决策树算法实战

2020-06-12  本文已影响0人  程南swimming

一、想要掌握一个算法分以下五步走:

  1. 理论推导 看书+手推+笔记

  2. 该算法的适用场景,解决过拟合,损失函数,优缺点,与其他方法比较,分布式计算方法,复杂度等

  3. 算法实现代码 参考做课后题

  4. kaggle找数据集跑自己的代码

  5. 调包+调参指南

二、决策树优缺点:

优点:1. 逻辑简单,可解释性强

  1. 计算量不大,可以跑大量数据源

缺点:

  1. 容易过拟合

  2. 忽略了数据之间的相关性

三、算法实现代码

1. 数据预处理

data=np.loadtxt()
m,n=np.shape(data)
D=np.arange(0,m,1)
A=[1 for i in range(n)] #如果使用过了标记为-1
A[-1]=-1
attri_list=[{},{},{}] #每个属性的可取属性值
attri_name=[] #每个属性的名称集合

2. 根据理论推导的边界值写辅助函数

class TreeNode(object):
    def __init__(self,title):
        self.val=None #val代表对哪个属性进行测试
        self.childnode=[] #子节点
        self.title=title #属于上游的哪个属性
def isDsameCategory(D):

def isDsameInA(D,A):

def isABlank(A):

def mostCateinD(D):
def calEnt(Dv):
    dic={}
    for i in Dv:
        if data[i,n-1] not in dic:
            dic[data[i,n-1]]=1
        else:
            dic[data[i, n - 1]] += 1
    res=0
    for k in dic:
        res-=(dic[k]/len(Dv))*np.log2(dic[k]/len(Dv))
    return res
        如果是增益率或者基尼指数要改为:
            ans -= (dic[k]/len(Dv))**2
return res

def calEnt(Dv,x,j): #计算离散属性的信息增益

def calFloatEnt(Dv,j): #计算连续属性的信息增益

连续属性的处理:

  1. 找划分点 原有n个取值{a1,a2,...,an} 先求每两个值的中间点{b1,b2,...,bn-1},然后把这n-1个值当做成离散属性,计算每一个分界点处的信息增益,取最大的分界点max_partion
for x in partion:
       tmp=calDigitEnt(Dv,x,j)
       if tmp>max_partion_gain:
               max_partion =x

3. 生成决策树

def TreeGenerate(Dv,Av,title):
    node=TreeNode(title)
    if isDsameCategory(Dv):
        node.val=data[Dv[0],n-1]
        return node
    if isAnull(A) or isDsameInA(Dv,Av):
        node.val=mostCateinD(Dv) #val 类别
        return node
    choose_attri=calGain(Dv,Av)
    Av[choose_attri]=-1
    node.childnode=[]
    if "." in str(attri[choose_attri]):
        partion = attri[choose_attri][0]
        node.val = attri_name[choose_attri]+"<"+str(partion)+"?"
        left=[]
        right=[]
        for i in Dv:
            if float(data[i,choose_attri])<partion:
                left.append(i)
            else:
                right.append(i)
        node.childnode.append(TreeGenerate(left, Av, "<"+str(partion)))
        node.childnode.append(TreeGenerate(right, Av, ">" + str(partion)))
    else:
        node.val =attri_name[choose_attri]+ "=?"
        for a in attri[choose_attri]:
            l=[]
            for i in Dv:
                if data[i,choose_attri]==a:
                    l.append(i)
            if len(l)==0:
                node1=TreeNode(a)
                node1.val=mostCateinD(Dv) #这里出了点小问题 因为把Dv写成了l
                node.childnode.append(node1)
            else:
                node.childnode.append(TreeGenerate(l,Av,attri_name[choose_attri]+str(a)))
    return node

对泰坦尼克号数据的处理:(在以上代码的基础上只用把pd的数据变为np的数据)

data=np.array(rf)
rf["Embarked"]=rf["Embarked"].fillna('S')
rf["Age"]=rf["Age"].fillna(rf["Age"].mean())
rf.loc[rf["Sex"]=="male","Sex"]=0
rf.loc[rf["Sex"]=="female","Sex"]=1
rf.loc[rf["Embarked"]=="S","Embarked"]=0
rf.loc[rf["Embarked"]=="C","Embarked"]=1
rf.loc[rf["Embarked"]=="Q","Embarked"]=2
del rf["PassengerId"]
del rf["Ticket"]
del rf["Cabin"]
del rf["Name"]
attri_name=list(rf.columns)
image.png

踩坑记:

  1. title和val分不清

注意treenode的定义,title代表标题:代表属于哪个属性 如敲声为清晰 是已知的

val 叶结点:对什么属性测试 根结点 类别 是未知的以及有结果的

  1. 不好判断哪个属性是连续属性,这里用的是是否包含小数点判断,因为连续属性给的是float类型

  2. python里面“”+x,对象x必须是字符串,其他类型的要先用str转

str(float) str(set) str(dic)

to do

预剪枝 后剪枝
测试集怎么在已有的决策数上跑

上一篇 下一篇

猜你喜欢

热点阅读