通过鸢尾花数据集理解深度神经网络(DNN)建模、训练、验证和预测
2017-08-22 本文已影响101人
周且南_laygin
理论总是很乏味,学习要结合实例才能掌握扎实。
说明,本文参考TensorFlow官网例子,结合学习所得。英文原文参考 这里
1、获取数据集
可以通过网上搜索得到,不过为了方便,我已经传到 github 上,可直接下载或复制粘贴以供使用。训练集、测试集。
2、数据集说明
2.1 数据如下
花萼长度 | 花萼宽度 | 花瓣长度 | 花瓣宽度 | 种类 |
---|---|---|---|---|
6.5 | 3.6 | 5.6 | 0.9 | 0 |
... | ... | ... | ... | ... |
2.4 | 5.7 | 1.6 | 3.6 | 2 |
2.2 任务说明
根据 花萼长度、花萼宽度、花瓣长度和花瓣宽度 来预测种类,其中种类为整型,取值空间为0,1,2
3、训练模型
训练集包括120个样本,测试集包括30个样本。
3.1 读取数据
以下假设数据集路径与python脚本在同一个目录下。
#先导入所需要的模块
import tensorflow as tf
import numpy as np
iris_training_set = 'iris_training.csv'
iris_test_set = 'iris_test.csv'
#读取数据
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=iris_training_set,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=iris_test_set,
target_dtype=np.int,
features_dtype=np.float32)
target_dtype
即为预测目标的数据类型,这里是整型 ;同样特征的类型设置为浮点型。
3.2 创建DNN分类器
#指定所有特征都是实数型
feature_columns = [tf.contrib.layers.real_valued_column("",dimension=4)]
#创建3层DNN,分别包括10,20,10个神经元
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10,20,10],
n_classes=3,
model_dir = 'iris_model')
dimension
:维度为4,因为有四个特征。
n_classes
:预测的目标类别,这里是3类。
hidden_units
:3个隐藏层,神经元数目分别为10,20,10.
3.3 开始训练
def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target)
return x,y
#训练模型
classifier.fit(input_fn=get_train_inputs,steps=2000)
#以上两步可以写成
# classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
# classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
4、验证模型
def get_test_inputs():
x = tf.constant(test_set.data)
y = tf.constant(test_set.target)
return x,y
#验证准确性
evalu = classifier.evaluate(
input_fn=get_test_inputs,
steps=1)
accuracy = evalu['accuracy']
print('accuracy:{}'.format(accuracy))
最后结果为:0.9666666388511658
5、预测
#预测新样本的分类
predictions = list(classifier.predict(
np.array([[5.6,7.6,3.5,1.4],[7.8,9.7,2.4,2.5],[7.6,4.6,8.7,1.3]])))
print('新分类{}'.format(predictions))
结果为:新分类[0, 0, 2]