wide&deep试验

2018-12-09  本文已影响0人  rwj_pku

背景 :通过人口调查数据来判断收入情况, 分类问题

主要流程:

1. 下载数据

2. 定义训练集input_fn

    input_fn作用:通过管道将dataset传输到Estimator中。

    TextLineDataset解析文本文件生成dataset

    Dataset 支持shuffle

    Dataset 执行解析csv的map

        map函数: tf.decode_csv, 返回features,classes

        根据epochs数量repeat

        Dataset.batch  把dataset变成batches

3. 定义评估集input_fn

    input_fn复用训练集的input_fn

4. 定义模型

    创建estimator_fn

        4.1. 获得feature_columns: wide_column和deep_column。

           使用numeric_column, categorical_column_with_vocabulary_list, categorical_column_with_hash_bucket, bucketized_column, cross_column, embedding_column.等

        4.2 hidden_units 

        4.3 设置训练参数,保证模型是在cpu上训练的, 因为比在GPU上快

                tf.estimator.RunConfig

        4.4 定义模型

                如果只用wide模型:tf.estimator.LinearClassifier

                如果只用deep模型:tf.estimator.DNNClassifier

                wide&deep模型:tf.estimator.DNNLinearCombinedClassifier

 5. 训练和评估

        训练训练集input_fn   model.train()

        评估测试集input_fn   model.evaluate()

  6. 导出模型

        model.export_savedmodel

        

    

上一篇下一篇

猜你喜欢

热点阅读