28. 日月光华 Python数据分析 - 机器学习 - 欠拟合
2023-07-27 本文已影响0人
薛东弗斯
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
x = np.linspace(0,20,30)
y = x**3 + np.random.rand(30)*160 + 10
plt.scatter(x,y)
![](https://img.haomeiwen.com/i3968643/311b1c9a837bdc69.png)
欠拟合
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x.reshape(-1,1), y)
y_pred = model.predict(x.reshape(-1,1))
plt.scatter(x,y)
plt.plot(x,y_pred, c='r')
![](https://img.haomeiwen.com/i3968643/694a2fdaf8b5990b.png)
from sklearn.preprocessing import PolynomialFeatures
q4 = PolynomialFeatures(degree=4) # 四阶多项式进行拟合
x4 = q4.fit_transform(x.reshape(-1,1))
y_pred4 = model4.predict(x4)
plt.scatter(x,y)
plt.plot(x,y_pred4, c='r')
![](https://img.haomeiwen.com/i3968643/e476d4e0ef323b2c.png)
q20 = PolynomialFeatures(degree=20) # 20阶多项式拟合
x20 = q20.fit_transform(x.reshape(-1,1))
model20 = LinearRegression()
model20.fit(x20, y)
y_pred20 = model20.predict(x20)
plt.scatter(x,y)
plt.plot(x,y_pred20, c='r')
![](https://img.haomeiwen.com/i3968643/b1d0df0f843d6f5f.png)
过拟合,效果反而变差