回归树

2020-05-20  本文已影响0人  还闹不闹

传送门:分类树

1、原理

分类与回归树(classification and regression tree,CART)模型由Breiman等人在1984年提出。CART同样由特征选择、树的生成及剪枝组成。

既然是决策树,那么必然会存在以下两个核心问题:如何选择划分点?如何决定叶节点的输出值?
一个回归树对应着输入空间(即特征空间)的一个划分以及在划分单元上的输出值。

2、算法描述

一个简单实例:训练数据见下表,目标是得到一棵最小二乘回归树。

2.1 选择最优切分变量j与最优切分点s

在本数据集中,只有一个变量,因此最优切分变量自然是x。
接下来我们考虑9个切分点[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5].
你可能会问,为什么会带小数点呢?类比于篮球比赛的博彩,倘若两队比分是96:95,而盘口是“让1分 A队胜B队”,那A队让1分之后,到底是A队赢还是B队赢了?所以我们经常可以看到“让0.5分 A队胜B队”这样的盘口。在这个实例中,也是这个道理。

损失函数定义为平方损失函数 Loss(y,f(x))=(f(x)−y)^2,将上述9个切分点一依此代入下面的公式,其中 c_m=ave(y_i|x_i∈R_m) 用选定的(j,s)划分区域,并决定输出值

2.2 对两个子区域继续调用上述步骤

2.3 生成回归树

假设在生成3个区域之后停止划分,那么最终生成的回归树形式如下:

3、代码

# coding=utf-8
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn import linear_model
# 画图支持中文显示
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']

# Data set
x = np.array(list(range(1, 11))).reshape(-1, 1)
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]).ravel() # 多维数组转换成一维数组[5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]
print('x:\n' + str(x) + '\ny:\n' + str(y))

# Fit regression model
model1 = DecisionTreeRegressor(max_depth=1)
model2 = DecisionTreeRegressor(max_depth=3)
model3 = linear_model.LinearRegression() # 线性回归模型
model1.fit(x, y)
model2.fit(x, y)
model3.fit(x, y)

# Predict
X_test = np.arange(0.0, 10.0, 0.01)[:, np.newaxis] # 1000行1列
print(X_test.size)
# test = np.zeros(shape=(2, 3, 4), dtype=int)
# print(test.size)
# print(test.ndim)
y_1 = model1.predict(X_test)
y_2 = model2.predict(X_test)
y_3 = model3.predict(X_test)

# Plot the results
plt.figure()
plt.scatter(x, y, s=20, color="darkorange", edgecolors='blue', label="src_data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=1", linewidth=2, linestyle='--')
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=3", linewidth=2, linestyle='--')
plt.plot(X_test, y_3, color='red', label='liner regression', linewidth=2, linestyle='--')
plt.xlim(0, 12) # 设置坐标轴
plt.ylim(4, 11)
plt.xlabel("x_data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

拓展阅读:https://blog.csdn.net/hy592070616/article/details/81628956
上一篇下一篇

猜你喜欢

热点阅读