其他分享
首页 > 其他分享> > 线性回归的案例01

线性回归的案例01

作者:互联网

这里篇幅可能过长,单独拎出来水一篇。教程视频只有八分钟,我只是看了一下思路,然后自己写了两个小时才写出来,不断修错误,总算是出了和视频中结果近似的结果。

问题描述

根据上节课的一次方程:y=1.477x+0.089
在这个方程的基础上添加噪声生成一百组数据,然后用线性回归的方法近似求取方程的参数。y=wx+b中的w和b
在这里插入图片描述
在这里插入图片描述

整体代码

废话不多说,直接先上代码

#!usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: Temmie
@file: test.py
@time: 2021/05/10
@desc:
"""
import pandas as pd
from matplotlib import pyplot
import numpy
#防止中文乱码
from matplotlib import font_manager
font_manager.FontProperties(fname='C:\Windows\Fonts\AdobeSongStd-Light.otf')
pyplot.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
pyplot.rcParams['axes.unicode_minus']=False #用来正常显示负号

def loss_cal(w,b,x,y):#损失函数计算
    loss_value=0
    for i in range(len(x)):
        loss_value +=(w*x[i]+b-y[i])**2
    return loss_value/(len(x))
def change_wb(w,b,x,y,learn_rate):#w和b的修改
    gra_w_t=0
    gra_b_t=0a
    for i in range(len(x)):
        gra_w_t +=(2*(w*x[i]+b-y[i])*x[i])/(len(x))
        gra_b_t +=(2*(w*x[i]+b-y[i]))/(len(x))
    w_now=w-learn_rate*gra_w_t
    b_now=b-learn_rate*gra_b_t
    return w_now,b_now
def the_res(x,y,sx,sy,loss):#绘制图像显示结果
    pyplot.scatter(x,y,color='blue',marker='.',label='原始数据')
    pyplot.plot(sx,sy,color='green',label='近似的线性回归线')
    pyplot.figlegend()
    pyplot.xlabel('x', loc='center')
    pyplot.ylabel('y', loc='center')
    pyplot.show()
    pyplot.plot(loss,color='red',label='损失函数')
    pyplot.figlegend()
    pyplot.xlabel('迭代次数', loc='center')
    pyplot.ylabel('损失函数值', loc='center')
    pyplot.show()
#读入数据
df = pd.read_csv('D:\learning_folder\data.csv')
#将不同数据分开
x=df.loc[:,'x'];print(x)
y=df.loc[:,'y'];print(y)
#设置初始的w和b初始化损失函数值loss_v迭代次数w_num和学习速率learn_rate
w=0;b=0;loss_v=[];w_num=1000;learn_rate=0.0001
for i in range(1,w_num,1):
    #计算损失函数
    loss_v.append(loss_cal(w,b,x,y))
    #迭代更改w和b
    w,b=change_wb(w,b,x,y,learn_rate)
print('w:',w,'\t','b:',b,'\t','final_loss:',loss_v[-1])
#绘制结果
sx=numpy.array(list(range(20,80,1)),dtype='float')
sy=w*sx+b
the_res(x,y,sx,sy,loss_v[1:])

代码解析

导入模块

import pandas as pd
from matplotlib import pyplot
import numpy
from matplotlib import font_manager

第一个pandas用来读入数据文件(.csv文件),对数据进行处理;
第二个pyplot进行画图;
第三个numpy也是数据处理;
第四个是防止中文乱码;

防止中文乱码

固定的内容

from matplotlib import font_manager
font_manager.FontProperties(fname='C:\Windows\Fonts\AdobeSongStd-Light.otf')
pyplot.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
pyplot.rcParams['axes.unicode_minus']=False #用来正常显示负号

初始化内容

这里需要给出初始的w和b来进行迭代计算,迭代的代数,学习速率,初始化一些空的列表方便存储迭代过程中我们希望查询的内容,例如损失函数值的变化等。

#设置初始的w和b初始化损失函数值loss_v迭代次数w_num和学习速率learn_rate
w=0;b=0;loss_v=[];w_num=1000;learn_rate=0.0001

循环迭代

for i in range(1,w_num,1):
    #计算损失函数
    loss_v.append(loss_cal(w,b,x,y))
    #迭代更改w和b
    w,b=change_wb(w,b,x,y,learn_rate)

不要慌,这里面的loss_cal和change_wb是我们自己写的函数

函数loss_cal

这是计算损失函数的一个函数
在这里插入图片描述
这里要说明一点:为了防止梯度下降的值频繁改变,我们可以通过分块求取平均值的方法防止参数不断大幅抖动,所以我们这里是对着100组数据求取瞬时函数值后取平均再改变参数w和b的

def loss_cal(w,b,x,y):#损失函数计算
    loss_value=0
    for i in range(len(x)):
        loss_value +=(w*x[i]+b-y[i])**2
    return loss_value/(len(x))

函数change_wb(w,b,x,y,learn_rate)

这是使用梯度下降来更新w和b的函数
在这里插入图片描述
注意,对w和b求梯度的时候求导的分别是w和b,其他要作为常数处理

def change_wb(w,b,x,y,learn_rate):#w和b的修改
    gra_w_t=0
    gra_b_t=0a
    for i in range(len(x)):
        gra_w_t +=(2*(w*x[i]+b-y[i])*x[i])/(len(x))
        gra_b_t +=(2*(w*x[i]+b-y[i]))/(len(x))
    w_now=w-learn_rate*gra_w_t
    b_now=b-learn_rate*gra_b_t
    return w_now,b_now

迭代结束显示重要参数

print('w:',w,'\t','b:',b,'\t','final_loss:',loss_v[-1])

画图the_res(x,y,sx,sy,loss)

这部分参考我写的plot部分即可,如需要,去python栏下清单查询

def the_res(x,y,sx,sy,loss):#绘制图像显示结果
    pyplot.scatter(x,y,color='blue',marker='.',label='原始数据')
    pyplot.plot(sx,sy,color='green',label='近似的线性回归线')
    pyplot.figlegend()
    pyplot.xlabel('x', loc='center')
    pyplot.ylabel('y', loc='center')
    pyplot.show()
    pyplot.plot(loss,color='red',label='损失函数')
    pyplot.figlegend()
    pyplot.xlabel('迭代次数', loc='center')
    pyplot.ylabel('损失函数值', loc='center')
    pyplot.show()

本来y轴可以进行非线性画图的,但是那个是开发人员和高级用户搞的,以后能搞懂再补效果更好的图。实际上损失函数值并非直线
在这里插入图片描述
在这里插入图片描述

标签:loss,01,函数,pyplot,gra,案例,rate,learn,线性
来源: https://blog.csdn.net/Temmie1024/article/details/116592577