机器学习

机器学习之决策树ID3/C4.5/CART学习算法比较

2019-01-13  本文已影响0人  郑壮杰
一、概述
二、决策树学习算法
2.1 相亲数据集
编号 年龄 长相 工资 编程 类别
1 不会 不见
2 年轻 一般 中等
3 年轻 不会 不见
4 年轻 一般
5 年轻 一般 不会 不见
2.2 ID3(Iterative Dichotomiser 3)--信息增益

  信息熵(information entropy)是度量样本集合纯度最常用的一种指标。对于样本集合D,类别数为K,信息熵定义为:
H(D)=-\sum_{k=1}^K\frac{|C_k|}{|D|}\log_2\frac{|C_k|}{|D|}
  其中,C_k是样本集合D中属于第k(k=1,2,3...K)类的样本子集,|C_k|表示该子集的元素个数,|D|表示样本集合的元素个数。特征属性A的条件熵H(D|A)定义为:
H(D|A)=\sum_{i=1}^n\frac{|D_i|}{|D|}H(D_i)=\sum_{i=1}^n\frac{|D_i|}{|D|}(-\sum_{k=1}^K\frac{|D_{ik}|}{|D_i|}\log_2\frac{|D_{ik}|}{|D_i|})
  其中,|D_i|表示D中特征A取第i个值的样本子集,D_{ik}表示D_i中属于第k类的样本子集。因此,特征A的信息增益等于两者之差:
g(D,A)=H(D)-H(D|A)
  以表2.1中相亲数据为例,该数据包含5个样本集,正样本占比\frac{2}{5},负样本占比\frac{3}{5}。于是,根据公式计算出根节点的信息熵为:
H(D)=-\frac{3}{5}\log_2\frac{3}{5}-\frac{2}{5}\log_2\frac{2}{5}=0.971
  然后,计算当前属性集合{年龄,长相,工资,编程}中每个属性的条件熵。以属性“年龄”为例,它有2个可能的取值:{老,年轻}。若使用该属性对D进行划分,则可得到2个子集,分别为:D_1(年龄=老),D_2(年龄=年轻)。

信息熵和条件熵计算的scala-spark实现:

/**
    * 计算信息熵
    *
    * @param df
    * @param column
    * @return
    */
  def calculate(df: DataFrame, column: String = "label"): Double = {
    val counts = df.select(column).groupBy(column).agg(count(column)).collect().map(row => row.getLong(1))
    val totalCount = counts.sum.toDouble
    if (totalCount == 0) {
      return 0
    }
    counts.map {
      count =>
        var impurity = 0.0
        if (count != 0) {
          val freq = count / totalCount
          impurity -= freq * log2(freq)
        }
        impurity
    }.reduce((v1, v2) => v1 + v2)
  }

  /**
    * 计算每个特征的条件熵
    *
    * @param df
    * @param column
    * @return
    */
  def calculateFeature(df: DataFrame, column: String): Double = {
    val counts = df.select(column).groupBy(column).agg(count(column)).collect()
    val totalCount = counts.map(row => row.getLong(1)).sum.toDouble
    if (totalCount == 0) {
      return 0
    }
    val impurity = counts.map {
      row =>
        val featureValue = row.get(0).toString.toDouble
        val featureCount = row.getLong(1)
        val freq = featureCount / totalCount
        val tmp = df.filter(col(column) === featureValue)
        freq * calculate(tmp, "label")
    }.reduce((v1, v2) => v1 + v2)
    impurity
  }
2.3 C4.5--信息增益比

  特征A对于数据集D的信息增益比定义为:
g_R(D,A)=\frac{g(D,A)}{H_A(D)}
其中
H_A(D)=-\sum_{i=1}^n\frac{|D_i|}{|D|}\log_2\frac{|D_i|}{|D|}
称为数据集D关于A的取值熵。因此,可以根据上面的公式求出每个特征的取值熵:
H_{年龄}(D)=-\frac{1}{5}\log_2\frac{1}{5}-\frac{4}{5}\log_2\frac{4}{5}=0.722
H_{长相}(D)=-\frac{1}{5}\log_2\frac{1}{5}-\frac{3}{5}\log_2\frac{3}{5}-\frac{1}{5}\log_2\frac{1}{5}=1.371
H_{工资}(D)=-\frac{3}{5}\log_2\frac{3}{5}-\frac{1}{5}\log_2\frac{1}{5}-\frac{1}{5}\log_2\frac{1}{5}=1.371
H_{编程}(D)=-\frac{3}{5}\log_2\frac{3}{5}-\frac{2}{5}\log_2\frac{2}{5}=0.971
最终可计算出各个特征的信息增益比:
g_R(D,年龄)=\frac{g(D,年龄)}{H_{年龄}(D)}=\frac{0.171}{0.722}=0.2368
g_R(D,长相)=\frac{g(D,长相)}{H_{长相}(D)}=\frac{0.42}{1.371}=0.3063
g_R(D,工资)=\frac{g(D,工资)}{H_{工资}(D)}=\frac{0.42}{1.371}=0.3063
g_R(D,编程)=\frac{g(D,编程)}{H_{编程}(D)}=\frac{0.971}{0.971}=1
通过信息增益比,特征年龄对应的指标上升了,而特征长相和工资有所下降。

2.4 CART--基尼指数(Gini)

  CART决策树使用基尼指数来选择划分属性,Gini描述的是数据的纯度,其定义为:
Gini(D)=1-\sum_{k=1}^n(\frac{|C_k|}{|D|})^2
  其中,C_k是样本集合D中属于第k(k=1,2,3...K)类的样本子集,|C_k|表示该子集的元素个数,|D|表示样本集合的元素个数。如果所有样本都属于同一个类别,则|C_k|=|D|Gini(D)=0,此时impurity最小。CART利用基尼指数构造二叉决策树。如果特征是离散型变量,将样本按特征A的取值切分成两份;如果特征是连续型变量,CART的处理方式和C4.5相同,先将特征值进行升序排序,然后把左边第一个值(index=1)作为一个分类,右边其他值作为另一个分类,计算其Gini指数,然后移动index的位置,直到计算完所有的分类结果,然后选取Gini最小的位置对应的index作为切分点。特征A的Gini指数定义为:
Gini(D|A)=\sum_{i=1}^n\frac{|D_i|}{|D|}Gini(D_i)
当n=2时,该公式可以简化为:
Gini(D|A)=\frac{|D_1|}{|D|}Gini(D_1) + \frac{|D_2|}{|D|}Gini(D_2)

使用CART分类准则,选取年龄维度,把老作为特征标签,那么年轻就被划分到另外一类

老(总数=1) 年轻(总数=4)
类别 不见 见、不见、见、不见

Gini(D|年龄=老)=\frac{1}{5}(1-(\frac{1}{1})^2-(\frac{0}{1})^2)+\frac{4}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.4
Gini(D|年龄=年轻)=\frac{1}{5}(1-(\frac{1}{1})^2-(\frac{0}{1})^2)+\frac{4}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.4

帅(总数=1) 一般、丑(总数=4)
类别 不见 见、不见、见、不见

Gini(D|长相=帅)=\frac{1}{5}(1-(\frac{1}{1})^2-(\frac{0}{1})^2)+\frac{4}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.4
Gini(D|长相=丑)=\frac{1}{5}(1-(\frac{1}{1})^2-(\frac{0}{1})^2)+\frac{4}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.4

一般(总数=3) 帅、丑(总数=2)
类别 见、不见 不见

Gini(D|长相=一般)=\frac{3}{5}(1-(\frac{2}{3})^2-(\frac{1}{3})^2)+\frac{2}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.47

高(总数=3) 中等、低(总数=2)
类别 见、不见 见、不见

Gini(D|工资=高)=\frac{3}{5}(1-(\frac{2}{3})^2-(\frac{1}{3})^2)+\frac{2}{5}(1-(\frac{1}{2})^2-(\frac{1}{2})^2)=0.47

会(总数=2) 不会(总数=3)
类别 不见

Gini(D|编程=会)=\frac{2}{5}(1-(\frac{2}{2})^2)+\frac{3}{5}(1-(\frac{3}{3})^2)=0
因此,特征编程的Gini指数最小,选择该特征作为最优的切分点。

三、小结

  通过比较ID3、C4.5和CART三种决策树的构造准则,在同一个样本集上,表现出不同的划分行为。

参考文献:
[1]诸葛越,葫芦娃.百面机器学习
[2]周志华.机器学习

上一篇下一篇

猜你喜欢

热点阅读