Machine Learning & Recommendation & NLP & DL

自然语言处理N天-Day1103从0搭建一个RNN神经网络作诗(

2019-02-19  本文已影响2人  我的昵称违规了

说明:本文依据Github上面的一个2000星项目完成。项目作者jinfagang项目地址,在这里感谢那些开源的程序员,让我们学到更多。
我会尽量将项目进行拆解,希望对大家的学习有所帮助吧。

第十一课 使用RNN生成古诗

上一节数据预处理和模型构建

4.模型的训练

获取数据 batch 的代码位于poem.py的generate_batch方法,作用是用来获取每一个batch的数值。作为接下来模型训练的输入数据集。
传入参数有batch_size:batch的大小;poems_vec:前面生成的诗文中字ID;word_to_int:前面生成的每一个字ID。

def generate_batch(batch_size, poems_vec, word_to_int):
    # 每次取batch_size首诗进行训练
    n_chunk = len(poems_vec) // batch_size
    x_batches = []
    y_batches = []
    #使用for循环,生成n_chunk个batch。
    for i in range(n_chunk):
        #每一个batch开始和结束的index
        start_index = i * batch_size
        end_index = start_index + batch_size
        batches = poems_vec[start_index:end_index]
        
        # 找到这个batch中所有poem最长的poem的长度,以这个长度为最大值生成batch中每一行的长度。
        length = max(map(len, batches))
        # 填充一个空batch,空的地方放空格对应的index标号
        x_data = np.full((batch_size, length), word_to_int[' '], np.int32)

        for row, batch in enumerate(batches):
            # 每一行就是一首诗,在原本的长度上把诗还原上去
            x_data[row, :len(batch)] = batch
        y_data = np.copy(x_data)
        # y就是x向左边移动一个,最后一位使用倒数第二位的数值填充
        y_data[:, :-1] = x_data[:, 1:]
        x_batches.append(x_data)
        y_batches.append(y_data)
    return x_batches, y_batches

模型的训练代码位于train.py的run_training方法

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import poems
import models

tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs')

FLAGS = tf.app.flags.FLAGS


def run_training():
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)
    # 读取诗集文件
    # 依次得到数字ID表示的诗句、汉字-ID的映射map、所有的汉字的列表
    poems_vector, word_to_int, vocabularies = poems.process_poems(FLAGS.file_path)
    batches_inputs, batches_outputs = poems.generate_batch(FLAGS.batch_size, poems_vector, word_to_int)

    input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
    output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])

    # 通过rnn模型得到结果状态集
    end_points = models.rnn_model(model='lstm', input_data=input_data, output_data=output_targets,
                                  vocab_size=len(vocabularies), rnn_size=128, num_layers=2, batch_size=64,
                                  learning_rate=FLAGS.learning_rate)

    # 初始化saver和session
    saver = tf.train.Saver(tf.global_variables())
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)

        start_epoch = 0
        checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)
            print('## restore from the checkpointt {0}'.format(checkpoint))
            start_epoch += int(checkpoint.split('-')[-1])
        print('## strat training...')

        try:
            n_chunk = len(poems_vector) // FLAGS.batch_size
            for epoch in range(start_epoch, FLAGS.epoches):
                n = 0
                for batch in range(n_chunk):
                    # 训练并计算loss
                    # batches_inputs[n]: 第n个batch的输入数据
                    # batches_outputs[n]: 第n个batch的输出数据
                    loss, _, _ = sess.run([
                        end_points['total_loss'],
                        end_points['last_state'],
                        end_points['train_op']],
                        feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}
                    )
                    n += 1
                    print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
                    # 每训练6个epoch进行一次模型保存
                    if epoch % 6 == 0:
                        saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
        except KeyboardInterrupt:
            print('## Interrupt manually, try saving checkpoint for now...')
            saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
            print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))


def main(_):
    run_training()


if __name__ == '__main__':
    tf.app.run()

上一篇下一篇

猜你喜欢

热点阅读