树回归(一)

2019-11-06  本文已影响0人  RossH

概述

CART是树构建算法,使用二元切分来处理连续型变量。基于ID3算法的决策树使用香农熵来度量集合的无序程度,如果用其他方法来代替香农熵,就可以用树构建算法来完成回归。

树的构建

在树的构建过程中,需要解决多种类型数据的存储问题。这里将使用字典来存储树的数据结构,该字典包含以下4个元素。

后面将构建两种树:

import numpy as np

# 加载数据集
def loadDataSet(fileName): 
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        # 将每行数据都映射成浮点数
        fltLine = map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

def splitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, regErr, ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    left, right = splitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(left, leafType, errType, ops)
    retTree['right'] = createTree(right, leafType, errType, ops)
    return retTree

splitDataSet函数在给定特征和特征值的情况下,通过数组过滤方式将数据集切分得到两个子集返回。
树构建函数createTree有4个参数,数据集;leafType给出建立叶节点的函数;errType代表误差计算函数;ops元组带有两个值,一个是容许的误差下降值,一个是切分的最少样本数

树的构建最重要的是找到最佳的划分点。用决策树进行分类,会计算数据集的熵,用来度量数据的混乱度。那么连续型数值混乱度如何计算?

首先是计算所有数据的均值,然后计算每条数据到均值的差值的平方。这有点儿类似于方差计算。唯一不同的是,方差是平方误差的均值(均方差),而这里是平方误差的总值(总方差)。总方差可以通过均方差乘以数据集样本个数得到。

# 生成叶节点
def regLeaf(dataSet):
    return np.mean(dataSet[:,-1])

# 误差估计
def regErr(dataSet):
    return np.var(dataSet[:,-1]) * dataSet.shape[0]

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
    tolS = ops[0]
    tolN = ops[1]
    # 如果所有值相等,则退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = dataSet.shape
    S = errType(dataSet)
    bestS = np.inf
    bestFeat = 0
    bestVal = 0
    for featIndex in range(n-1):                                    # 遍历所有特征
        for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):   # 遍历所有特征值
            mat0, mat1 = splitDataSet(dataSet, featIndex, splitVal)
            if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestFeat = featIndex
                bestVal = splitVal
                bestS = newS
    # 如果误差减少不大则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    # 如果切分出的数据集很小则退出
    if (mat0.shape[0] < tolN) or (mat1.shape[0] < tolN):
        return None, leafType(dataSet)
    return bestFeat, bestVal

regLeaf函数负责生成叶节点,在回归树中就是目标变量的均值。
chooseBestSplit是回归树构建的核心函数,目的是找到最佳二元划分方式。其参数ops设定了tolNtolS两个值,用于控制函数的停止时机。
执行createTree查看效果。

dataSet = loadDataSet('ex00.txt')
myMat = np.mat(dataSet)
createTree(myMat)

结果如下:

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

数据分布如下


再看一个多次切分的例子。数据分布如下。


构建回归树。

dataSet = loadDataSet('ex0.txt')
myMat = np.mat(dataSet)
createTree(myMat)

结果如下。

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}
上一篇 下一篇

猜你喜欢

热点阅读