多项式回归算法

2020-04-18  本文已影响0人  元宝的技术日常

1、算法简介

1-1、算法思路

上一篇,简单线性回归算法的缺点之一是对于标签值是曲线结构的走势,很难拟合。那多项式回归算法出现,就是使得线性回归算法可以对非线性的数据进行回归分析,一个方面的优化、改进。

对于拟合出非线性关系,一般可以想到曲线;提起曲线,就得说在初中时学习的一元二次方程--y = ax^2 + bx + c。

简单线性回归算法的特征值是一次幂,如果想要形成一元二次方程的效果,就得需要在特征值中添加二次幂;这样的话,对于最后的解就变为了a、b、c。


1-2、图示

多项式回归

如图,样本点中间有一条曲线,样本之间的关系试图要用一条曲线来拟合。


1-3、算法流程
简单线性回归算法


1-4、优缺点

1-4-1、优点

a、拟合非线性的数据
b、理解与解释都十分直观
c、可以通过正则化来降低过拟合的风险
d、容易使用随机梯度下降和新数据更新模型权重

1-4-2、缺点

a、需要处理异常值
b、较简单回归算法复杂、困难
c、训练时间会增加


2、实践

2-1、采用bobo老师创建简单测试用例

import numpy as np 
import matplotlib.pyplot as plt

# 创建测试数据
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)

plt.scatter(x, y)
plt.show() #见plt.show0
plt.show0
from sklearn.linear_model import LinearRegression

# 使用简单线性回归训练
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_predict = lin_reg.predict(X)

plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show() # 见plt.show1
plt.show1
X2 = np.hstack([X, X**2]) # 添加一个二次幂特征
X2.shape
# (100, 2)

# 多项式回归
lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r') 
#要对x排序,否则是混乱的折线图
plt.show() # 见plt.show2
plt.show2
lin_reg2.coef_ # 特征和二次幂特征的系数
# array([0.85348244, 0.481137  ])

lin_reg2.intercept_ # 截距-b
# 2.032352537360585
上一篇 下一篇

猜你喜欢

热点阅读