生活不易 我用pythonPython语言与信息数据获取和机器学习机器学习与数据挖掘

使用决策树ID3算法,预测收入是否大于50k

2017-04-04  本文已影响0人  来个芒果

数据集是在UCI上下载的:http://archive.ics.uci.edu/ml/datasets/Adult

算法比较简单,没有涉及剪枝和限制树深,懒得写.. 实现完算法在数据集上跑了一遍,发现数据集太大导致栈溢出,所以只能预测部分数据了,谁知道比较好的优化办法欢迎交流~

数据:


列描述信息


代码:

import pandas as pd 
import numpy as np
import math
import sys
col_names=['age','workclass','fnlwgt','education','education_num',
'marital_status','occupation','relationship','race','sex',
'capital_gain','capital_loss','hours_per_week','native_country','high_income']
income=pd.read_table('./data/income.data',sep=',',names=col_names)


#sys.setrecursionlimit(1000) 尝试用sys解决溢出,无效


#处理数据
columns=['workclass','education','marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country','high_income']
for name  in columns:
    col=pd.Categorical.from_array(income[name])
    income[name]=col.codes

#Splitting data
private_incomes=income[income['workclass']==4]
public_incomes=income[income['workclass']!=4]

#Calculating_entropy
def calc_entropy(column):
    counts=np.bincount(column)
    probabilities=counts/len(column)
    entropy=0
    for prob in probabilities:
        if prob>0:
            entropy+=prob*math.log(prob,2)
    return -entropy

#Calculating information_gain
def calc_information_gain(data,split_name,target_name):
    original_entropy=calc_entropy(data[target_name])
    column=data[split_name]
    median=column.median()
    
    left_split=data[column<=median]
    right_split=data[column>median]
    
    to_subtract=0
    for subset in [left_split,right_split]:
        prob=(subset.shape[0]/data.shape[0])
        to_subtract+=prob*calc_entropy(subset[target_name])
    return original_entropy - to_subtract   

#Finding best split column
def find_best_column(data,columns,target_column):

    information_gains=[]
    for col in columns:
        information_gains.append(calc_information_gain(data,col,target_column))
    highest_gain=columns[information_gains.index(max(information_gains))]
    return highest_gain


#Constructing DecisionTree-using id3 algorithm and storing it .
def id3(data,columns,target,tree):
    unique_targets=pd.unique(data[target])
    
    nodes.append(len(nodes)+1)
    tree['number']=nodes[-1]
    if len(unique_targets)==1 :
        tree['label']=unique_targets[0]
        return tree  
    
    best_column=find_best_column(data,columns,target)
    column_median=data[best_column].median()
    
    tree['column']=best_column
    tree['median']=column_median
    
    left_split=data[data[best_column] <= column_median]
    right_split=data[data[best_column] > column_median]

    split_dict=[["left",left_split],["right",right_split]]
    for name,split in split_dict:
        tree[name]={}
        id3(split,columns,target,tree[name])


#Printing a more attractive tree
def print_with_depth(string,depth):
    prefix="   "*depth
    print("{0}{1}".format(prefix,string))
def print_node(tree,depth):
    if 'label' in tree:
        print_with_depth("Leaf:Label {0}".format(tree['label']),depth)
        return
    print_with_depth("{0}>{1}".format(tree['column'],tree['median']),depth)
    branches=[tree['left'],tree['right']]
    for branch in branches:
        print_node(branch,depth+1)

#Making predictions
def predict(tree,row):
    if 'label' in tree:
        return tree['label']
    column=tree['column']
    median=tree['median']
    if row[column]<=median:
        return predict(tree['left'],row)
    else:
        return predict(tree['right'],row)
def batch_predict(tree,df):
    predictions=df.apply(lambda x:predict(tree,x),axis=1)
    return predictions

columns = ["age", "workclass", "education_num", "marital_status", "occupation", "relationship", "race", "sex", "hours_per_week", "native_country"]
tree={} 
nodes=[] #保存节点编号
train=income[:100]  #预测全部时发生栈溢出现象,所以只预测部分数据
test=income[100:110]
id3(train,columns,'high_income',tree)

actual_prediction=pd.DataFrame({'actual':batch_predict(tree,test),'pred':test['high_income']})
actual_prediction.index=range(10)

print("======>>>Decision Tree:")
print(print_node(tree,0))
print("=====>>>>>预测:")
print(actual_prediction)

结果:


上一篇下一篇

猜你喜欢

热点阅读