21. 日月光华 Python数据分析 - 机器学习 - 多元线
数据描述: 数据集包含了 200 个不同市场的产品销售额, 每个销售额对应 3 种广告媒体,分别是 TV, radio 和 newspaper
任务描述:分析广告媒体与销售额之间的关系,基于广告媒体预算,预测销售额
评价指标:销售额为连续值,为回归问题,可以采用均方误差作为评价指标 开发环境
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
plt.style.use('ggplot')
data = pd.read_csv('Advertising.csv')
data.head()
Unnamed: 0 TV radio newspaper sales
0 1 230.1 37.8 69.2 22.1
1 2 44.5 39.3 45.1 10.4
2 3 17.2 45.9 69.3 9.3
3 4 151.5 41.3 58.5 18.5
4 5 180.8 10.8 58.4 12.9
查看TV媒体投入与销售额之间的关系
plt.scatter(data.TV, data.sales)

查看radio媒体投入与销售额之间的关系
plt.scatter(data.radio, data.sales)

查看newspaper媒体投入与销售额之间的关系
plt.scatter(data.newspaper, data.sales)

# 如何客观的评价我们的模型
x = data[['TV','radio','newspaper']]
y = data.sales
x_train,x_test,y_train,y_test = train_test_split(x, y)
len(x_train),len(y_train)
(150, 150)
len(x_test)
50
model = LinearRegression()
model.fit(x_train, y_train)
model.coef_
array([ 0.04371177, 0.19392061, -0.00030117])
for i in zip(x_train.columns, model.coef_): # zip打包
print(i)
('TV', 0.04371177253650498)
('radio', 0.193920608289641)
('newspaper', -0.0003011660943684182) # 发现 newspaper投入对销量影响最小
mean_squared_error(model.predict(x_test), y_test)
5.22399779076229
# 模型的改进
newspaper对结果影响不大,因此去掉该特征,反而是预测效果更好
x = data[['TV','radio']]
y = data.sales
x_train,x_test,y_train,y_test = train_test_split(x, y)
model2 = LinearRegression()
model2.fit(x_train,y_train)
model2.coef_
array([0.04603784, 0.18428908])
mean_squared_error(model2.predict(x_test),y_test)
1.7138986453138052