简单线性回归算法
2020-04-17 本文已影响0人
元宝的技术日常
1、算法简介
1-1、算法思路
简单线性回归(SimpleLinearRegression)解决的是回归问题,上一篇是分类,这两个概念的区别是标签值-label的差距。分类问题的label一般是类别型,像性别类别,品牌类别... 而回归问题的label一般为为连续数值型,像身高、体重... 监督学习的样本点是由特征-feature和标签-label组成。
简单线性回归之所以有简单两个字,是因为确实比较简单,如果能把样本点映射到空间中,它的作用就是试图用一条线把所有点表示。比如,样本只有一个特征值-x和一个标签值-y,最终的表示则为一元方程:y=ax+b;a、b则为此算法要求出的未知量。
1-2、图示
简单线性回归如图,样本点中间有一条直线,样本之间的关系试图要用一条直线来模拟。
1-3、算法流程
1--- 假如样本只有一个特征值-x和一个标签值-y,最终的表示则为一元方程:y^=ax+b
2--- 试图求出a、b,使得推理出的-y^和实际标签值-y无限接近;判断是否接近的评测标准一般用:均方差-MSE、均方根差-RMSE、平均绝对误差-MAE和R Square。
3--- 采用最小二乘法,可以求出a、b(推导过程见这里)
4--- 不同于kNN算法,在训练数据集上训练好参数a和b的值之后,推理/预测的时候,只需要使用学习到的参数a和b对每一个待推理/预测的样本进行计算就好了。这就是一个典型的参数学习算法。
1-4、优缺点
1-4-1、优点
a、思想简单,实现容易;
b、结果有可解释性;
c、有强大的非线性模型的基础。
1-4-2、缺点
a、在数学模型中表示一根直线,而现实环境中很多的数据,例如房价,销售涨跌都是曲线结构的,使得推理/预测率低;
b、难以很好地表达高度复杂的数据。
2、实践
2-1、采用bobo老师创建简单测试用例
import numpy as np
import matplotlib.pyplot as plt
# 创建测试数据
x = np.array([1., 2., 3., 4., 5.])
y = np.array([1., 3., 2., 3., 5.])
plt.scatter(x, y)
plt.axis([0, 6, 0, 6])
plt.show() #见plt.show1
plt.show1
x_mean = np.mean(x)
y_mean = np.mean(y)
# 最小化误差的平方,用最小二乘法求解:
num = 0.0
d = 0.0
for x_i, y_i in zip(x, y):
num += (x_i - x_mean) * (y_i - y_mean)
d += (x_i - x_mean) ** 2
a = num/d
b = y_mean - a * x_mean
y_hat = a * x + b
plt.scatter(x, y)
plt.plot(x, y_hat, color='r')
plt.axis([0, 6, 0, 6])
plt.show() #见plt.show2
plt.show2