%% 画数据散点图
%第214对数据有问题,先删除
Data = xlsread('D:\Matlab\test\数据集\train.csv');
x = Data(:,1);
y = Data(:,2);
%均值归一化,在这里的数据中不进行均值归一化会造成运算为 NaN 的结果 
x = (x- mean(x))./ (max(x)-min(x));
y = (y- mean(y))./ (max(y)-min(y));
plot(x,y,'.');
hold on;

%% 参数初始化
m = length(x);%样本数量
theta = [1;0];%theta初始化
%预先分配空间以节省运行时间
X = [ones(m,1),x];%特征值的增广矩阵
%梯度下降法
pd = zeros(m,2);%J对theta的偏导矩阵 
cost = zeros(m,1);
alpha = 0.1;
itration = 10000;

%% 梯度下降法迭代寻找最小值
for i = 1:itration
    h = X*theta;
    cost = (y-h).*(y-h);
    pd(:,1) = (h-y).*X(:,1);
    pd(:,2) = (h-y).*X(:,2);
    theta(1) = theta(1) - alpha/m*sum(pd(:,1));
    theta(2) = theta(2) - alpha/m*sum(pd(:,2));
    J = 1/(2*m)*sum(cost);
end

%% 正规方程法
% theta = ((X'*X)^(-1))*X'*y;
% n较小时使用正规方程法简单且速度较快但要注意X'X的可逆性问题;
% n较大时使用梯度下降法运算更快。

%% 画线
X = min(x):0.01:max(x);
Y = theta(1)+theta(2)*X;
plot(X,Y,'LineWidth',2);

效果图:

 相关知识点和公式可参考我的另一篇博客中的线性回归部分:

机器学习基础----基于吴恩达机器学习课程的笔记_m0_61112058的博客-CSDN博客

数据集:

线性回归 数据集 - DataFountain

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐