【阿旭机器学习实战】【7】岭回归基本原理及其λ的选取方法
2022-11-16 本文已影响0人
阿旭123
【阿旭机器学习实战】系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流。
本文对机器学习中的基于线性回归模型改进的岭回归原理进行了简单介绍,并且通过实例介绍了其基本使用方法以及lambda的选取原则。
目录
线性回归之岭回归
1、原理
如果数据的特征数比样本点还多应该怎么办?是否还可以使用普通的线性回归来做预测?
答案是否定的
。因为输入数据的矩阵X不是满秩矩阵。非满秩矩阵在求逆时会出现问题
。
为了解决这个问题,统计学家引入了岭回归(ridge regression)
的概念。
缩减方法可以去掉不重要的参数
,因此能更好地理解数据。此外,与简单的线性回归相比,缩减法能取得更好的预测效果。
【注意】在岭回归里面,决定回归模型性能的除了数据算法以外,还有一个缩减值lambda * I
岭回归是加了二阶正则项(lambda*I)的最小二乘,主要适用于过拟合严重或各变量之间存在多重共线性的时候,岭回归是有bias的,这里的bias是为了让variance更小。
2、岭回归主要处理的问题
岭回归主要用于处理下面两类问题:
1.数据点少于特征变量个数
2.变量间存在共线性(最小二乘回归得到的系数不稳定,方差很大)
3、归纳总结
1.岭回归可以解决特征数量比样本量多的问题
2.岭回归作为一种缩减算法可以判断哪些特征重要或者不重要,有点类似于降维的效果
3.缩减算法可以看作是对一个模型增加偏差的同时减少方差
4、岭回归实例
# 导入岭回归模型
from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
import numpy as np
# 手动创建训练数据
x_train = np.array([[1,1,2,1,4],[2,3,4,1,2],[1,3,2,4,1],[2,1,3,4,5]])
y_train = np.array([1,2,3,4])
# 测试数据
x_test = np.array([[1,2,3,1,2]])
4.1普通线性回归进行预测
linear = LinearRegression()
linear.fit(x_train,y_train)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
linear.predict(x_test) # 这个预测是不准确的,x_train不可逆
array([1.20837809])
4.2 岭回归进行预测
rigde = Ridge(alpha=1000)
# alpha就是 lambda*I中的lambda
rigde.fit(x_train,y_train)
Ridge(alpha=1000, copy_X=True, fit_intercept=True, max_iter=None,
normalize=False, random_state=None, solver='auto', tol=0.001)
rigde.predict(x_test)
array([2.48971789])
# 回归系数W = (X^T*X + lambda*I)^(-1) * X^T*Y
# Y = W^T * X + b
rigde.coef_ # 回归系数的衰减和alpha值有关,alpha越大,alpha对回归的影响就越大
array([9.97254560e-04, 5.40720570e-06, 5.06027985e-04, 5.94723394e-03,
9.89143751e-04])
# 普通线性回归系数
linear.coef_
array([ 0.30827068, 0.11385607, 0.36949517, 0.72824919, 0.13748657])
4.3 岭回归的核心问题是找到合适的alpha值
选取alpha值的一般原则是:各回归系数的岭估计基本稳定
# 建立训练数据集
x_train = 1/(np.arange(1,11)+ np.arange(0,10).reshape((10,1)))
x_train
array([[1. , 0.5 , 0.33333333, 0.25 , 0.2 ,
0.16666667, 0.14285714, 0.125 , 0.11111111, 0.1 ],
[0.5 , 0.33333333, 0.25 , 0.2 , 0.16666667,
0.14285714, 0.125 , 0.11111111, 0.1 , 0.09090909],
[0.33333333, 0.25 , 0.2 , 0.16666667, 0.14285714,
0.125 , 0.11111111, 0.1 , 0.09090909, 0.08333333],
[0.25 , 0.2 , 0.16666667, 0.14285714, 0.125 ,
0.11111111, 0.1 , 0.09090909, 0.08333333, 0.07692308],
[0.2 , 0.16666667, 0.14285714, 0.125 , 0.11111111,
0.1 , 0.09090909, 0.08333333, 0.07692308, 0.07142857],
[0.16666667, 0.14285714, 0.125 , 0.11111111, 0.1 ,
0.09090909, 0.08333333, 0.07692308, 0.07142857, 0.06666667],
[0.14285714, 0.125 , 0.11111111, 0.1 , 0.09090909,
0.08333333, 0.07692308, 0.07142857, 0.06666667, 0.0625 ],
[0.125 , 0.11111111, 0.1 , 0.09090909, 0.08333333,
0.07692308, 0.07142857, 0.06666667, 0.0625 , 0.05882353],
[0.11111111, 0.1 , 0.09090909, 0.08333333, 0.07692308,
0.07142857, 0.06666667, 0.0625 , 0.05882353, 0.05555556],
[0.1 , 0.09090909, 0.08333333, 0.07692308, 0.07142857,
0.06666667, 0.0625 , 0.05882353, 0.05555556, 0.05263158]])
y_train = np.ones(10)
y_train
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
# 创建一系列alpha值作为岭回归的缩减系数
alphas = np.logspace(-2,5,200)
alphas
array([1.00000000e-02, 1.08436597e-02, 1.17584955e-02, 1.27505124e-02,
1.38262217e-02, 1.49926843e-02, 1.62575567e-02, 1.76291412e-02,
1.91164408e-02, 2.07292178e-02, 2.24780583e-02, 2.43744415e-02,
2.64308149e-02, 2.86606762e-02, 3.10786619e-02, 3.37006433e-02,
3.65438307e-02, 3.96268864e-02, 4.29700470e-02, 4.65952567e-02,
5.05263107e-02, 5.47890118e-02, 5.94113398e-02, 6.44236351e-02,
6.98587975e-02, 7.57525026e-02, 8.21434358e-02, 8.90735464e-02,
9.65883224e-02, 1.04737090e-01, 1.13573336e-01, 1.23155060e-01,
1.33545156e-01, 1.44811823e-01, 1.57029012e-01, 1.70276917e-01,
1.84642494e-01, 2.00220037e-01, 2.17111795e-01, 2.35428641e-01,
2.55290807e-01, 2.76828663e-01, 3.00183581e-01, 3.25508860e-01,
3.52970730e-01, 3.82749448e-01, 4.15040476e-01, 4.50055768e-01,
4.88025158e-01, 5.29197874e-01, 5.73844165e-01, 6.22257084e-01,
6.74754405e-01, 7.31680714e-01, 7.93409667e-01, 8.60346442e-01,
9.32930403e-01, 1.01163798e+00, 1.09698580e+00, 1.18953407e+00,
1.28989026e+00, 1.39871310e+00, 1.51671689e+00, 1.64467618e+00,
1.78343088e+00, 1.93389175e+00, 2.09704640e+00, 2.27396575e+00,
2.46581108e+00, 2.67384162e+00, 2.89942285e+00, 3.14403547e+00,
3.40928507e+00, 3.69691271e+00, 4.00880633e+00, 4.34701316e+00,
4.71375313e+00, 5.11143348e+00, 5.54266452e+00, 6.01027678e+00,
6.51733960e+00, 7.06718127e+00, 7.66341087e+00, 8.30994195e+00,
9.01101825e+00, 9.77124154e+00, 1.05956018e+01, 1.14895100e+01,
1.24588336e+01, 1.35099352e+01, 1.46497140e+01, 1.58856513e+01,
1.72258597e+01, 1.86791360e+01, 2.02550194e+01, 2.19638537e+01,
2.38168555e+01, 2.58261876e+01, 2.80050389e+01, 3.03677112e+01,
3.29297126e+01, 3.57078596e+01, 3.87203878e+01, 4.19870708e+01,
4.55293507e+01, 4.93704785e+01, 5.35356668e+01, 5.80522552e+01,
6.29498899e+01, 6.82607183e+01, 7.40196000e+01, 8.02643352e+01,
8.70359136e+01, 9.43787828e+01, 1.02341140e+02, 1.10975250e+02,
1.20337784e+02, 1.30490198e+02, 1.41499130e+02, 1.53436841e+02,
1.66381689e+02, 1.80418641e+02, 1.95639834e+02, 2.12145178e+02,
2.30043012e+02, 2.49450814e+02, 2.70495973e+02, 2.93316628e+02,
3.18062569e+02, 3.44896226e+02, 3.73993730e+02, 4.05546074e+02,
4.39760361e+02, 4.76861170e+02, 5.17092024e+02, 5.60716994e+02,
6.08022426e+02, 6.59318827e+02, 7.14942899e+02, 7.75259749e+02,
8.40665289e+02, 9.11588830e+02, 9.88495905e+02, 1.07189132e+03,
1.16232247e+03, 1.26038293e+03, 1.36671636e+03, 1.48202071e+03,
1.60705282e+03, 1.74263339e+03, 1.88965234e+03, 2.04907469e+03,
2.22194686e+03, 2.40940356e+03, 2.61267523e+03, 2.83309610e+03,
3.07211300e+03, 3.33129479e+03, 3.61234270e+03, 3.91710149e+03,
4.24757155e+03, 4.60592204e+03, 4.99450512e+03, 5.41587138e+03,
5.87278661e+03, 6.36824994e+03, 6.90551352e+03, 7.48810386e+03,
8.11984499e+03, 8.80488358e+03, 9.54771611e+03, 1.03532184e+04,
1.12266777e+04, 1.21738273e+04, 1.32008840e+04, 1.43145894e+04,
1.55222536e+04, 1.68318035e+04, 1.82518349e+04, 1.97916687e+04,
2.14614120e+04, 2.32720248e+04, 2.52353917e+04, 2.73644000e+04,
2.96730241e+04, 3.21764175e+04, 3.48910121e+04, 3.78346262e+04,
4.10265811e+04, 4.44878283e+04, 4.82410870e+04, 5.23109931e+04,
5.67242607e+04, 6.15098579e+04, 6.66991966e+04, 7.23263390e+04,
7.84282206e+04, 8.50448934e+04, 9.22197882e+04, 1.00000000e+05])
# 用上面一系列的alpha值,创建对应的算法
# 定义一个列表,用于存放每次训练的回归系数
w = []
# 创建岭回归模型
r = Ridge(fit_intercept=False)
for alpha in alphas:
# 给算法设置不同的alpha值
r.set_params(alpha=alpha)
# 对不同的alpha值的模型进行训练
r.fit(x_train,y_train)
# 取出对应的回归系数
w.append(r.coef_)
画岭迹线(回归系数和alpha值之间的关系)
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(12,9))
axes = plt.subplot(111)
axes.plot(alphas,w)
axes.set_xscale("log")
在这里插入图片描述
通过对岭迹线的观察发现超100以后岭迹线基本趋于稳定,alpha合适值在100以后
如果内容对你有帮助,感谢点赞+关注哦!
更多干货内容持续更新中…