香农熵-计算最好的数据集划分方式

2018-07-20  本文已影响0人  ButtersoO

对于以下数据

data = [
# 不浮出水面是否可以生存, 有脚蹼, 属于鱼类
        1, 1, 'yes',  # 究竟什么鱼满足这个条件我很好奇
        1, 1, 'yes',  # 
        1, 0, 'no',  # 海胆
        0, 1, 'no',  # 鸭子
        0, 1, 'no' ] # 企鹅

计算出 “不浮出水面是否可以生存”与“有脚蹼”这两个特征值,哪个与是否属于鱼类更相关。
以下算法的原理是:拿出指定的特征值,计算剩下的数据的熵,熵越大,也就是数据越混乱,说明被拿出的数据越重要
代码如下

# -*- encoding:utf-8 -*-
import math
import numpy as np
__author__ = 'Butters'


def get_gain(p):
    """
    信息增益值
    """
    return -math.log(p, 2)


def get_ent(*p):
    """
    熵
    """
    return sum([i * get_gain(i) for i in p])


def test():
    data = [
        # 不浮出水面是否可以生存, 有脚蹼, 属于鱼类
        1, 1, 'yes',  # 究竟什么鱼满足这个条件我很好奇
        1, 1, 'yes',  # 
        1, 0, 'no',  # 海胆
        0, 1, 'no',  # 鸭子
        0, 1, 'no' ] # 企鹅
    dataset = np.reshape(data, (5, 3))
    chooseBestFeatureToSplit(dataset)


def calcShannonEnt(dataset):
    print '=====>data set is'
    print dataset
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        feat = featVec[-1]
        if feat not in labelCounts:
            labelCounts[feat] = 0
        labelCounts[feat] += 1
    shannonEnt = get_ent(*[float(labelCounts[key]) / float(numEntries) for key in labelCounts])
    print 'shannon ent is ', shannonEnt
    return shannonEnt


def splitDataset(dataset, axis, value):
    """
    获取dataset里的第axis轴值等于 value的
    :param dataset:
    :param axis:第axis列特征值
    :param value:
    :return:
    """
    m = np.array([row for row in dataset if row[axis] == value])
    return np.delete(m, axis, axis=1)


def chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1  # 2
    baseEntries = calcShannonEnt(dataset)
    print 'base entries is', baseEntries
    bestInfoGain = 0.0
    bestFeature = -1
    for i in xrange(numFeatures):
        featList = [example[i] for example in dataset]
        featSet = set(featList)  # 获取取值范围
        print 'feat set is', featSet
        tempEntry = 0.0
        for value in featSet:
            subDataset = splitDataset(dataset, i, value)
            prob = float(len(subDataset)) / float(len(dataset))
            tempEntry += (prob * calcShannonEnt(subDataset))
        infoGain = baseEntries - tempEntry
        print '-----------after calculate ----------'
        print tempEntry
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    print 'best feature is', bestFeature
    print 'best info gain is', bestInfoGain


if __name__ == '__main__':
    test()

运行结果:

best feature is 0
best info gain is 0.419973094022

所以第0个特征值是我们要的。
我们可以通过简单的逻辑来验证一下,5个例子中:

以上部分代码来源于《机器学习实战》,略有简化与修改

上一篇 下一篇

猜你喜欢

热点阅读