tensorflow训练流程

2020-06-01  本文已影响0人  青吟乐

tensorflow的训练流程大致分为以下几块

1设置某些初始化信息(也可以用到再给值),具体看下面代码注释

model_root_name = 'model5_200601'  # 模型路径名称
dataset = r"I:\trainset"  # 训练集
label= r"I:\label"  # 标签数据
PATCH_SIZE = (64, 64)  # 块大小
bitch_number = 64  # 每个图像分多少块
ori_lr = 1e-1 # 学习率
MAX_EPOCH = 500

2设置训练的各个部分属性

if __name__ == '__main__':
    train_list = get_train_list(load_file_list(dataset), load_file_list(label))#将data数据和label数据顺次连接

    with tf.name_scope('input_scope'):
        train_input = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
        train_gt = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))

    train_output = model(train_input)
    train_output = tf.clip_by_value(train_output, 0., 1.)

    with tf.name_scope('loss_scope'), tf.device("/gpu:0"):
        loss = tf.reduce_sum(tf.square(tf.subtract(train_output, train_gt)))#loss函数设计成l2
        weights = tf.get_collection(tf.GraphKeys.WEIGHTS)#权重矩阵
        avg_loss = tf.placeholder('float32')#
        tf.summary.scalar("avg_loss", avg_loss)#显示标量信息

    global_step = tf.Variable(0, trainable=False)
    learning_rate = ori_lr
    #优化器设置
    optimizer_adam = tf.train.AdamOptimizer(learning_rate, 0.5)
    opt_adam = optimizer_adam.minimize(loss, global_step=global_step)
    #保存文件使用
    saver = tf.train.Saver(max_to_keep=0)

    # 配置计算方式,打印操作,并行线程,使用最优线程数,自动选择cpugpu
    config = tf.ConfigProto(allow_soft_placement=True)
    # 当使用GPU时候,Tensorflow运行自动慢慢达到最大GPU的内存
    config.gpu_options.allow_growth = True

3进行训练

    #设置一个session运行计算图
    with tf.Session(config=config) as sess:
        merged = tf.summary.merge_all()#自动管理
        #初始计算图变量
        sess.run(tf.global_variables_initializer())
        last_epoch=0

        for epoch in range(last_epoch,last_epoch+MAX_EPOCH):
            #打乱训练集的顺序
            shuffle(train_list)
            total_g_loss, n_iter = 0, 0
            #计算每一轮开始的时间
            epoch_time = time.time()
            total_get_data_time, total_network_time = 0, 0
            for idx in range(1000):
                input_data, gt_data = prepare_nn_data(train_list[:3000])#获取成对的训练集
                feed_dict = {train_input: input_data, train_gt: gt_data}#设置feed_fict
                _, l, output, g_step = sess.run([opt_adam, loss, train_output, global_step], feed_dict=feed_dict)#初始化session的参数
                total_g_loss += l
                n_iter += 1
                del input_data, gt_data, output
            #设置lr和summary
            lr, summary = sess.run([learning_rate, merged], {avg_loss: total_g_loss / n_iter})
            #保存模型
            if ((epoch + 1) % 10 == 0):
                saver.save(sess, os.path.join("%s_%03d_%.ckpt" % (model_root_name,epoch)))
上一篇 下一篇

猜你喜欢

热点阅读