树回归(二)

2019-11-07  本文已影响0人  RossH

树剪枝

一棵树如果节点过多,说明该模型存在过拟合问题。

通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。树回归(一)中的chooseBestSplit函数中的提前终止条件,实际上是一种预剪枝(prepruning)操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。

预剪枝的不足

树回归(一)中的树构建算法对参数tolStolN非常敏感,下面用树回归(一)中的第一个数据集,采用不同的ops参数,来观察结果。

dataSet = loadDataSet('ex00.txt')
myMat = np.mat(dataSet)
createTree(myMat, ops=(0,1))

结果如下:

{'spInd': 0,
 'spVal': 0.48813,
 'left': {'spInd': 0,
  'spVal': 0.620599,
  'left': {'spInd': 0,
 ......
  'right': {'spInd': 0,
   'spVal': 0.325412,
   'left': {'spInd': 0, 'spVal': 0.3371, 'left': 0.1910235, 'right': 0.118208},
   'right': -0.028594120689655174}}}

由于输出过长,这里省略部分内容。

与上文中只包含两个节点的树相比,这里构建的树过于臃肿。

下面用一个与ex00.txt数据集分布类似,但y轴数量级是其100倍的ex2.txt数据集来构建树。

dataSet = loadDataSet('ex2.txt')
myMat2 = np.mat(dataSet)
createTree(myMat2)

结果如下

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': 108.838789625,
 ......
 'right': {'spInd': 0,
  'spVal': 0.457563,
  'left': 7.969946125,
  'right': -3.6244789069767447}}

用默认参数构建的树显得比较臃肿。下面是其分布。


ex00.txtex2.txt两个数据集分布类似,但在都采用默认参数的情况下,ex00.txt构建的树只有两个叶节点,而ex2.txt却有很多。产生这种现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项上花费时间并对上述误差容忍度取平方值,也能得到两个叶节点的树:

createTree(myMat2, ops=(10000, 4))

output:
{'spInd': 0,
 'spVal': 0.499171,
 'left': 101.35815937735848,
 'right': -2.637719329787234}

然而,通过不断修改参数来得到合理结果并不是很好的办法。

下面将介绍后剪枝,利用测试集来对树进行剪枝,并不需要指定参数,是一种更理想化的剪枝方法。

后剪枝

剪枝函数prune()的伪代码如下:

基于已有的树切分测试数据:
    如果存在任意子集是一棵树,则在该子集递归剪枝
    计算将当前两个叶节点合并后的误差
    计算不合并的误差
    如果合并会降低误差的话,就将叶节点合并
def isTree(obj):
    return (type(obj).__name__ == 'dict')

def getMean(tree):
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    return (tree['left'] + tree['right'])/2

def prune(tree, testData):
    # 没有测试数据则对树进行塌陷处理
    if testData.shape[0] == 0: 
        return getMean(tree)
    
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): 
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): 
        tree['right'] =  prune(tree['right'], rSet)
    
    # 如果都是叶节点,看能不能合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = splitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(np.power(lSet[:,-1] - tree['left'],2)) +\
            sum(np.power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(np.power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            print("merging")
            return treeMean
        else: return tree
    else: return tree

isTree()判断是否为树。
getMean()从上往下遍历树直到叶节点为止。该函数对树进行塌陷处理,即返回树平均值。
接下来看看实际效果。

# 构建一个过拟合的树
myTree = createTree(myMat2, ops=(0,1))
# 加载测试集
testData = loadDataSet('ex2test.txt')
testMat = np.mat(testData)
# 剪枝
prune(myTree, testMat)

运行后观察两棵树,可以发现大量节点被剪枝掉,但没有预期那样剪枝成两部分。
一般地,为了寻求最佳模型可以同时使用两种剪枝技术。

上一篇 下一篇

猜你喜欢

热点阅读