我爱编程

回归模型2018-01-25

2018-01-30  本文已影响0人  Jackpot_0213

Linear regression 线性回归

Machine Learning

个人理解,现行回归就像是初中的数学题,给你几组数据,求出先行回归方程,然后再带入一个X求出相应的Y。有时候我们还会遇到这样的题,就是在坐标系上给你几个点,老师让我们划出一条线,要求尽可能多得点落在这条线上,或者是要求这些点均匀的分布在这条线的两侧,这条线就是你人工预测的回归模型。只不过我们现在需要运用一些数学方法,更精确的求出这条线,这只是一种预测方法,尽管预测结果会差强人意但是这个你入门的一个很好的突破口。

配置环境

Siraj 将在下节视频中使用 Python 代码实现线性回归模型。如需在本机跟随练习,可先使用下列命令配置新的 conda 环境:

参考资源

你或许不一定熟悉 Siraj 视频中使用到的第三方库:pandas、scikit-learn 和 matplotlib。如有兴趣,可利用以下资源进行预习:

简单线性回归

上面我们介绍了如何使用 线性回归模型 根据动物的大脑重量来预测对应体重。在下面练习中,将构建线性回归模型,根据各国男性人口的 身体质量指数 (BMI) 来预测该国人口平均寿命。在此之前,我们先来讲解如何使用所需工具。

将使用 scikit-learn 库中的 LinearRegression 类来创建线性回归模型,其提供的 fit() 方法可用于拟合模型。

>>> from sklearn.linear_model import LinearRegression
>>> model = LinearRegression()
>>> model.fit(x_values, y_values)

在上述示例中,新建的线性回归模型被赋于 model 变量,然后根据 x_valuesy_values 数据进行了拟合。模型拟合在这里是指通过训练数据求取最优拟合线的步骤。然后我们使用模型的 predict() 方法进行两组预测。

>>> print(model.predict([ [127], [248] ]))
[[ 438.94308857, 127.14839521]]

预测结果最终以数组形式返回,每个元素对应各自的输入数组。第一项输入数组 [127] 的预测结果为 438.94308857。第二项输入数组 [248] 的预测结果为 127.14839521。输入数据采取数组形式 [127] 而非 127 的原因是, 线性回归模型可以接受多个特征。多元线性回归模型将在课程稍后部分讲解,本节练习中的模型仅使用单个特征。

线性回归练习

本节练习提供的数据为各国男性人口的 BMI 与该国人口平均寿命。数据来自 Gapminder

数据文件位于页面下方练习中的 "bmi_and_life_expectancy.csv" 标签。其中 "Country" 列记录出生国家,"Life expectancy" 列记录该国平均寿命,"BMI" 列记录该国男性 BMI 数据。你将使用 BMI 数据来预测平均寿命。

任务步骤:

1. 加载数据

2. 构建线性回归模型

3. 使用模型进行预测

# TODO: Add import statements
import pandas as pd
from sklearn.linear_model import LinearRegression

# Assign the dataframe to this variable.
# TODO: Load the data
bmi_life_data = pd.read_csv("bmi_and_life_expectancy.csv")

# Make and fit the linear regression model
#TODO: Fit the model and Assign it to bmi_life_model
bmi_life_model = LinearRegression()
bmi_life_model.fit(bmi_life_data[['BMI']], bmi_life_data[['Life expectancy']])

# Mak a prediction using the model
# TODO: Predict life expectancy for a BMI value of 21.07931
laos_life_exp = bmi_life_model.predict(21.07931)

Country,Life expectancy,BMI
Afghanistan,52.8,20.62058
Albania,76.8,26.44657
Algeria,75.5,24.5962
Andorra,84.6,27.63048
Angola,56.7,22.25083
Armenia,72.3,25.355420000000002
Australia,81.6,27.56373
Austria,80.4,26.467409999999997
Azerbaijan,69.2,25.65117
Bahamas,72.2,27.24594
Bangladesh,68.3,20.39742
Barbados,75.3,26.384390000000003
Belarus,70.0,26.16443
Belgium,79.6,26.75915
Belize,70.7,27.02255
Benin,59.7,22.41835
Bhutan,70.7,22.8218
Bolivia,71.2,24.43335
Bosnia and Herzegovina,77.5,26.611629999999998
Botswana,53.2,22.129839999999998
Brazil,73.2,25.78623
Bulgaria,73.2,26.542859999999997
Burkina Faso,58.0,21.27157
Burundi,59.1,21.50291
Cambodia,66.1,20.80496
Cameroon,56.6,23.681729999999998
Canada,80.8,27.4521
Cape Verde,70.4,23.515220000000003
Chad,54.3,21.485689999999998
Chile,78.5,27.015420000000002
China,73.4,22.92176
Colombia,76.2,24.94041
Comoros,67.1,22.06131
"Congo, Dem. Rep.",57.5,19.86692
"Congo, Rep.",58.8,21.87134
Costa Rica,79.8,26.47897
Cote d'Ivoire,55.4,22.56469
Croatia,76.2,26.596290000000003
Cuba,77.6,25.06867
Cyprus,80.0,27.41899
Denmark,78.9,26.13287
Djibouti,61.8,23.38403
Ecuador,74.7,25.58841
Egypt,70.2,26.732429999999997
El Salvador,73.7,26.36751
Eritrea,60.1,20.885089999999998
Estonia,74.2,26.264459999999996
Ethiopia,60.0,20.247
Fiji,64.9,26.53078
Finland,79.6,26.733390000000004
France,81.1,25.853289999999998
French Polynesia,75.11,30.867520000000003
Gabon,61.7,24.0762
Gambia,65.7,21.65029
Georgia,71.8,25.54942
Germany,80.0,27.165090000000003
Ghana,62.0,22.842470000000002
Greece,80.2,26.33786
Greenland,70.3,26.01359
Grenada,70.8,25.179879999999997
Guatemala,71.2,25.29947
Guinea,57.1,22.52449
Guinea-Bissau,53.6,21.64338
Guyana,65.0,23.68465
Haiti,61.0,23.66302
Honduras,71.8,25.10872
Hungary,73.9,27.115679999999998
Iceland,82.4,27.206870000000002
India,64.7,20.95956
Indonesia,69.4,21.85576
Iran,73.1,25.310029999999998
Iraq,66.6,26.71017
Ireland,80.1,27.65325
Israel,80.6,27.13151
Jamaica,75.1,24.00421
Japan,82.5,23.50004
Jordan,76.9,27.47362
Kazakhstan,67.1,26.290779999999998
Kenya,60.8,21.592579999999998
Kuwait,77.3,29.172109999999996
Latvia,72.4,26.45693
Lesotho,44.5,21.90157
Liberia,59.9,21.89537
Libya,75.6,26.54164
Lithuania,72.1,26.86102
Luxembourg,81.0,27.434040000000003
"Macedonia, FYR",74.5,26.34473
Madagascar,62.2,21.403470000000002
Malawi,52.4,22.034679999999998
Malaysia,74.5,24.73069
Maldives,78.5,23.219910000000002
Mali,58.5,21.78881
Malta,80.7,27.683609999999998
Marshall Islands,65.3,29.37337
Mauritania,67.9,22.62295
Mauritius,72.9,25.15669
Mexico,75.4,27.42468
Moldova,70.4,24.2369
Mongolia,64.8,24.88385
Montenegro,76.0,26.55412
Morocco,73.3,25.63182
Mozambique,54.0,21.93536
Myanmar,59.4,21.44932
Namibia,59.1,22.65008
Nepal,68.4,20.76344
Netherlands,80.3,26.01541
Nicaragua,77.0,25.77291
Niger,58.0,21.21958
Nigeria,59.2,23.03322
Norway,80.8,26.934240000000003
Oman,76.2,26.241090000000003
Pakistan,64.1,22.299139999999998
Panama,77.3,26.26959
Papua New Guinea,58.6,25.015060000000002
Paraguay,74.0,25.54223
Peru,78.2,24.770410000000002
Philippines,69.8,22.872629999999997
Poland,75.4,26.6738
Portugal,79.4,26.68445
Qatar,77.9,28.13138
Romania,73.2,25.41069
Russia,67.9,26.01131
Rwanda,64.1,22.55453
Samoa,72.3,30.42475
Sao Tome and Principe,66.0,23.51233
Senegal,63.5,21.927429999999998
Serbia,74.3,26.51495
Sierra Leone,53.6,22.53139
Singapore,80.6,23.83996
Slovak Republic,74.9,26.92717
Slovenia,78.7,27.43983
Somalia,52.6,21.969170000000002
South Africa,53.4,26.85538
Spain,81.1,27.49975
Sri Lanka,74.0,21.96671
Sudan,65.5,22.40484
Suriname,70.2,25.49887
Swaziland,45.1,23.16969
Sweden,81.1,26.37629
Switzerland,82.0,26.20195
Syria,76.1,26.919690000000003
Tajikistan,69.6,23.77966
Tanzania,60.4,22.47792
Thailand,73.9,23.008029999999998
Timor-Leste,69.9,20.59082
Togo,57.5,21.87875
Tonga,70.3,30.99563
Trinidad and Tobago,71.7,26.396690000000003
Tunisia,76.8,25.15699
Turkey,77.8,26.703709999999997
Turkmenistan,67.2,25.24796
Uganda,56.0,22.35833
Ukraine,67.8,25.42379
United Arab Emirates,75.6,28.053590000000003
United Kingdom,79.7,27.392490000000002
United States,78.3,28.456979999999998
Uruguay,76.0,26.39123
Uzbekistan,69.6,25.32054
Vanuatu,63.4,26.78926
West Bank and Gaza,74.1,26.5775
Vietnam,74.1,20.9163
Zambia,51.1,20.68321
Zimbabwe,47.3,22.0266

线性回归注意事项

线性回归隐含一系列前提假设,并非适合所有情形,因此应当注意以下两个问题。

最适用于线性数据
线性回归会根据训练数据生成直线模型。如果训练数据包含非线性关系,你需要选择:调整数据(进行数据转换)、增加特征数量(参考下节内容)或改用其他模型。

容易受到异常值影响
线性回归的目标是求取对训练数据而言的 “最优拟合” 直线。如果数据集中存在不符合总体规律的异常值,最终结果将会存在不小偏差。
在第一个图表中,模型与数据相当拟合。
但若添加若干不符合规律的异常值,会明显改变模型的预测结果。
在大多数情况下,模型需要基本上能与大部分数据拟合,所以要小心异常值!

多元线性回归

我们在上节练习中使用 BMI 来预测平均寿命。这里的 BMI 是预测变量,也称为自变量。预测变量被用来预测其他变量,而被预测的则称为因变量。在本例中,因变量为平均寿命。

假设我们又获取了各国人口的心率数据。那么可以同时使用 BMI 和心率来预测平均寿命吗?

当然可以!只需使用多元线性回归即可。

如果预测结果取决于多个变量,则需相应创建更加复杂的模型。只要所选自变量/预测变量适合当前场景,增加变量有助于改善预测结果。

在只有单个预测变量时,线性回归模型是一条直线,而增加预测变量,相当于增加图像维度。
此时需要使用三维图像来进行展示,线性回归模型也变成了平面:
预测变量的数量可以超过两个,甚至在合适的情况下多多益善!若使用 <n 个预测变量,那么模型公式则为:

y=m1​x1​+m2​x2​+m3​x3​+...+mn​xn​+b

模型中的预测变量越多,就越难以通过图像展示。幸好线性回归的其他环节不发生变化,仍然可以用相同方式拟合模型和做出预测。下面来试试吧!

编程测验:多元线性回归

在本测验中,你将使用到波士顿房价数据集。数据集中包含 506 栋房屋的 13 个特征与房价中值(单位为 1000 美元)。你将根据这 13 个特征拟合模型,并预测房价。(波士顿房价数据集来自于 UCI机器学习数据集, 但现已下线。你还可以查看我们的波士顿房价项目中的数据集来对数据集有更多了解。

你需要完成以下步骤:

1. 构建线性回归模型

2. 使用模型进行预测

from sklearn.linear_model import LinearRegression
from sklearn.datasets import load_boston

# Load the data from the the boston house-prices dataset 
boston_data = load_boston()
x = boston_data['data']
y = boston_data['target']

# Make and fit the linear regression model
# TODO: Fit the model and Assign it to the model variable
model = LinearRegression()
model.fit(x, y)

# Mak a prediction using the model
sample_house = [[2.29690000e-01, 0.00000000e+00, 1.05900000e+01, 0.00000000e+00, 4.89000000e-01,
                6.32600000e+00, 5.25000000e+01, 4.35490000e+00, 4.00000000e+00, 2.77000000e+02,
                1.86000000e+01, 3.94870000e+02, 1.09700000e+01]]
# TODO: Predict housing price for the sample_house
prediction = model.predict(sample_house)
上一篇下一篇

猜你喜欢

热点阅读