python3疯狂之路

通过鸢尾花数据集理解深度神经网络(DNN)建模、训练、验证和预测

2017-08-22  本文已影响101人  周且南_laygin

理论总是很乏味,学习要结合实例才能掌握扎实。

说明,本文参考TensorFlow官网例子,结合学习所得。英文原文参考 这里

1、获取数据集

可以通过网上搜索得到,不过为了方便,我已经传到 github 上,可直接下载或复制粘贴以供使用。训练集测试集

2、数据集说明

2.1 数据如下

花萼长度 花萼宽度 花瓣长度 花瓣宽度 种类
6.5 3.6 5.6 0.9 0
... ... ... ... ...
2.4 5.7 1.6 3.6 2

2.2 任务说明

根据 花萼长度、花萼宽度、花瓣长度和花瓣宽度 来预测种类,其中种类为整型,取值空间为0,1,2

3、训练模型

训练集包括120个样本,测试集包括30个样本。

3.1 读取数据

以下假设数据集路径与python脚本在同一个目录下。

#先导入所需要的模块
import tensorflow as tf
import numpy as np

iris_training_set = 'iris_training.csv'
iris_test_set = 'iris_test.csv'

#读取数据
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=iris_training_set,
target_dtype=np.int,
features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=iris_test_set,
target_dtype=np.int,
features_dtype=np.float32)

target_dtype即为预测目标的数据类型,这里是整型 ;同样特征的类型设置为浮点型。

3.2 创建DNN分类器

#指定所有特征都是实数型
feature_columns = [tf.contrib.layers.real_valued_column("",dimension=4)]
#创建3层DNN,分别包括10,20,10个神经元
classifier = tf.contrib.learn.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10,20,10],
n_classes=3,
model_dir = 'iris_model')

dimension:维度为4,因为有四个特征。
n_classes:预测的目标类别,这里是3类。
hidden_units:3个隐藏层,神经元数目分别为10,20,10.

3.3 开始训练

def get_train_inputs():
    x =  tf.constant(training_set.data)
    y =  tf.constant(training_set.target)
    
    return x,y
#训练模型
classifier.fit(input_fn=get_train_inputs,steps=2000)

#以上两步可以写成
# classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
# classifier.fit(x=training_set.data, y=training_set.target, steps=1000)

4、验证模型

def get_test_inputs():
    x  = tf.constant(test_set.data)
    y = tf.constant(test_set.target)
    
    return x,y

#验证准确性
evalu = classifier.evaluate(
input_fn=get_test_inputs,
steps=1)
accuracy = evalu['accuracy']

print('accuracy:{}'.format(accuracy))

最后结果为:0.9666666388511658

5、预测

#预测新样本的分类
predictions = list(classifier.predict(
    np.array([[5.6,7.6,3.5,1.4],[7.8,9.7,2.4,2.5],[7.6,4.6,8.7,1.3]])))

print('新分类{}'.format(predictions))

结果为:新分类[0, 0, 2]

上一篇 下一篇

猜你喜欢

热点阅读