基于sklearn的特征筛选
2018-01-20 本文已影响0人
月见樽
理论
特征筛选的作用
样本中的有些特征是所谓的“优秀特征”,使用这些特征可以显著的提高泛化能力。而有些特征在样本类别区分上并不明显,在训练中引入这些特征会导致算力的浪费;另外有些特征对样本的分类有反作用,引入这些特征反而会导致泛化能力下降
特征筛选
与PCA(主成分分析)不同,特征筛选不修改特征值,而是寻找对模型性能提升较大的尽量少的特征
代码实现
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
引入数据集——Titanic数据集
titan = pd.read_csv("http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt")
titan.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1313 entries, 0 to 1312
Data columns (total 11 columns):
row.names 1313 non-null int64
pclass 1313 non-null object
survived 1313 non-null int64
name 1313 non-null object
age 633 non-null float64
embarked 821 non-null object
home.dest 754 non-null object
room 77 non-null object
ticket 69 non-null object
boat 347 non-null object
sex 1313 non-null object
dtypes: float64(1), int64(2), object(8)
memory usage: 112.9+ KB
数据预处理
分离数据与标签
x_source = titan.drop(["row.names","name","survived"],axis=1)
x_source.info()
y_source = titan["survived"]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1313 entries, 0 to 1312
Data columns (total 8 columns):
pclass 1313 non-null object
age 633 non-null float64
embarked 821 non-null object
home.dest 754 non-null object
room 77 non-null object
ticket 69 non-null object
boat 347 non-null object
sex 1313 non-null object
dtypes: float64(1), object(7)
memory usage: 82.1+ KB
缺失数据填充
x_source['age'].fillna(x_source['age'].mean(),inplace=True)
x_source.fillna('UNKNOWN',inplace=True)
x_source.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1313 entries, 0 to 1312
Data columns (total 8 columns):
pclass 1313 non-null object
age 1313 non-null float64
embarked 1313 non-null object
home.dest 1313 non-null object
room 1313 non-null object
ticket 1313 non-null object
boat 1313 non-null object
sex 1313 non-null object
dtypes: float64(1), object(7)
memory usage: 82.1+ KB
数据分割
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x_source,y_source,random_state=33,test_size=0.25)
x_train.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 984 entries, 1086 to 1044
Data columns (total 8 columns):
pclass 984 non-null object
age 984 non-null float64
embarked 984 non-null object
home.dest 984 non-null object
room 984 non-null object
ticket 984 non-null object
boat 984 non-null object
sex 984 non-null object
dtypes: float64(1), object(7)
memory usage: 69.2+ KB
特征向量化
from sklearn.feature_extraction import DictVectorizer
vec = DictVectorizer()
x_train = vec.fit_transform(x_train.to_dict(orient='record'))
x_test = vec.transform(x_test.to_dict(orient='record'))
len(vec.feature_names_)
474
模型训练
from sklearn.tree import DecisionTreeClassifier
基本决策树模型
dt = DecisionTreeClassifier(criterion='entropy')
dt.fit(x_train,y_train)
dt.score(x_test,y_test)
0.82066869300911849
带特征筛选的决策树
from sklearn import feature_selection
fs = feature_selection.SelectPercentile(feature_selection.chi2,percentile=7)
x_train_fs = fs.fit_transform(x_train,y_train)
x_test_fs = fs.transform(x_test)
print(x_train.shape,x_train_fs.shape)
(984, 474) (984, 33)
dt.fit(x_train_fs,y_train)
dt.score(x_test_fs,y_test)
0.85410334346504557