28. 日月光华 Python数据分析 - 机器学习 - 欠拟合

2023-07-27  本文已影响0人  薛东弗斯
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

x = np.linspace(0,20,30)
y = x**3 + np.random.rand(30)*160 + 10

plt.scatter(x,y)
image.png

欠拟合

from sklearn.linear_model import LinearRegression

model = LinearRegression()
model.fit(x.reshape(-1,1), y)
y_pred = model.predict(x.reshape(-1,1))
plt.scatter(x,y)
plt.plot(x,y_pred, c='r')
image.png
from sklearn.preprocessing import PolynomialFeatures

q4 = PolynomialFeatures(degree=4)      # 四阶多项式进行拟合
x4 = q4.fit_transform(x.reshape(-1,1))
y_pred4 = model4.predict(x4)
plt.scatter(x,y)
plt.plot(x,y_pred4, c='r')
image.png
q20 = PolynomialFeatures(degree=20)    # 20阶多项式拟合
x20 = q20.fit_transform(x.reshape(-1,1))

model20 = LinearRegression()
model20.fit(x20, y)
y_pred20 = model20.predict(x20)
plt.scatter(x,y)
plt.plot(x,y_pred20, c='r')
image.png

过拟合,效果反而变差

上一篇 下一篇

猜你喜欢

热点阅读