Learning private models with mul
2017-10-25 本文已影响13人
yingtaomj
首先训练教师部分
由于试验目的只是为了把程序跑通,nb_teachers设为10,max_steps设为20,batch_size为128,teacher_id要手动设置循环跑,每次要删除tmp文件夹里的npy文件。
train_teacher.py里只有一个主函数train_teacher
- 调用
input.partition_dataset
函数得到(6000,28,28,1)格式的data和(6000,)的label。 - 再调用
deep_cnn.train
对数据进行训练。 - 最后调用
deep_cnn.softmax_preds
获得预测结果,再用metrics.accuracy
来检验。
在deep_cnn.train
中,with tf.Graph().as_default():
- 先定义placeholder,(128,28,28,1)格式的
train_data_node
,(128,)的train_labels_node
。 - 定义
inference
函数得到不同数字的预测概率;定义loss
的计算过程;定义train
训练模型。 - 给placeholder填充image和data,
sess.run
,每隔20步打印loss并保存结果。
其中,inference
定义了CNN模型。依次包括:
-
conv1
,输入(128,28,28,1)的image,输出(128,28,28,64)的conv1 -
pool
,输入conv1,输出(128,14,14,64)的pool1,利用tf.nn.lrn标准化 - 依次类推,(128,14,14,128)的conv2->标准化->(128,7,7,128)的pool2->(128,384)的local3->(128,192)的local4->softmax logits
学生模型(和教师模型基本是一样的)
有两个分函数:
ensemble_preds
利用教师模型对给定数据的预测结果:格式为(10 ,1000,10)(nb_teacher,len(data),nb_label)
的result。每一个result[id_teacher]都是调用deep_cnn.softmax_preds
得到的。
prepare_student_data
打印教师聚合模型的准确率,准备stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels
四部分数据。(stdnt_labels是聚合模型预测的结果)
主函数train_student
调用prepare_student_data准备四部分数据;
将stdnt_data, stdnt_labels用于训练学生模型;
最后调用deep_cnn.softmax_preds
获得预测结果,再用metrics.accuracy
来检验学生模型的结果。