基于pytorch的linear Regression

2020-09-10  本文已影响0人  骆旺达

线性回归模型

线性回归是分析一个变量与另外一(多)个变量之间关系的方法。因变量是 y,自变量是 x,关系线性:
y=wx+b
任务就是求解 wb

我们的求解步骤:
1、确定模型:Model => y = wx + b
2、选择损失函数:这里用 MSE
MSE=\frac{1}{M} \sum_{i=1}^M(y-y_{pred} ) ^2
3、求解梯度并更新 w,b
w = w - lr * w.grad
b = b - lr * b.grad

# 首先我们得有训练样本X,Y, 这里我们随机生成
# 随机生成20个(0,1)*10的torch向量
x = torch.rand(20, 1) * 10
# (随机生成20个均值为0,方差为1的torch向量+5)+2*x
y = 2 * x + (5 + torch.randn(20, 1))

# 构建线性回归函数的参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)   # 这俩都需要求梯度

for iteration in range(100):
 # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)
 
    # 计算loss
    loss = (0.5 * (y-y_pred)**2).mean()
    
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)    # 这种_的加法操作时从自身减,相当于-=
    w.data.sub_(lr * w.grad)

    # 梯度清零
    w.grad.data.zero_()
    b.grad.data.zero_()

print(w.data, b.data)

绘图

import matplotlib.pyplot as  plt
wx = torch.mul(w,x)
y_pred = torch.add(wx,b)
plt.scatter(x,y)
plt.plot(x,y_pred.data,"r")
image.png

来自:https://mp.weixin.qq.com/s/cZ4LpzVKS_JFUw1hEwbm5g

上一篇下一篇

猜你喜欢

热点阅读