wide&deep试验
背景 :通过人口调查数据来判断收入情况, 分类问题
主要流程:
1. 下载数据
2. 定义训练集input_fn
input_fn作用:通过管道将dataset传输到Estimator中。
TextLineDataset解析文本文件生成dataset
Dataset 支持shuffle
Dataset 执行解析csv的map
map函数: tf.decode_csv, 返回features,classes
根据epochs数量repeat
Dataset.batch 把dataset变成batches
3. 定义评估集input_fn
input_fn复用训练集的input_fn
4. 定义模型
创建estimator_fn
4.1. 获得feature_columns: wide_column和deep_column。
使用numeric_column, categorical_column_with_vocabulary_list, categorical_column_with_hash_bucket, bucketized_column, cross_column, embedding_column.等
4.2 hidden_units
4.3 设置训练参数,保证模型是在cpu上训练的, 因为比在GPU上快
tf.estimator.RunConfig
4.4 定义模型
如果只用wide模型:tf.estimator.LinearClassifier
如果只用deep模型:tf.estimator.DNNClassifier
wide&deep模型:tf.estimator.DNNLinearCombinedClassifier
5. 训练和评估
训练训练集input_fn model.train()
评估测试集input_fn model.evaluate()
6. 导出模型
model.export_savedmodel