线性回归与梯度下降实现

2020-04-12  本文已影响0人  罗泽坤

先创建一个数据样本集合

Y = 2X+E+3,(E为标准正态分布误差)

import numpy as np
import pandas as pd
from pandas import DataFrame,Series
import torch
e = np.random.randn(100)
#print(data)
x = np.random.uniform(1,10,size=(100,1)) #从均匀分布中100个取值

数据清洗连接操作

X = DataFrame(data=x) 
print(X)
E = DataFrame(data=e)
print(E)
Y = 2*X+E+3
#Y = Y.rename(columns={0:'1'})
Y.columns  = ['1']
print(Y)
Dt = pd.concat([X,Y.reindex(X.index)],axis=1)#将样本数据
print(Dt)
Dt.to_csv('LinR.csv',index = False,header = False) #去掉头索引和列索引写入LinR.csv文件
           0
0   7.959296
1   2.914118
2   6.531930
3   6.033350
4   3.329694
..       ...
95  6.689540
96  9.719845
97  1.911316
98  6.540799
99  2.833886

[100 rows x 1 columns]
           0
0  -0.988615
1  -0.298217
2   1.846047
3   1.982957
4   0.644001
..       ...
95 -1.317616
96  0.844359
97  0.475709
98 -0.182511
99 -0.075437

[100 rows x 1 columns]
            1
0   17.929977
1    8.530018
2   17.909907
3   17.049656
4   10.303389
..        ...
95  15.061463
96  23.284048
97   7.298340
98  15.899087
99   8.592336

[100 rows x 1 columns]
           0          1
0   7.959296  17.929977
1   2.914118   8.530018
2   6.531930  17.909907
3   6.033350  17.049656
4   3.329694  10.303389
..       ...        ...
95  6.689540  15.061463
96  9.719845  23.284048
97  1.911316   7.298340
98  6.540799  15.899087
99  2.833886   8.592336

[100 rows x 2 columns]
import matplotlib.pyplot as plt
sample = np.genfromtxt('LinR.csv',delimiter=',')
print(sample)
x = sample[:,0]  #提取出自变量
y = sample[:,1]  #提取出函数值
[[ 7.95929583 17.92997689]
 [ 2.91411788  8.5300183 ]
 [ 6.53192986 17.90990674]
 [ 6.0333498  17.04965623]
 [ 3.32969394 10.30338909]
 [ 9.78091003 24.11104012]
 [ 5.50414492 15.01162983]
 [ 2.3905083   7.98267035]
 [ 9.10073718 20.63593365]
 [ 8.6771435  20.00840521]
 [ 8.59297259 20.47870044]
 [ 3.71421162 13.64842939]
 [ 9.53717825 20.65189706]
 [ 1.65102104  5.72905434]
 [ 2.25873778  6.97617708]
 [ 4.22231735 11.55259733]
 [ 1.53631731  4.94287203]
 [ 5.34000791 13.16707094]
 [ 6.17545718 15.06864934]
 [ 4.90415263 12.41321868]
 [ 7.75327441 18.6794576 ]
 [ 5.69725718 12.84726204]
 [ 3.56909863 10.5966961 ]
 [ 9.94681425 21.33072322]
 [ 3.59799413  9.56658453]
 [ 3.87835658 11.11208729]
 [ 5.03228019 13.68579256]
 [ 9.48569476 22.77879716]
 [ 4.66228634 12.75729249]
 [ 2.4908625   7.26331882]
 [ 4.81153515 12.35458447]
 [ 9.58406536 22.38109883]
 [ 4.28121897 10.74117052]
 [ 2.45652558  8.71965016]
 [ 1.9736664   6.4809288 ]
 [ 9.51051954 22.72786535]
 [ 6.75959417 17.81727331]
 [ 5.61578003 14.22442185]
 [ 3.85178666  9.75542375]
 [ 5.68632445 13.52993683]
 [ 5.4966711  15.11321743]
 [ 1.11686002  6.05160856]
 [ 8.82731607 19.98729226]
 [ 3.23499699  8.89761767]
 [ 3.22340228  9.82438038]
 [ 3.79773535  9.75677403]
 [ 7.91327594 19.74140776]
 [ 6.96384595 16.55981935]
 [ 8.04811251 18.00044451]
 [ 5.95047664 15.91780519]
 [ 6.86179131 14.57196115]
 [ 9.92403963 24.85022378]
 [ 2.28427921  5.8765021 ]
 [ 3.65702265  9.91875919]
 [ 2.94110273 10.2254981 ]
 [ 1.44233169  6.68000409]
 [ 8.11509328 19.30801919]
 [ 9.61113455 22.19692088]
 [ 5.31715886 10.11119983]
 [ 4.18841633 10.38081049]
 [ 3.38696086 10.89271924]
 [ 6.79149049 16.38415715]
 [ 9.43502226 22.16533282]
 [ 2.39967181  8.2302037 ]
 [ 7.36239661 17.6441651 ]
 [ 2.94883598  9.58722116]
 [ 8.07898524 18.88577106]
 [ 7.27030879 17.48551352]
 [ 3.82638905 10.42645614]
 [ 2.21195155  6.25717108]
 [ 2.62373052  7.75337578]
 [ 4.2888883  13.05445549]
 [ 4.79372014 13.47932057]
 [ 7.27504625 18.03400179]
 [ 7.6804564  18.91495491]
 [ 4.96708467 11.14166529]
 [ 1.3403768   3.22305645]
 [ 3.31317609 10.54117355]
 [ 8.83562974 21.6072844 ]
 [ 5.2485316  15.68878013]
 [ 3.09095248  8.00604601]
 [ 1.40126741  5.61514191]
 [ 5.97243366 16.6769609 ]
 [ 9.27273317 21.88579855]
 [ 1.10719167  7.34636051]
 [ 7.64405604 16.02157126]
 [ 1.46871259  7.17774005]
 [ 8.04512646 19.50343429]
 [ 6.44703966 14.47562085]
 [ 2.53241481  6.62603648]
 [ 7.88971862 19.72029653]
 [ 5.25505582 13.99824232]
 [ 6.42160172 17.67089747]
 [ 4.71902614 10.58448458]
 [ 3.69982573 11.40370471]
 [ 6.68953965 15.06146343]
 [ 9.7198446  23.28404806]
 [ 1.91131556  7.29834036]
 [ 6.54079876 15.89908666]
 [ 2.83388631  8.59233584]]
# 做出y = x*2+e+3的散点图
plt.scatter(x, y, marker = 'o',color = 'red', s = 40 )
plt.show()
output_5_0.png
#损失函数
def loss_function(data,b,w):
    Total_Error = 0
    for i in range(len(data)):
        x = data[i][0]
        y = data[i][1]
        Total_Error += ((w*x+b)-y)**2
    return Total_Error/float(len(data))

# 梯度下降步骤
def gradient_step(b_current,w_current,data,learning_rate):
    #初始化梯度
    b_grad = 0
    w_grad = 0
    N = float(len(data))
    for i in range(len(data)):
        x = data[i][0]
        y = data[i][1]
        b_grad +=  (2/N)*(w_current*x+b_current-y)
        w_grad +=  x*(2/N)*(w_current*x+b_current-y)
    new_b  = b_current - learning_rate*b_grad
    #print(new_b)
    #w_current -= learning_rate*w_grad
    new_w = b_current - learning_rate*w_grad
    #print(new_w)
    return [new_b,new_w]
#梯度下降迭代
def gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate):
    b_current = initial_b
    w_current = initial_w
    for i in range(iteration_num):
        b_current,w_current = gradient_step(b_current,w_current,np.array(data),learning_rate)
    return [b_current,w_current]
#运行函数
def run_fun():
    initial_b = 0
    initial_w = 0
    iteration_num = 20000
    learning_rate = 0.0001
    data = np.genfromtxt('LinR.csv',delimiter=',')
    print('Starting gradient decent at b = {0},w={1},loss_value = {2}'.format(initial_b,initial_w,loss_function(data,initial_b,initial_w)))
    print('Running...')
    [b,w] = gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate)
    print('After {0} iterations b = {1},w = {2},loss_value = {3}'.format(iteration_num,b,w,loss_function(data,b,w)))
    
#if __name__ == '__main__'():
run_fun()

    
Starting gradient decent at b = 0,w=0,loss_value = 217.6852494453724
Running...
After 20000 iterations b = 2.1616063195346547,w = 2.1614204809921613,loss_value = 1.4195063550286242

可以看到迭代出的b值与w值与预先设定的2和3很接近损失函数值也很小了

import matplotlib.pyplot as plt
sample = np.genfromtxt('LinR.csv',delimiter=',')
print(sample)
x_1 = sample[:,0]
y_1 = sample[:,1]
[[ 7.95929583 17.92997689]
 [ 2.91411788  8.5300183 ]
 [ 6.53192986 17.90990674]
 [ 6.0333498  17.04965623]
 [ 3.32969394 10.30338909]
 [ 9.78091003 24.11104012]
 [ 5.50414492 15.01162983]
 [ 2.3905083   7.98267035]
 [ 9.10073718 20.63593365]
 [ 8.6771435  20.00840521]
 [ 8.59297259 20.47870044]
 [ 3.71421162 13.64842939]
 [ 9.53717825 20.65189706]
 [ 1.65102104  5.72905434]
 [ 2.25873778  6.97617708]
 [ 4.22231735 11.55259733]
 [ 1.53631731  4.94287203]
 [ 5.34000791 13.16707094]
 [ 6.17545718 15.06864934]
 [ 4.90415263 12.41321868]
 [ 7.75327441 18.6794576 ]
 [ 5.69725718 12.84726204]
 [ 3.56909863 10.5966961 ]
 [ 9.94681425 21.33072322]
 [ 3.59799413  9.56658453]
 [ 3.87835658 11.11208729]
 [ 5.03228019 13.68579256]
 [ 9.48569476 22.77879716]
 [ 4.66228634 12.75729249]
 [ 2.4908625   7.26331882]
 [ 4.81153515 12.35458447]
 [ 9.58406536 22.38109883]
 [ 4.28121897 10.74117052]
 [ 2.45652558  8.71965016]
 [ 1.9736664   6.4809288 ]
 [ 9.51051954 22.72786535]
 [ 6.75959417 17.81727331]
 [ 5.61578003 14.22442185]
 [ 3.85178666  9.75542375]
 [ 5.68632445 13.52993683]
 [ 5.4966711  15.11321743]
 [ 1.11686002  6.05160856]
 [ 8.82731607 19.98729226]
 [ 3.23499699  8.89761767]
 [ 3.22340228  9.82438038]
 [ 3.79773535  9.75677403]
 [ 7.91327594 19.74140776]
 [ 6.96384595 16.55981935]
 [ 8.04811251 18.00044451]
 [ 5.95047664 15.91780519]
 [ 6.86179131 14.57196115]
 [ 9.92403963 24.85022378]
 [ 2.28427921  5.8765021 ]
 [ 3.65702265  9.91875919]
 [ 2.94110273 10.2254981 ]
 [ 1.44233169  6.68000409]
 [ 8.11509328 19.30801919]
 [ 9.61113455 22.19692088]
 [ 5.31715886 10.11119983]
 [ 4.18841633 10.38081049]
 [ 3.38696086 10.89271924]
 [ 6.79149049 16.38415715]
 [ 9.43502226 22.16533282]
 [ 2.39967181  8.2302037 ]
 [ 7.36239661 17.6441651 ]
 [ 2.94883598  9.58722116]
 [ 8.07898524 18.88577106]
 [ 7.27030879 17.48551352]
 [ 3.82638905 10.42645614]
 [ 2.21195155  6.25717108]
 [ 2.62373052  7.75337578]
 [ 4.2888883  13.05445549]
 [ 4.79372014 13.47932057]
 [ 7.27504625 18.03400179]
 [ 7.6804564  18.91495491]
 [ 4.96708467 11.14166529]
 [ 1.3403768   3.22305645]
 [ 3.31317609 10.54117355]
 [ 8.83562974 21.6072844 ]
 [ 5.2485316  15.68878013]
 [ 3.09095248  8.00604601]
 [ 1.40126741  5.61514191]
 [ 5.97243366 16.6769609 ]
 [ 9.27273317 21.88579855]
 [ 1.10719167  7.34636051]
 [ 7.64405604 16.02157126]
 [ 1.46871259  7.17774005]
 [ 8.04512646 19.50343429]
 [ 6.44703966 14.47562085]
 [ 2.53241481  6.62603648]
 [ 7.88971862 19.72029653]
 [ 5.25505582 13.99824232]
 [ 6.42160172 17.67089747]
 [ 4.71902614 10.58448458]
 [ 3.69982573 11.40370471]
 [ 6.68953965 15.06146343]
 [ 9.7198446  23.28404806]
 [ 1.91131556  7.29834036]
 [ 6.54079876 15.89908666]
 [ 2.83388631  8.59233584]]
#将拟合曲线与样本做比较
import matplotlib.pyplot as plt
import numpy as np
plt.scatter(x_1, y_1, marker = 'o',color = 'red', s = 40,label = 'sampelSC' )
#plt.show()
x_2 = np.arange(0,10)
w = 2.1614204809921613
b = u2.1616063195346547
y_2 = w*x+b
plt.plot(x_2,y_2,label='Fitting Curve')
plt.show()
output_11_0.png

可以看到曲线拟合的效果还是很不错的


上一篇下一篇

猜你喜欢

热点阅读