一、理论
二、数据集
6.1101,17.5925.5277,9.13028.5186,13.6627.0032,11.8545.8598,6.82338.3829,11.8867.4764,4.34838.5781,126.4862,6.59875.0546,3.81665.7107,3.252214.164,15.5055.734,3.15518.4084,7.22585.6407,0.716185.3794,3.51296.3654,5.30485.1301,0.560776.4296,3.65187.0708,5.38936.1891,3.138620.27,21.7675.4901,4.2636.3261,5.18755.5649,3.082518.945,22.63812.828,13.50110.957,7.046713.176,14.69222.203,24.1475.2524,-1.226.5894,5.99669.2482,12.1345.8918,1.84958.2111,6.54267.9334,4.56238.0959,4.11645.6063,3.392812.836,10.1176.3534,5.49745.4069,0.556576.8825,3.911511.708,5.38545.7737,2.44067.8247,6.73187.0931,1.04635.0702,5.13375.8014,1.84411.7,8.00435.5416,1.01797.5402,6.75045.3077,1.83967.4239,4.28857.6031,4.99816.3328,1.42336.3589,-1.42116.2742,2.47565.6397,4.60429.3102,3.96249.4536,5.41418.8254,5.16945.1793,-0.7427921.279,17.92914.908,12.05418.959,17.0547.2182,4.88528.2951,5.744210.236,7.77545.4994,1.017320.341,20.99210.136,6.67997.3345,4.02596.0062,1.27847.2259,3.34115.0269,-2.68076.5479,0.296787.5386,3.88455.0365,5.701410.274,6.75265.1077,2.05765.7292,0.479535.1884,0.204216.3557,0.678619.7687,7.54356.5159,5.34368.5172,4.24159.1802,6.79816.002,0.926955.5204,0.1525.0594,2.82145.7077,1.84517.6366,4.29595.8707,7.20295.3054,1.98698.2934,0.1445413.394,9.05515.4369,0.61705
三、代码实现
clear all; clc;data = load('ex1data1.txt');X = data(:, 1); y = data(:, 2);m = length(y); % number of training examplesplot(X,y,'rx');%% =================== Part 3: Gradient descent ===================fprintf('Running Gradient Descent ...\n')%为什么加上一列1,为了算J时候,theta0 乘以1X = [ones(m, 1), data(:,1)]; % Add a column of ones to xtheta = zeros(2, 1); % initialize fitting parameters% Some gradient descent settingsiterations = 1500;alpha = 0.01;% compute and display initial costcomputeCost(X, y, theta)% run gradient descent[theta, J_history]= gradientDescent(X, y, theta, alpha, iterations);hold on; % keep previous plot visibleplot(X(:,2), X*theta, '-')legend('Training data', 'Linear regression')hold off % don't overlay any more plots on this figure% Predict values for population sizes of 35,000 and 70,000predict1 = [1, 3.5] *theta;fprintf('For population = 35,000, we predict a profit of %f\n',... predict1*10000);predict2 = [1, 7] * theta;fprintf('For population = 70,000, we predict a profit of %f\n',... predict2*10000);% Grid over which we will calculate Jtheta0_vals = linspace(-10, 10, 100);theta1_vals = linspace(-1, 4, 100);% initialize J_vals to a matrix of 0'sJ_vals = zeros(length(theta0_vals), length(theta1_vals));% Fill out J_valsfor i = 1:length(theta0_vals) for j = 1:length(theta1_vals) t = [theta0_vals(i); theta1_vals(j)]; J_vals(i,j) = computeCost(X, y, t); endend% Because of the way meshgrids work in the surf command, we need to % transpose J_vals before calling surf, or else the axes will be flippedJ_vals = J_vals';% Surface plotfigure;surf(theta0_vals, theta1_vals, J_vals)xlabel('\theta_0'); ylabel('\theta_1');% Contour plotfigure;% Plot J_vals as 15 contours spaced logarithmically between 0.01 and 100%以10为底的指数 logspace(-2, 3, 20)坐标值标注范围以及间距contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20))xlabel('\theta_0'); ylabel('\theta_1');hold on;plot(theta(1), theta(2), 'rx', 'MarkerSize', 10, 'LineWidth', 2);
...................
function J = computeCost(X, y, theta)m = length(y); % number of training examples J = 0;for i=1:m J = J +(theta(1)*X(i,1) + theta(2)*X(i,2) - y(i))^2; end% 除以2m是为了在更新参数的时候 好算 2因为J是二次,求骗到后产生系数2,%m是为了不让J 过大(i=1:m已经是求偏导第二部的m、项了)J = J/(m*2);end
......
function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)m = length(y); % number of training examplesJ_history = zeros(num_iters, 1);J_1 = 0;% 偏导数J_1, J_2J_2 = 0;for iter = 1:num_iters for i = 1:m J_1 = J_1 + theta(1)*X(i,1) + theta(2)*X(i,2) - y(i); J_2 = J_2 + (theta(1)*X(i,1) + theta(2)*X(i,2) - y(i)) * X(i,2); end %J中的m 没有在上面的for内除,因为只除以一次就够了 J_1 = J_1/m; J_2 = J_2/m;% temp1 = theta(1) - alpha * J_1;% temp2 = theta(2) - alpha * J_2;% theta(1) = temp1;% theta(2) = temp2; theta(1) = theta(1) - alpha * J_1; theta(2) = theta(2) - alpha * J_2; J_history(iter) = computeCost(X, y, theta); % save J_history J_historyendend
四、运行结果