Learning private models with mul

2017-10-25  本文已影响13人  yingtaomj

代码地址

首先训练教师部分

由于试验目的只是为了把程序跑通,nb_teachers设为10,max_steps设为20,batch_size为128,teacher_id要手动设置循环跑,每次要删除tmp文件夹里的npy文件。

train_teacher.py里只有一个主函数train_teacher

  1. 调用input.partition_dataset函数得到(6000,28,28,1)格式的data和(6000,)的label。
  2. 再调用deep_cnn.train对数据进行训练。
  3. 最后调用deep_cnn.softmax_preds获得预测结果,再用metrics.accuracy来检验。

deep_cnn.train中,with tf.Graph().as_default():

  1. 先定义placeholder,(128,28,28,1)格式的train_data_node,(128,)的train_labels_node
  2. 定义inference函数得到不同数字的预测概率;定义loss的计算过程;定义train训练模型。
  3. 给placeholder填充image和data,sess.run,每隔20步打印loss并保存结果。

其中,inference定义了CNN模型。依次包括:

  1. conv1,输入(128,28,28,1)的image,输出(128,28,28,64)的conv1
  2. pool,输入conv1,输出(128,14,14,64)的pool1,利用tf.nn.lrn标准化
  3. 依次类推,(128,14,14,128)的conv2->标准化->(128,7,7,128)的pool2->(128,384)的local3->(128,192)的local4->softmax logits

学生模型(和教师模型基本是一样的)

有两个分函数:

ensemble_preds

利用教师模型对给定数据的预测结果:格式为(10 ,1000,10)(nb_teacher,len(data),nb_label)的result。每一个result[id_teacher]都是调用deep_cnn.softmax_preds得到的。

prepare_student_data

打印教师聚合模型的准确率,准备stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels四部分数据。(stdnt_labels是聚合模型预测的结果)

主函数train_student

调用prepare_student_data准备四部分数据;
将stdnt_data, stdnt_labels用于训练学生模型;
最后调用deep_cnn.softmax_preds获得预测结果,再用metrics.accuracy来检验学生模型的结果。

上一篇下一篇

猜你喜欢

热点阅读