树回归
2018-12-04 本文已影响12人
JasonChiu17
原理:
- 将数据集切分成很多份易建模的数据
- 利用线性回归技术建模
优点
- 可以对复杂和非线性的数据建模
缺点
- 结果不易理解
适用数据类型
- 数值型和标称型数据
选择最佳特征之后,数据划分方法:
- ID3: 按最佳特征的所有可能取值来划分
- CART:二元切分法
import numpy as np
定义加载数据函数: x,y都放在一个dataMat里
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split()
fltLine = list(map(float,curLine)) #转换成list形式,map()得到是一个对象
dataMat.append(fltLine)
return dataMat
定义二元切分法
#返回两个数据集,大于value或小于value
def binSplitDataSet(dataSet, feature, value):
# #dataSet为array
# mat0 = dataSet[dataSet[:,feature] > value]
# mat1 = dataSet[dataSet[:,feature] <= value]
#dataSet为matrix
mat0 = dataSet[np.nonzero(dataSet[:,feature] > value)[0],:] #np.nonzero返回非零元素的位置,(行,列)
mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
定义构建树的函数
#计算目标变量的均值
def regLeaf(dataSet):
return np.mean(dataSet[:,-1])
#计算目标变量的总方差
def regErr(dataSet):
return np.var(dataSet[:,-1])*dataSet.shape[0]
def createTree(dataSet, leafType = regLeaf, errType= regErr, ops=(1,4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None:
return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
测试函数
testMat = np.mat(np.eye(4)) #matrix
# testMat = np.eye(4) #array
testMat
matrix([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
mat0,mat1 = binSplitDataSet(testMat,1,0)
print(mat0)
print(mat1)
[[0. 1. 0. 0.]]
[[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
回归树的切分函数
# #计算目标变量的均值
# def regLeaf(dataSet):
# return np.mean(dataSet[:,-1])
# #计算目标变量的总方差
# def regErr(dataSet):
# return np.var(dataSet[:,-1])*dataSet.shape[0]
#选择最优划分特征
def chooseBestSplit(dataSet, leafType = regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0] #容许的误差下降值
tolN=ops[1] #切分的最少样本数
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #如果目标变量值相同,则返回预测值(叶子节点)
return None, leafType(dataSet)
m,n = dataSet.shape
S = errType(dataSet) #初始误差
bestS = np.inf #初始最小误差为无穷大
bestIndex = 0;bestValue = 0
for featIndex in range(n-1): #遍历所有特征,除掉最后一列的目标变量
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0 ,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
#若划分的数据集少于自定义的最少样本数,则不划分
if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):continue
newS = errType(mat0)+errType(mat1) #计算划分数据之后的误差
if newS < bestS: #若划分之后的误差更小,则更新最小误差
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#若误差下降太小,则直接返回预测值
if (S - bestS )< tolS:
return None,leafType(dataSet)
#否则,按最优特征和值,划分数据
mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
#若划分的数据集少于最少样本数,则返回预测值。
#这里条件成立的情况只会是初始化的feature和value是最优的,因为上面循环已经有阈值判断的条件了
if (mat0.shape[0] < tolN) or (mat1.shape[0]< tolN):
return None,leafType(dataSet)
return bestIndex,bestValue
测试数据
# %matplotlib inline
import matplotlib.pyplot as plt
def data2show(data):
xArr = data[:,-2].A
yArr = data[:,-1].A
fig = plt.figure()
# plt.grid()
ax = fig.add_subplot(111)
plt.scatter(xArr,yArr,s=5,label='raw data')
ax.legend(loc='upper left')
ax.set_xlabel('x1')
ax.set_ylabel('x2')
plt.show()
ex00.txt
myDat = loadDataSet('../../Reference Code/Ch09/ex00.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
output_17_1.png
- 两个叶节点
myMat
matrix([[ 3.609800e-02, 1.550960e-01],
[ 9.933490e-01, 1.077553e+00],
[ 5.308970e-01, 8.934620e-01],
[ 7.123860e-01, 5.648580e-01],
[ 3.435540e-01, -3.717000e-01],
[ 9.801600e-02, -3.327600e-01],
[ 6.911150e-01, 8.343910e-01],
[ 9.135800e-02, 9.993500e-02],
[ 7.270980e-01, 1.000567e+00],
[ 9.519490e-01, 9.452550e-01],
[ 7.685960e-01, 7.602190e-01],
[ 5.413140e-01, 8.937480e-01],
[ 1.463660e-01, 3.428300e-02],
[ 6.731950e-01, 9.150770e-01],
[ 1.835100e-01, 1.848430e-01],
[ 3.395630e-01, 2.067830e-01],
[ 5.179210e-01, 1.493586e+00],
[ 7.037550e-01, 1.101678e+00],
[ 8.307000e-03, 6.997600e-02],
[ 2.439090e-01, -2.946700e-02],
[ 3.069640e-01, -1.773210e-01],
[ 3.649200e-02, 4.081550e-01],
[ 2.955110e-01, 2.882000e-03],
[ 8.375220e-01, 1.229373e+00],
[ 2.020540e-01, -8.774400e-02],
[ 9.193840e-01, 1.029889e+00],
[ 3.772010e-01, -2.435500e-01],
[ 8.148250e-01, 1.095206e+00],
[ 6.112700e-01, 9.820360e-01],
[ 7.224300e-02, -4.209830e-01],
[ 4.102300e-01, 3.317220e-01],
[ 8.690770e-01, 1.114825e+00],
[ 6.205990e-01, 1.334421e+00],
[ 1.011490e-01, 6.883400e-02],
[ 8.208020e-01, 1.325907e+00],
[ 5.200440e-01, 9.619830e-01],
[ 4.881300e-01, -9.779100e-02],
[ 8.198230e-01, 8.352640e-01],
[ 9.750220e-01, 6.735790e-01],
[ 9.531120e-01, 1.064690e+00],
[ 4.759760e-01, -1.637070e-01],
[ 2.731470e-01, -4.552190e-01],
[ 8.045860e-01, 9.240330e-01],
[ 7.479500e-02, -3.496920e-01],
[ 6.253360e-01, 6.236960e-01],
[ 6.562180e-01, 9.585060e-01],
[ 8.340780e-01, 1.010580e+00],
[ 7.819300e-01, 1.074488e+00],
[ 9.849000e-03, 5.659400e-02],
[ 3.022170e-01, -1.486500e-01],
[ 6.782870e-01, 9.077270e-01],
[ 1.805060e-01, 1.036760e-01],
[ 1.936410e-01, -3.275890e-01],
[ 3.434790e-01, 1.752640e-01],
[ 1.458090e-01, 1.369790e-01],
[ 9.967570e-01, 1.035533e+00],
[ 5.902100e-01, 1.336661e+00],
[ 2.380700e-01, -3.584590e-01],
[ 5.613620e-01, 1.070529e+00],
[ 3.775970e-01, 8.850500e-02],
[ 9.914200e-02, 2.528000e-02],
[ 5.395580e-01, 1.053846e+00],
[ 7.902400e-01, 5.332140e-01],
[ 2.422040e-01, 2.093590e-01],
[ 1.523240e-01, 1.328580e-01],
[ 2.526490e-01, -5.561300e-02],
[ 8.959300e-01, 1.077275e+00],
[ 1.333000e-01, -2.231430e-01],
[ 5.597630e-01, 1.253151e+00],
[ 6.436650e-01, 1.024241e+00],
[ 8.772410e-01, 7.970050e-01],
[ 6.137650e-01, 1.621091e+00],
[ 6.457620e-01, 1.026886e+00],
[ 6.513760e-01, 1.315384e+00],
[ 6.977180e-01, 1.212434e+00],
[ 7.425270e-01, 1.087056e+00],
[ 9.010560e-01, 1.055900e+00],
[ 3.623140e-01, -5.564640e-01],
[ 9.482680e-01, 6.318620e-01],
[ 2.340000e-04, 6.090300e-02],
[ 7.500780e-01, 9.062910e-01],
[ 3.254120e-01, -2.192450e-01],
[ 7.268280e-01, 1.017112e+00],
[ 3.480130e-01, 4.893900e-02],
[ 4.581210e-01, -6.145600e-02],
[ 2.807380e-01, -2.288800e-01],
[ 5.677040e-01, 9.690580e-01],
[ 7.509180e-01, 7.481040e-01],
[ 5.758050e-01, 8.990900e-01],
[ 5.079400e-01, 1.107265e+00],
[ 7.176900e-02, -1.109460e-01],
[ 5.535200e-01, 1.391273e+00],
[ 4.011520e-01, -1.216400e-01],
[ 4.066490e-01, -3.663170e-01],
[ 6.521210e-01, 1.004346e+00],
[ 3.478370e-01, -1.534050e-01],
[ 8.193100e-02, -2.697560e-01],
[ 8.216480e-01, 1.280895e+00],
[ 4.801400e-02, 6.449600e-02],
[ 1.309620e-01, 1.842410e-01],
[ 7.734220e-01, 1.125943e+00],
[ 7.896250e-01, 5.526140e-01],
[ 9.699400e-02, 2.271670e-01],
[ 6.257910e-01, 1.244731e+00],
[ 5.895750e-01, 1.185812e+00],
[ 3.231810e-01, 1.808110e-01],
[ 8.224430e-01, 1.086648e+00],
[ 3.603230e-01, -2.048300e-01],
[ 9.501530e-01, 1.022906e+00],
[ 5.275050e-01, 8.795600e-01],
[ 8.600490e-01, 7.174900e-01],
[ 7.044000e-03, 9.415000e-02],
[ 4.383670e-01, 3.401400e-02],
[ 5.745730e-01, 1.066130e+00],
[ 5.366890e-01, 8.672840e-01],
[ 7.821670e-01, 8.860490e-01],
[ 9.898880e-01, 7.442070e-01],
[ 7.614740e-01, 1.058262e+00],
[ 9.854250e-01, 1.227946e+00],
[ 1.325430e-01, -3.293720e-01],
[ 3.469860e-01, -1.503890e-01],
[ 7.687840e-01, 8.997050e-01],
[ 8.489210e-01, 1.170959e+00],
[ 4.492800e-01, 6.909800e-02],
[ 6.617200e-02, 5.243900e-02],
[ 8.137190e-01, 7.066010e-01],
[ 6.619230e-01, 7.670400e-01],
[ 5.294910e-01, 1.022206e+00],
[ 8.464550e-01, 7.200300e-01],
[ 4.486560e-01, 2.697400e-02],
[ 7.950720e-01, 9.657210e-01],
[ 1.181560e-01, -7.740900e-02],
[ 8.424800e-02, -1.954700e-02],
[ 8.458150e-01, 9.526170e-01],
[ 5.769460e-01, 1.234129e+00],
[ 7.720830e-01, 1.299018e+00],
[ 6.966480e-01, 8.454230e-01],
[ 5.950120e-01, 1.213435e+00],
[ 6.486750e-01, 1.287407e+00],
[ 8.970940e-01, 1.240209e+00],
[ 5.529900e-01, 1.036158e+00],
[ 3.329820e-01, 2.100840e-01],
[ 6.561500e-02, -3.069700e-01],
[ 2.786610e-01, 2.536280e-01],
[ 7.731680e-01, 1.140917e+00],
[ 2.036930e-01, -6.403600e-02],
[ 3.556880e-01, -1.193990e-01],
[ 9.888520e-01, 1.069062e+00],
[ 5.187350e-01, 1.037179e+00],
[ 5.145630e-01, 1.156648e+00],
[ 9.764140e-01, 8.629110e-01],
[ 9.190740e-01, 1.123413e+00],
[ 6.977770e-01, 8.278050e-01],
[ 9.280970e-01, 8.832250e-01],
[ 9.002720e-01, 9.968710e-01],
[ 3.441020e-01, -6.153900e-02],
[ 1.480490e-01, 2.042980e-01],
[ 1.300520e-01, -2.616700e-02],
[ 3.020010e-01, 3.171350e-01],
[ 3.371000e-01, 2.633200e-02],
[ 3.149240e-01, -1.952000e-03],
[ 2.696810e-01, -1.659710e-01],
[ 1.960050e-01, -4.884700e-02],
[ 1.290610e-01, 3.051070e-01],
[ 9.367830e-01, 1.026258e+00],
[ 3.055400e-01, -1.159910e-01],
[ 6.839210e-01, 1.414382e+00],
[ 6.223980e-01, 7.663300e-01],
[ 9.025320e-01, 8.616010e-01],
[ 7.125030e-01, 9.334900e-01],
[ 5.900620e-01, 7.055310e-01],
[ 7.231200e-01, 1.307248e+00],
[ 1.882180e-01, 1.136850e-01],
[ 6.436010e-01, 7.825520e-01],
[ 5.202070e-01, 1.209557e+00],
[ 2.331150e-01, -3.481470e-01],
[ 4.656250e-01, -1.529400e-01],
[ 8.845120e-01, 1.117833e+00],
[ 6.632000e-01, 7.016340e-01],
[ 2.688570e-01, 7.344700e-02],
[ 7.292340e-01, 9.319560e-01],
[ 4.296640e-01, -1.886590e-01],
[ 7.371890e-01, 1.200781e+00],
[ 3.785950e-01, -2.960940e-01],
[ 9.301730e-01, 1.035645e+00],
[ 7.743010e-01, 8.367630e-01],
[ 2.739400e-01, -8.571300e-02],
[ 8.244420e-01, 1.082153e+00],
[ 6.260110e-01, 8.405440e-01],
[ 6.793900e-01, 1.307217e+00],
[ 5.782520e-01, 9.218850e-01],
[ 7.855410e-01, 1.165296e+00],
[ 5.974090e-01, 9.747700e-01],
[ 1.408300e-02, -1.325250e-01],
[ 6.638700e-01, 1.187129e+00],
[ 5.523810e-01, 1.369630e+00],
[ 6.838860e-01, 9.999850e-01],
[ 2.103340e-01, -6.899000e-03],
[ 6.045290e-01, 1.212685e+00],
[ 2.507440e-01, 4.629700e-02]])
ex0.txt
myDat = loadDataSet('../../Reference Code/Ch09/ex0.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 1, 'spVal': 0.39435, 'left': {'spInd': 1, 'spVal': 0.582002, 'left': {'spInd': 1, 'spVal': 0.797583, 'left': 3.9871632, 'right': 2.9836209534883724}, 'right': 1.980035071428571}, 'right': {'spInd': 1, 'spVal': 0.197834, 'left': 1.0289583666666666, 'right': -0.023838155555555553}}
output_21_1.png
- 五个叶节点,即五个预测值
- 总方差越小,划分的数据集越集中。
myMat
matrix([[ 1.000000e+00, 4.091750e-01, 1.883180e+00],
[ 1.000000e+00, 1.826030e-01, 6.390800e-02],
[ 1.000000e+00, 6.636870e-01, 3.042257e+00],
[ 1.000000e+00, 5.173950e-01, 2.305004e+00],
[ 1.000000e+00, 1.364300e-02, -6.769800e-02],
[ 1.000000e+00, 4.696430e-01, 1.662809e+00],
[ 1.000000e+00, 7.254260e-01, 3.275749e+00],
[ 1.000000e+00, 3.943500e-01, 1.118077e+00],
[ 1.000000e+00, 5.077600e-01, 2.095059e+00],
[ 1.000000e+00, 2.373950e-01, 1.181912e+00],
[ 1.000000e+00, 5.753400e-02, 2.216630e-01],
[ 1.000000e+00, 3.698200e-01, 9.384530e-01],
[ 1.000000e+00, 9.768190e-01, 4.149409e+00],
[ 1.000000e+00, 6.160510e-01, 3.105444e+00],
[ 1.000000e+00, 4.137000e-01, 1.896278e+00],
[ 1.000000e+00, 1.052790e-01, -1.213450e-01],
[ 1.000000e+00, 6.702730e-01, 3.161652e+00],
[ 1.000000e+00, 9.527580e-01, 4.135358e+00],
[ 1.000000e+00, 2.723160e-01, 8.590630e-01],
[ 1.000000e+00, 3.036970e-01, 1.170272e+00],
[ 1.000000e+00, 4.866980e-01, 1.687960e+00],
[ 1.000000e+00, 5.118100e-01, 1.979745e+00],
[ 1.000000e+00, 1.958650e-01, 6.869000e-02],
[ 1.000000e+00, 9.867690e-01, 4.052137e+00],
[ 1.000000e+00, 7.856230e-01, 3.156316e+00],
[ 1.000000e+00, 7.975830e-01, 2.950630e+00],
[ 1.000000e+00, 8.130600e-02, 6.893500e-02],
[ 1.000000e+00, 6.597530e-01, 2.854020e+00],
[ 1.000000e+00, 3.752700e-01, 9.997430e-01],
[ 1.000000e+00, 8.191360e-01, 4.048082e+00],
[ 1.000000e+00, 1.424320e-01, 2.309230e-01],
[ 1.000000e+00, 2.151120e-01, 8.166930e-01],
[ 1.000000e+00, 4.127000e-02, 1.307130e-01],
[ 1.000000e+00, 4.413600e-02, -5.377060e-01],
[ 1.000000e+00, 1.313370e-01, -3.391090e-01],
[ 1.000000e+00, 4.634440e-01, 2.124538e+00],
[ 1.000000e+00, 6.719050e-01, 2.708292e+00],
[ 1.000000e+00, 9.465590e-01, 4.017390e+00],
[ 1.000000e+00, 9.041760e-01, 4.004021e+00],
[ 1.000000e+00, 3.066740e-01, 1.022555e+00],
[ 1.000000e+00, 8.190060e-01, 3.657442e+00],
[ 1.000000e+00, 8.454720e-01, 4.073619e+00],
[ 1.000000e+00, 1.562580e-01, 1.199400e-02],
[ 1.000000e+00, 8.571850e-01, 3.640429e+00],
[ 1.000000e+00, 4.001580e-01, 1.808497e+00],
[ 1.000000e+00, 3.753950e-01, 1.431404e+00],
[ 1.000000e+00, 8.858070e-01, 3.935544e+00],
[ 1.000000e+00, 2.399600e-01, 1.162152e+00],
[ 1.000000e+00, 1.486400e-01, -2.273300e-01],
[ 1.000000e+00, 1.431430e-01, -6.872800e-02],
[ 1.000000e+00, 3.215820e-01, 8.250510e-01],
[ 1.000000e+00, 5.093930e-01, 2.008645e+00],
[ 1.000000e+00, 3.558910e-01, 6.645660e-01],
[ 1.000000e+00, 9.386330e-01, 4.180202e+00],
[ 1.000000e+00, 3.480570e-01, 8.648450e-01],
[ 1.000000e+00, 4.388980e-01, 1.851174e+00],
[ 1.000000e+00, 7.814190e-01, 2.761993e+00],
[ 1.000000e+00, 9.113330e-01, 4.075914e+00],
[ 1.000000e+00, 3.246900e-02, 1.102290e-01],
[ 1.000000e+00, 4.999850e-01, 2.181987e+00],
[ 1.000000e+00, 7.716630e-01, 3.152528e+00],
[ 1.000000e+00, 6.703610e-01, 3.046564e+00],
[ 1.000000e+00, 1.762020e-01, 1.289540e-01],
[ 1.000000e+00, 3.921700e-01, 1.062726e+00],
[ 1.000000e+00, 9.111880e-01, 3.651742e+00],
[ 1.000000e+00, 8.722880e-01, 4.401950e+00],
[ 1.000000e+00, 7.331070e-01, 3.022888e+00],
[ 1.000000e+00, 6.102390e-01, 2.874917e+00],
[ 1.000000e+00, 7.327390e-01, 2.946801e+00],
[ 1.000000e+00, 7.148250e-01, 2.893644e+00],
[ 1.000000e+00, 7.638600e-02, 7.213100e-02],
[ 1.000000e+00, 5.590090e-01, 1.748275e+00],
[ 1.000000e+00, 4.272580e-01, 1.912047e+00],
[ 1.000000e+00, 8.418750e-01, 3.710686e+00],
[ 1.000000e+00, 5.589180e-01, 1.719148e+00],
[ 1.000000e+00, 5.332410e-01, 2.174090e+00],
[ 1.000000e+00, 9.566650e-01, 3.656357e+00],
[ 1.000000e+00, 6.203930e-01, 3.522504e+00],
[ 1.000000e+00, 5.661200e-01, 2.234126e+00],
[ 1.000000e+00, 5.232580e-01, 1.859772e+00],
[ 1.000000e+00, 4.768840e-01, 2.097017e+00],
[ 1.000000e+00, 1.764080e-01, 1.794000e-03],
[ 1.000000e+00, 3.030940e-01, 1.231928e+00],
[ 1.000000e+00, 6.097310e-01, 2.953862e+00],
[ 1.000000e+00, 1.777400e-02, -1.168030e-01],
[ 1.000000e+00, 6.226160e-01, 2.638864e+00],
[ 1.000000e+00, 8.865390e-01, 3.943428e+00],
[ 1.000000e+00, 1.486540e-01, -3.285130e-01],
[ 1.000000e+00, 1.043500e-01, -9.986600e-02],
[ 1.000000e+00, 1.168680e-01, -3.083600e-02],
[ 1.000000e+00, 5.165140e-01, 2.359786e+00],
[ 1.000000e+00, 6.648960e-01, 3.212581e+00],
[ 1.000000e+00, 4.327000e-03, 1.889750e-01],
[ 1.000000e+00, 4.255590e-01, 1.904109e+00],
[ 1.000000e+00, 7.436710e-01, 3.007114e+00],
[ 1.000000e+00, 9.351850e-01, 3.845834e+00],
[ 1.000000e+00, 6.973000e-01, 3.079411e+00],
[ 1.000000e+00, 4.445510e-01, 1.939739e+00],
[ 1.000000e+00, 6.837530e-01, 2.880078e+00],
[ 1.000000e+00, 7.559930e-01, 3.063577e+00],
[ 1.000000e+00, 9.026900e-01, 4.116296e+00],
[ 1.000000e+00, 9.449100e-02, -2.409630e-01],
[ 1.000000e+00, 8.738310e-01, 4.066299e+00],
[ 1.000000e+00, 9.918100e-01, 4.011834e+00],
[ 1.000000e+00, 1.856110e-01, 7.771000e-02],
[ 1.000000e+00, 6.945510e-01, 3.103069e+00],
[ 1.000000e+00, 6.572750e-01, 2.811897e+00],
[ 1.000000e+00, 1.187460e-01, -1.046300e-01],
[ 1.000000e+00, 8.430200e-02, 2.521600e-02],
[ 1.000000e+00, 9.453410e-01, 4.330063e+00],
[ 1.000000e+00, 7.858270e-01, 3.087091e+00],
[ 1.000000e+00, 5.309330e-01, 2.269988e+00],
[ 1.000000e+00, 8.795940e-01, 4.010701e+00],
[ 1.000000e+00, 6.527700e-01, 3.119542e+00],
[ 1.000000e+00, 8.793380e-01, 3.723411e+00],
[ 1.000000e+00, 7.647390e-01, 2.792078e+00],
[ 1.000000e+00, 5.048840e-01, 2.192787e+00],
[ 1.000000e+00, 5.542030e-01, 2.081305e+00],
[ 1.000000e+00, 4.932090e-01, 1.714463e+00],
[ 1.000000e+00, 3.637830e-01, 8.858540e-01],
[ 1.000000e+00, 3.164650e-01, 1.028187e+00],
[ 1.000000e+00, 5.802830e-01, 1.951497e+00],
[ 1.000000e+00, 5.428980e-01, 1.709427e+00],
[ 1.000000e+00, 1.126610e-01, 1.440680e-01],
[ 1.000000e+00, 8.167420e-01, 3.880240e+00],
[ 1.000000e+00, 2.341750e-01, 9.218760e-01],
[ 1.000000e+00, 4.028040e-01, 1.979316e+00],
[ 1.000000e+00, 7.094230e-01, 3.085768e+00],
[ 1.000000e+00, 8.672980e-01, 3.476122e+00],
[ 1.000000e+00, 9.933920e-01, 3.993679e+00],
[ 1.000000e+00, 7.115800e-01, 3.077880e+00],
[ 1.000000e+00, 1.336430e-01, -1.053650e-01],
[ 1.000000e+00, 5.203100e-02, -1.647030e-01],
[ 1.000000e+00, 3.668060e-01, 1.096814e+00],
[ 1.000000e+00, 6.975210e-01, 3.092879e+00],
[ 1.000000e+00, 7.872620e-01, 2.987926e+00],
[ 1.000000e+00, 4.767100e-01, 2.061264e+00],
[ 1.000000e+00, 7.214170e-01, 2.746854e+00],
[ 1.000000e+00, 2.303760e-01, 7.167100e-01],
[ 1.000000e+00, 1.043970e-01, 1.038310e-01],
[ 1.000000e+00, 1.978340e-01, 2.377600e-02],
[ 1.000000e+00, 1.292910e-01, -3.329900e-02],
[ 1.000000e+00, 5.285280e-01, 1.942286e+00],
[ 1.000000e+00, 9.493000e-03, -6.338000e-03],
[ 1.000000e+00, 9.985330e-01, 3.808753e+00],
[ 1.000000e+00, 3.635220e-01, 6.527990e-01],
[ 1.000000e+00, 9.013860e-01, 4.053747e+00],
[ 1.000000e+00, 8.326930e-01, 4.569290e+00],
[ 1.000000e+00, 1.190020e-01, -3.277300e-02],
[ 1.000000e+00, 4.876380e-01, 2.066236e+00],
[ 1.000000e+00, 1.536670e-01, 2.227850e-01],
[ 1.000000e+00, 2.386190e-01, 1.089268e+00],
[ 1.000000e+00, 2.081970e-01, 1.487788e+00],
[ 1.000000e+00, 7.509210e-01, 2.852033e+00],
[ 1.000000e+00, 1.834030e-01, 2.448600e-02],
[ 1.000000e+00, 9.956080e-01, 3.737750e+00],
[ 1.000000e+00, 1.513110e-01, 4.501700e-02],
[ 1.000000e+00, 1.268040e-01, 1.238000e-03],
[ 1.000000e+00, 9.831530e-01, 3.892763e+00],
[ 1.000000e+00, 7.724950e-01, 2.819376e+00],
[ 1.000000e+00, 7.841330e-01, 2.830665e+00],
[ 1.000000e+00, 5.693400e-02, 2.346330e-01],
[ 1.000000e+00, 4.255840e-01, 1.810782e+00],
[ 1.000000e+00, 9.987090e-01, 4.237235e+00],
[ 1.000000e+00, 7.078150e-01, 3.034768e+00],
[ 1.000000e+00, 4.138160e-01, 1.742106e+00],
[ 1.000000e+00, 2.171520e-01, 1.169250e+00],
[ 1.000000e+00, 3.605030e-01, 8.311650e-01],
[ 1.000000e+00, 9.779890e-01, 3.729376e+00],
[ 1.000000e+00, 5.079530e-01, 1.823205e+00],
[ 1.000000e+00, 9.207710e-01, 4.021970e+00],
[ 1.000000e+00, 2.105420e-01, 1.262939e+00],
[ 1.000000e+00, 9.286110e-01, 4.159518e+00],
[ 1.000000e+00, 5.803730e-01, 2.039114e+00],
[ 1.000000e+00, 8.413900e-01, 4.101837e+00],
[ 1.000000e+00, 6.815300e-01, 2.778672e+00],
[ 1.000000e+00, 2.927950e-01, 1.228284e+00],
[ 1.000000e+00, 4.569180e-01, 1.736620e+00],
[ 1.000000e+00, 1.341280e-01, -1.950460e-01],
[ 1.000000e+00, 1.624100e-02, -6.321500e-02],
[ 1.000000e+00, 6.912140e-01, 3.305268e+00],
[ 1.000000e+00, 5.820020e-01, 2.063627e+00],
[ 1.000000e+00, 3.031020e-01, 8.988400e-01],
[ 1.000000e+00, 6.225980e-01, 2.701692e+00],
[ 1.000000e+00, 5.250240e-01, 1.992909e+00],
[ 1.000000e+00, 9.967750e-01, 3.811393e+00],
[ 1.000000e+00, 8.810250e-01, 4.353857e+00],
[ 1.000000e+00, 7.234570e-01, 2.635641e+00],
[ 1.000000e+00, 6.763460e-01, 2.856311e+00],
[ 1.000000e+00, 2.546250e-01, 1.352682e+00],
[ 1.000000e+00, 4.886320e-01, 2.336459e+00],
[ 1.000000e+00, 5.198750e-01, 2.111651e+00],
[ 1.000000e+00, 1.601760e-01, 1.217260e-01],
[ 1.000000e+00, 6.094830e-01, 3.264605e+00],
[ 1.000000e+00, 5.318810e-01, 2.103446e+00],
[ 1.000000e+00, 3.216320e-01, 8.968550e-01],
[ 1.000000e+00, 8.451480e-01, 4.220850e+00],
[ 1.000000e+00, 1.200300e-02, -2.172830e-01],
[ 1.000000e+00, 1.888300e-02, -3.005770e-01],
[ 1.000000e+00, 7.147600e-02, 6.014000e-03]])
预剪枝
-
tolS和tolN其实是预剪枝操作
-
调节tolS为0,tolN为1,容许误差下降为0,最少划分样本为1,则分得很多叶节点
createTree(myMat,ops=(0,1))
{'spInd': 1,
'spVal': 0.39435,
'left': {'spInd': 1,
'spVal': 0.582002,
'left': {'spInd': 1,
'spVal': 0.797583,
'left': {'spInd': 1,
'spVal': 0.819006,
'left': {'spInd': 1,
'spVal': 0.832693,
'left': {'spInd': 1,
'spVal': 0.867298,
'left': {'spInd': 1,
'spVal': 0.872288,
'left': {'spInd': 1,
'spVal': 0.952758,
'left': {'spInd': 1,
'spVal': 0.998533,
'left': 4.237235,
'right': {'spInd': 1,
'spVal': 0.956665,
'left': {'spInd': 1,
'spVal': 0.993392,
'left': {'spInd': 1,
'spVal': 0.995608,
...
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat = np.mat(myDat)
print(createTree(myMat))
data2show(myMat)
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': 105.24862350000001, 'right': 112.42895575000001}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': 87.3103875, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.910975, 'left': 96.452867, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': 104.825409, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.181793, 'right': 102.25234449999999}}}, 'right': 95.27584316666666}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 81.110152, 'right': 88.78449880000001}}, 'right': 102.35780185714285}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': 114.554706, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': 104.82495374999999, 'right': 108.92921799999999}}, 'right': 114.1516242857143}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': 93.67344971428572, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 123.2101316, 'right': {'spInd': 0, 'spVal': 0.553797, 'left': 97.20018024999999, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 109.38961049999999, 'right': 110.979946}, 'right': 101.73699325000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': 12.50675925, 'right': 3.4331330000000007}, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -12.558604833333334, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': 14.38417875, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': -0.8923554999999995, 'right': 3.6584772500000016}}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.350725, 'left': -15.08511175, 'right': -22.693879600000002}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 15.05929075, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -19.9941552, 'right': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.217214, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': 0.40377471428571476, 'right': -13.070501}, 'right': 6.770429}, 'right': -11.822278500000001}, 'right': 3.4496025}, 'right': {'spInd': 0, 'spVal': 0.156067, 'left': -12.1079725, 'right': -6.247900000000001}}}}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 6.509843285714284, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': -2.544392714285715, 'right': 4.091626}}}}}
output_27_1.png
- 这个数据集的x2取值范围很大,所以目标变量的总方差会更大(regErr更大),所以当tolS=1时候,容忍的误差下降值对于ex2.txt数据集来说太小了,所以会划分到很多叶节点,效果差
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat2 = np.mat(myDat)
print(createTree(myMat2,ops=(10000,4)))
data2show(myMat2)
{'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}
output_29_1.png
- 增大tolS,构建的树只有两个节点,效果好
后剪枝
1. 若是子树,返回True
def isTree(obj):
istree = (type(obj).__name__ == 'dict')
return istree
2. 合并叶结点,返回平均值
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
merge = (tree['right']+tree['left'])/2.0
return merge
3. 剪枝,判断合并后的误差是否比合并前更小
def prune(tree,testData):
#没有测试数据则对树进行坍塌处理:返回所有叶节点的均值
if testData.shape[0] == 0:
return getMean(tree)
#若有子树,则划分
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
#左子树若有子树,则剪枝
if isTree(tree['left']):
tree['left'] = prune(tree['left'],lSet)
#右子树若有子树,则剪枝
if isTree(tree['right']):
tree['right'] = prune(tree['right'],lSet)
#若是叶节点,则计算合并之后的误差和合并之前的误差
if not isTree(tree['right']) and not isTree(tree['left']):
lSet, rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
#合并之前的总方差
errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) + sum(np.power(rSet[:,-1] - tree['right'],2))
#均值
treeMean = (tree['right']+tree['left'])/2.0
#合并之后的总方差
errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
#对比误差
if errorMerge < errorNoMerge:
print('merging')
return treeMean
else:
return tree
else:
return tree
测试
myDat = loadDataSet('../../Reference Code/Ch09/ex2.txt')
myMat2 = np.mat(myDat)
myTree = createTree(myMat2,ops=(0,1))
print(myTree)
data2show(myMat2)
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': {'spInd': 0, 'spVal': 0.968621, 'left': 86.399637, 'right': 98.648346}, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': {'spInd': 0, 'spVal': 0.953902, 'left': {'spInd': 0, 'spVal': 0.954711, 'left': 82.016541, 'right': 100.935789}, 'right': 130.92648}}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.841547, 'left': {'spInd': 0, 'spVal': 0.841625, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': {'spInd': 0, 'spVal': 0.948822, 'left': {'spInd': 0, 'spVal': 0.949198, 'left': {'spInd': 0, 'spVal': 0.952377, 'left': 100.649591, 'right': 73.520802}, 'right': 105.752508}, 'right': 69.318649}, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.936524, 'left': {'spInd': 0, 'spVal': 0.937766, 'left': 100.120253, 'right': 119.949824}, 'right': {'spInd': 0, 'spVal': 0.934853, 'left': 65.548418, 'right': {'spInd': 0, 'spVal': 0.925782, 'left': 115.753994, 'right': {'spInd': 0, 'spVal': 0.910975, 'left': {'spInd': 0, 'spVal': 0.912161, 'left': {'spInd': 0, 'spVal': 0.915263, 'left': 92.074619, 'right': 96.71761}, 'right': 85.005351}, 'right': {'spInd': 0, 'spVal': 0.901444, 'left': {'spInd': 0, 'spVal': 0.908629, 'left': 106.814667, 'right': 118.513475}, 'right': {'spInd': 0, 'spVal': 0.901421, 'left': 87.300625, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': {'spInd': 0, 'spVal': 0.900699, 'left': 100.133819, 'right': {'spInd': 0, 'spVal': 0.896683, 'left': 109.188248, 'right': 107.00162}}, 'right': {'spInd': 0, 'spVal': 0.888426, 'left': 82.436686, 'right': {'spInd': 0, 'spVal': 0.872199, 'left': {'spInd': 0, 'spVal': 0.883615, 'left': {'spInd': 0, 'spVal': 0.885676, 'left': 94.896354, 'right': 108.045948}, 'right': {'spInd': 0, 'spVal': 0.872883, 'left': 95.348184, 'right': 95.887712}}, 'right': {'spInd': 0, 'spVal': 0.866451, 'left': 111.552716, 'right': {'spInd': 0, 'spVal': 0.856421, 'left': 94.402102, 'right': 107.166848}}}}}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.84294, 'left': {'spInd': 0, 'spVal': 0.847219, 'left': 89.20993, 'right': 76.240984}, 'right': 95.893131}}}, 'right': 60.552308}, 'right': {'spInd': 0, 'spVal': 0.838587, 'left': 115.669032, 'right': 134.089674}}, 'right': {'spInd': 0, 'spVal': 0.823848, 'left': 76.723835, 'right': {'spInd': 0, 'spVal': 0.819722, 'left': 59.342323, 'right': 70.054508}}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 118.319942, 'right': {'spInd': 0, 'spVal': 0.811363, 'left': 99.841379, 'right': 112.981216}}}, 'right': {'spInd': 0, 'spVal': 0.799873, 'left': 62.877698, 'right': {'spInd': 0, 'spVal': 0.798198, 'left': 91.368473, 'right': 76.853728}}}, 'right': {'spInd': 0, 'spVal': 0.786865, 'left': {'spInd': 0, 'spVal': 0.787755, 'left': 110.15973, 'right': 118.642009}, 'right': {'spInd': 0, 'spVal': 0.785574, 'left': 100.598825, 'right': {'spInd': 0, 'spVal': 0.777582, 'left': 107.024467, 'right': 100.838446}}}}, 'right': 64.041941}, 'right': 115.199195}, 'right': {'spInd': 0, 'spVal': 0.740859, 'left': {'spInd': 0, 'spVal': 0.757527, 'left': 81.106762, 'right': 63.549854}, 'right': {'spInd': 0, 'spVal': 0.731636, 'left': 93.773929, 'right': 73.912028}}}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.642373, 'left': {'spInd': 0, 'spVal': 0.642707, 'left': {'spInd': 0, 'spVal': 0.665329, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': {'spInd': 0, 'spVal': 0.70889, 'left': {'spInd': 0, 'spVal': 0.716211, 'left': 110.90283, 'right': {'spInd': 0, 'spVal': 0.710234, 'left': 103.345308, 'right': 108.553919}}, 'right': 135.416767}, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': {'spInd': 0, 'spVal': 0.69892, 'left': {'spInd': 0, 'spVal': 0.699873, 'left': {'spInd': 0, 'spVal': 0.70639, 'left': 106.180427, 'right': 105.062147}, 'right': 115.586605}, 'right': 92.470636}, 'right': {'spInd': 0, 'spVal': 0.689099, 'left': 120.521925, 'right': {'spInd': 0, 'spVal': 0.666452, 'left': {'spInd': 0, 'spVal': 0.667851, 'left': {'spInd': 0, 'spVal': 0.680486, 'left': 112.378209, 'right': 110.367074}, 'right': 92.449664}, 'right': {'spInd': 0, 'spVal': 0.665652, 'left': 120.014736, 'right': 105.547997}}}}}, 'right': {'spInd': 0, 'spVal': 0.661073, 'left': 121.980607, 'right': {'spInd': 0, 'spVal': 0.652462, 'left': 115.687524, 'right': 112.715799}}}, 'right': 82.500766}, 'right': 140.613941}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': {'spInd': 0, 'spVal': 0.623909, 'left': {'spInd': 0, 'spVal': 0.628061, 'left': {'spInd': 0, 'spVal': 0.637999, 'left': 82.713621, 'right': {'spInd': 0, 'spVal': 0.632691, 'left': 91.656617, 'right': 93.645293}}, 'right': {'spInd': 0, 'spVal': 0.624827, 'left': 117.628346, 'right': 105.970743}}, 'right': {'spInd': 0, 'spVal': 0.618868, 'left': 87.181863, 'right': 76.917665}}, 'right': {'spInd': 0, 'spVal': 0.606417, 'left': 168.180746, 'right': {'spInd': 0, 'spVal': 0.513332, 'left': {'spInd': 0, 'spVal': 0.533511, 'left': {'spInd': 0, 'spVal': 0.548539, 'left': {'spInd': 0, 'spVal': 0.553797, 'left': {'spInd': 0, 'spVal': 0.560301, 'left': {'spInd': 0, 'spVal': 0.599142, 'left': 93.521396, 'right': {'spInd': 0, 'spVal': 0.589806, 'left': 130.378529, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': {'spInd': 0, 'spVal': 0.585413, 'left': 98.674874, 'right': 125.295113}, 'right': {'spInd': 0, 'spVal': 0.571214, 'left': 82.589328, 'right': {'spInd': 0, 'spVal': 0.569327, 'left': 114.872056, 'right': 108.435392}}}}}, 'right': 82.903945}, 'right': {'spInd': 0, 'spVal': 0.549814, 'left': 120.857321, 'right': 137.267576}}, 'right': {'spInd': 0, 'spVal': 0.546601, 'left': 83.114502, 'right': {'spInd': 0, 'spVal': 0.537834, 'left': {'spInd': 0, 'spVal': 0.543843, 'left': 96.319043, 'right': 98.36201}, 'right': 90.995536}}}, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.531944, 'left': 129.766743, 'right': 124.795495}, 'right': 116.176162}}, 'right': {'spInd': 0, 'spVal': 0.508548, 'left': 101.075609, 'right': {'spInd': 0, 'spVal': 0.508542, 'left': 93.292829, 'right': 96.403373}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.465561, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': {'spInd': 0, 'spVal': 0.483803, 'left': {'spInd': 0, 'spVal': 0.487381, 'left': {'spInd': 0, 'spVal': 0.487537, 'left': 11.924204, 'right': 5.149336}, 'right': 27.729263}, 'right': 5.224234}, 'right': {'spInd': 0, 'spVal': 0.46568, 'left': -9.712925, 'right': -23.777531}}, 'right': {'spInd': 0, 'spVal': 0.463241, 'left': 30.051931, 'right': 17.171057}}, 'right': {'spInd': 0, 'spVal': 0.455761, 'left': -34.044555, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.130626, 'left': {'spInd': 0, 'spVal': 0.382037, 'left': {'spInd': 0, 'spVal': 0.388789, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': {'spInd': 0, 'spVal': 0.454312, 'left': {'spInd': 0, 'spVal': 0.454375, 'left': 9.841938, 'right': 3.043912}, 'right': {'spInd': 0, 'spVal': 0.446196, 'left': {'spInd': 0, 'spVal': 0.451087, 'left': -20.360067, 'right': -28.724685}, 'right': -5.108172}}, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': {'spInd': 0, 'spVal': 0.418943, 'left': {'spInd': 0, 'spVal': 0.426711, 'left': {'spInd': 0, 'spVal': 0.428582, 'left': 19.745224, 'right': 15.224266}, 'right': -21.594268}, 'right': 44.161493}, 'right': {'spInd': 0, 'spVal': 0.403228, 'left': -26.419289, 'right': {'spInd': 0, 'spVal': 0.391609, 'left': -1.729244, 'right': 3.001104}}}}, 'right': {'spInd': 0, 'spVal': 0.385021, 'left': 21.578007, 'right': 24.816941}}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.370042, 'left': {'spInd': 0, 'spVal': 0.378965, 'left': -29.007783, 'right': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.377383, 'left': 13.583555, 'right': 5.241196}, 'right': -8.228297}}, 'right': {'spInd': 0, 'spVal': 0.35679, 'left': -32.124495, 'right': {'spInd': 0, 'spVal': 0.350725, 'left': {'spInd': 0, 'spVal': 0.351478, 'left': -19.526539, 'right': -0.461116}, 'right': {'spInd': 0, 'spVal': 0.350065, 'left': -40.086564, 'right': {'spInd': 0, 'spVal': 0.342761, 'left': -1.319852, 'right': {'spInd': 0, 'spVal': 0.342155, 'left': -31.584855, 'right': {'spInd': 0, 'spVal': 0.3417, 'left': -16.930416, 'right': -23.547711}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': {'spInd': 0, 'spVal': 0.32889, 'left': {'spInd': 0, 'spVal': 0.331364, 'left': {'spInd': 0, 'spVal': 0.3349, 'left': 2.768225, 'right': 18.97665}, 'right': -1.290825}, 'right': 39.783113}, 'right': {'spInd': 0, 'spVal': 0.309133, 'left': {'spInd': 0, 'spVal': 0.310956, 'left': {'spInd': 0, 'spVal': 0.318309, 'left': -13.189243, 'right': -27.605424}, 'right': -49.939516}, 'right': {'spInd': 0, 'spVal': 0.131833, 'left': {'spInd': 0, 'spVal': 0.138619, 'left': {'spInd': 0, 'spVal': 0.156067, 'left': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.193282, 'left': {'spInd': 0, 'spVal': 0.211633, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': {'spInd': 0, 'spVal': 0.284794, 'left': {'spInd': 0, 'spVal': 0.300318, 'left': 8.814725, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -18.051318, 'right': {'spInd': 0, 'spVal': 0.295993, 'left': -1.798377, 'right': {'spInd': 0, 'spVal': 0.290749, 'left': -14.988279, 'right': -14.391613}}}}, 'right': {'spInd': 0, 'spVal': 0.273863, 'left': 35.623746, 'right': {'spInd': 0, 'spVal': 0.264926, 'left': -9.457556, 'right': {'spInd': 0, 'spVal': 0.264639, 'left': 5.280579, 'right': 2.557923}}}}, 'right': {'spInd': 0, 'spVal': 0.228628, 'left': {'spInd': 0, 'spVal': 0.228751, 'left': {'spInd': 0, 'spVal': 0.232802, 'left': -20.425137, 'right': 1.222318}, 'right': -30.812912}, 'right': -2.266273}}, 'right': {'spInd': 0, 'spVal': 0.222271, 'left': {'spInd': 0, 'spVal': 0.2232, 'left': 19.425158, 'right': 15.501642}, 'right': {'spInd': 0, 'spVal': 0.218321, 'left': -9.255852, 'right': {'spInd': 0, 'spVal': 0.217214, 'left': 1.410768, 'right': -3.958752}}}}, 'right': {'spInd': 0, 'spVal': 0.202161, 'left': {'spInd': 0, 'spVal': 0.203993, 'left': {'spInd': 0, 'spVal': 0.206207, 'left': -8.332207, 'right': -12.619036}, 'right': -22.379119}, 'right': {'spInd': 0, 'spVal': 0.199903, 'left': -1.983889, 'right': -3.372472}}}, 'right': {'spInd': 0, 'spVal': 0.176523, 'left': 18.208423, 'right': 0.946348}}, 'right': {'spInd': 0, 'spVal': 0.156273, 'left': {'spInd': 0, 'spVal': 0.164134, 'left': {'spInd': 0, 'spVal': 0.166431, 'left': -14.740059, 'right': -6.512506}, 'right': -27.405211}, 'right': 0.225886}}, 'right': {'spInd': 0, 'spVal': 0.13988, 'left': 7.557349, 'right': 7.336784}}, 'right': -29.087463}, 'right': 22.478291}}}}}, 'right': -39.524461}, 'right': {'spInd': 0, 'spVal': 0.124723, 'left': 22.891675, 'right': {'spInd': 0, 'spVal': 0.085111, 'left': {'spInd': 0, 'spVal': 0.108801, 'left': {'spInd': 0, 'spVal': 0.11515, 'left': -1.402796, 'right': 13.795828}, 'right': {'spInd': 0, 'spVal': 0.10796, 'left': -16.106164, 'right': {'spInd': 0, 'spVal': 0.085873, 'left': -1.293195, 'right': -10.137104}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 37.820659, 'right': {'spInd': 0, 'spVal': 0.080061, 'left': -24.132226, 'right': {'spInd': 0, 'spVal': 0.068373, 'left': {'spInd': 0, 'spVal': 0.079632, 'left': 2.229873, 'right': 29.420068}, 'right': {'spInd': 0, 'spVal': 0.061219, 'left': -15.160836, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': {'spInd': 0, 'spVal': 0.053764, 'left': {'spInd': 0, 'spVal': 0.055862, 'left': 6.695567, 'right': -3.131497}, 'right': -13.731698}, 'right': {'spInd': 0, 'spVal': 0.028546, 'left': {'spInd': 0, 'spVal': 0.039914, 'left': 3.855393, 'right': 11.220099}, 'right': {'spInd': 0, 'spVal': 0.000256, 'left': -8.377094, 'right': 9.668106}}}}}}}}}}}}}
output_39_1.png
- 叶节点太多,使用后剪枝
myDatTest = loadDataSet('../../Reference Code/Ch09/ex2test.txt')
myDatTest = np.mat(myDatTest)
prune(myTree,myDatTest)
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
{'spInd': 0,
'spVal': 0.499171,
'left': {'spInd': 0,
'spVal': 0.729397,
'left': {'spInd': 0,
'spVal': 0.952833,
'left': {'spInd': 0,
'spVal': 0.965969,
'left': 92.5239915,
'right': {'spInd': 0,
'spVal': 0.956951,
'left': {'spInd': 0,
'spVal': 0.958512,
...
模型树
- 把叶节点设定为分段线性函数
#求解岭回归
def linearSolve(dataSet):
m,n = dataSet.shape
X = np.mat(np.ones((m,n)))
Y = np.mat(np.ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]
Y = dataSet[:,-1]
xTx = X.T*X
if np.linalg.det(xTx) == 0:
raise NameError('This matrix is singular,cannot do inverse,\n try increasing the scond value of ops')
ws = xTx.I*(X.T*Y)
return ws,X,Y
#生成叶节点
def modelLeaf(dataSet):
ws,X,Y = linearSolve(dataSet)
return ws
#计算误差
def modelErr(dataSet):
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
err = sum(np.power(Y-yHat,2))
return err
myDat2 = loadDataSet('../../Reference Code/Ch09/exp2.txt')
myDat2 = np.mat(myDat2)
myTree = createTree(myDat2,leafType = modelLeaf, errType=modelErr,ops=(1,10))
print(myTree)
data2show(myDat2)
{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
[1.19647739e+01]]), 'right': matrix([[3.46877936],
[1.18521743]])}
output_44_1.png
- 可以看到在'spVal': 0.285477创建了两个模型,在图上也可以看到样本分布在x1=0.28分段
- y1 = 3.468 + 1.1852x
- y2 = 0.0016985 + 11.964x
树回归和标准回归的比较
#回归树的叶子节点预测值,model是一个常数(均值)
def regTreeEval(model,inDat):
return float(model)
#模型树的叶子节点预测值,model是一个ws矩阵
def modelTreeEval(model,inDat):
n = inDat.shape[1]
X = np.mat(np.ones((1,n+1)))
#格式化处理,添加了全1的x0列
X[:,1:n+1] = inDat
return float(X*model)
#自顶向下遍历整棵树,直到命中叶节点为止
def treeForeCast(tree,inData,modelEval = regTreeEval):
#若是叶节点,则返回预测值
if not isTree(tree):
return modelEval(tree,inData)
#先遍历左子树
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):#存在子树,则往子树走,直到找到叶节点
return treeForeCast(tree['left'],inData,modelEval)
else:#叶节点则返回预测值
return modelEval(tree['left'],inData)
#再遍历右子树
else:
if isTree(tree['right']):#往下走
return treeForeCast(tree['right'],inData,modelEval)
else:#返回预测值
return modelEval(tree['right'],inData)
#遍历每一个测试样本,返回预测值
def createForeCast(tree,testData,modelEval = regTreeEval):
m = len(testData)
yHat= np.mat(np.zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree,np.mat(testData[i]),modelEval)
return yHat
trainMat = np.mat(loadDataSet('../../Reference Code/Ch09/bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('../../Reference Code/Ch09/bikeSpeedVsIq_test.txt'))
trainMat[0:10,:]
matrix([[ 3. , 46.852122],
[ 23. , 178.676107],
[ 0. , 86.154024],
[ 6. , 68.707614],
[ 15. , 139.737693],
[ 17. , 141.988903],
[ 12. , 94.477135],
[ 8. , 86.083788],
[ 9. , 97.265824],
[ 7. , 80.400027]])
testMat[0:10,:]
matrix([[ 12. , 121.010516],
[ 19. , 157.337044],
[ 12. , 116.031825],
[ 15. , 132.124872],
[ 2. , 52.719612],
[ 6. , 39.058368],
[ 3. , 50.757763],
[ 20. , 166.740333],
[ 11. , 115.808227],
[ 21. , 165.582995]])
1. 构建回归树
myTree = createTree(trainMat,leafType = regLeaf, errType=regErr,ops=(1,20))
print(myTree)
data2show(trainMat)
data2show(testMat)
{'spInd': 0, 'spVal': 10.0, 'left': {'spInd': 0, 'spVal': 17.0, 'left': {'spInd': 0, 'spVal': 20.0, 'left': 168.34161286956524, 'right': 157.0484078846154}, 'right': {'spInd': 0, 'spVal': 14.0, 'left': 141.06067981481482, 'right': 122.90893026923078}}, 'right': {'spInd': 0, 'spVal': 7.0, 'left': 94.7066578125, 'right': {'spInd': 0, 'spVal': 5.0, 'left': 69.02117757692308, 'right': 50.94683665}}}
output_52_1.png
output_52_2.png
x=testMat[:,0]
y=testMat[:,1]
yHat = createForeCast(myTree,testMat[:,0])
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
0.9640852318222141
2. 构建模型树
myTree = createTree(trainMat,leafType = modelLeaf, errType=modelErr,ops=(1,20))
print(myTree)
yHat = createForeCast(myTree,testMat[:,0],modelEval=modelTreeEval)
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
{'spInd': 0, 'spVal': 4.0, 'left': {'spInd': 0, 'spVal': 12.0, 'left': {'spInd': 0, 'spVal': 16.0, 'left': {'spInd': 0, 'spVal': 20.0, 'left': matrix([[47.58621512],
[ 5.51066299]]), 'right': matrix([[37.54851927],
[ 6.23298637]])}, 'right': matrix([[43.41251481],
[ 6.37966738]])}, 'right': {'spInd': 0, 'spVal': 9.0, 'left': matrix([[-2.87684083],
[10.20804482]]), 'right': {'spInd': 0, 'spVal': 6.0, 'left': matrix([[-11.84548851],
[ 12.12382261]]), 'right': matrix([[-17.21714265],
[ 13.72153115]])}}}, 'right': matrix([[ 68.87014372],
[-11.78556471]])}
0.9760412191380593
3. 普通线性回归
ws,X,Y=linearSolve(trainMat)
ws
matrix([[37.58916794],
[ 6.18978355]])
m = testMat.shape[0]
for i in range(m):
yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]
np.corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] #rowvar=0,表示每一列是一个向量
0.9434684235674763
- 模型树>回归树>普通线性回归(岭回归)
4. 集成Matplotlib和Tkinter
#在windows下运行
import tkinter as tk
root = tk.Tk()
myLabel = tk.Label(root, text='Hello,World')
myLabel.grid()
root.mainloop()
from numpy import *
from tkinter import *
#import matplotlib
#matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS,tolN):
reDraw.f.clf() # clear the figure
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
else:
myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
yHat = createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0].flatten().A[0], reDraw.rawDat[:,1].flatten().A[0], s=5) #use scatter for data set
reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
reDraw.canvas.show()
def getInputs():
try: tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0, END)
tolNentry.insert(0,'10')
try: tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0, END)
tolSentry.insert(0,'1.0')
return tolN,tolS
def drawNewTree():
tolN,tolS = getInputs()#get values from Entry boxes
reDraw(tolS,tolN)
root=Tk()
reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
Label(root, text="tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
root.mainloop()
回归树.png
模型树.png