LSTMTensorFlow笔记我爱编程

TensorFlow seq2seq模型实战

2017-05-11  本文已影响3434人  风驰电掣一瓜牛

中法翻译模型

教程: https://www.tensorflow.org/versions/r0.12/tutorials/seq2seq/

目标: 训练一个端到端的英语到法语的翻译模型

下面分如下几个部分讲述:

  1. 准备代码
  2. 准备数据
  3. 运行
  4. 代码分析

准备代码

代码地址: https://github.com/tensorflow/models/tree/master/tutorials/rnn/translate

有3个文件:

data_utils.py
seq2seq_model.py
translate.py

其中,translate.py是主脚本,运行python translate.py -h 可查看参数。seq2seq_model.py是seq2seq+attention的翻译模型实现,data_utils.py是处理数据的脚本,包括下载、解压、分词、构建词表、文档id化等预处理流程。

运行python translate.py --self_test可以测试代码是否可以正常运行。如果不报错就是ok的。

但是由于TensorFlow框架目前更新很快,代码很有可能不能运行。我的机器安装的是tensorflow-1.1.0,测试上面的代码会报错。查看github issue区的一些讨论,定位到是cell的定义有问题,解决方法时修改seq2seq_model.py中的关于cell的定义:

def single_cell():
  return tf.contrib.rnn.GRUCell(size, reuse=tf.get_variable_scope().reuse)
if use_lstm:
  def single_cell():
    return tf.contrib.rnn.BasicLSTMCell(size, reuse=tf.get_variable_scope().reuse)
def mycell():
    return single_cell();
if num_layers > 1:
    def mycell():
        return tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)])
# The seq2seq function: we use embedding for the input and attention.
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
  return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
      encoder_inputs,
      decoder_inputs,
      mycell(),
      num_encoder_symbols=source_vocab_size,
      num_decoder_symbols=target_vocab_size,
      embedding_size=size,
      output_projection=output_projection,
      feed_previous=do_decode,
      dtype=dtype)

修改后,再测试就是ok的了。

准备数据

数据有两份,一份是训练数据,一份是验证数据。数据都是平行语料,即英语到法语的句子对。

训练数据: training-giga-fren.tar 下载地址

验证数据: dev-v2.tgz 下载地址

在脚本data_utils.py中调用相关函数会自动下载上述数据,但是如果训练的机器没有网络或者下载很慢,可以单独下载之后再训练。下载后解压的数据情况如下:

60M     dev-v2.tar
3.6G    giga-fren.release2.fixed.en
4.3G    giga-fren.release2.fixed.fr
2.5G    training-giga-fren.tar

文件行数为2千万行,每个文件都是一行一个句子:

22520376 giga-fren.release2.fixed.en
22520376 giga-fren.release2.fixed.fr

dev-v2.tar解压后有许多文件,只需要其中的两个文件,每个3000行:

newstest2013.en
newstest2013.fr

运行

代码和数据准备好之后,接下来就是运行了。

按照官网教程,第一步是预处理加训练:

python translate.py --data_dir ./data --train_dir ./train_data --en_vocab_size=40000 --fr_vocab_size=40000

其中, ./data目录存放的是下载的训练语料和验证语料,./train_data存在训练的模型,英文和法文词表大小都设置为40000,这样其他的词会用UNK代替,这样设定也是为了节省训练时间。

预处理较长,需要慢慢等待。所有文件处理后+原文件有18G,如下:

21M     dev-v2.tgz
3.6G    giga-fren.release2.fixed.en
1.2G    giga-fren.release2.fixed.en.gz
2.3G    giga-fren.release2.fixed.en.ids40000
4.3G    giga-fren.release2.fixed.fr
1.3G    giga-fren.release2.fixed.fr.gz
2.7G    giga-fren.release2.fixed.fr.ids40000
328K    newstest2013.en
228K    newstest2013.en.ids40000
388K    newstest2013.fr
264K    newstest2013.fr.ids40000
2.5G    training-giga-fren.tar
336K    vocab40000.from
372K    vocab40000.to

vocab40000.from 是英文词表,vocab40000.to是法语词表。

giga-fren.release2.fixed.en.ids40000 和 iga-fren.release2.fixed.fr.ids40000分别是id化的文件。

训练的时候读取到2900000行进程就被killed了,可能是占用内存太大了。于是取部分数据进行训练,放在目录test中:

cd test
head -n 2500000 ../data/giga-fren.release2.fixed.en > en.txt
head -n 2500000 ../data/giga-fren.release2.fixed.fr > fr.txt
cp ../data/newstest2013.en ./
cp ../data/newstest2013.fr .
cd ..

重新训练:

python translate.py \
--data_dir ./test \
--train_dir ./train_data \
--en_vocab_size=40000 \
--fr_vocab_size=40000 \
--from_train_data ./test/en.txt \
--to_train_data ./test/fr.txt \
--from_dev_data ./test/newstest2013.en \
--to_dev_data ./test/newstest2013.fr

如果是新的数据,则需要指定train_data和dev_data的路径,会调用data_utils.py中的prepare_data函数构建词表,然后id化。

训练过程中进程还是总被killed,于是降低模型参数,用更少数据训练:

python translate.py \
    --data_dir ./test \
    --train_dir ./model \
    --from_vocab_size=40000 \
    --to_vocab_size=40000 \
    --from_train_data ./test/en.txt \
    --to_train_data ./test/fr.txt \
    --from_dev_data ./test/newstest2013.en \
    --to_dev_data ./test/newstest2013.fr \
    --size=256 \
    --num_layers=2 \
    --max_train_data_size=1000000

翻译case如下:

Reading model parameters from ./model/translate.ckpt-6200
> States than the number of people killed by lightning.
Les répondants ont été de plus de plus de plus de plus de plus de
> One thing is certain: these new provisions will have a negative impact on voter turn-out.
Il est pas de plus de renseignements sur les questions de ces pays .
> These restrictions are not without consequence.
Les répondants sont pas de plus de plus .
> Who is the president of the United States?
Le _UNK est le _UNK

代码分析

官网教程中说tf有seq2seq的库,包含各种seq2seq模型。如基本模型basic_rnn_seq2seq,输入就是词语。而embedding模型,输入变为词的embedding表示。而例子中用的是加入attention机制的embedding模型,如下

def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
  return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
      encoder_inputs,
      decoder_inputs,
      mycell(),
      num_encoder_symbols=source_vocab_size,
      num_decoder_symbols=target_vocab_size,
      embedding_size=size,
      output_projection=output_projection,
      feed_previous=do_decode,
      dtype=dtype)

encoder_inputs  对应源句子中的词
decoder_inputs  对应目标句子中的词
cell 为RNNCell实例,如GRUCell,LSTMCell
num_encoder_symbols encoder输入词表大小
num_decoder_symbols decoder输入词表大小
output_projection 不设定的话输出维数可能很大(取决于词表大小),设定的话投影到一个低维向量(by sampled softmax loss)
feed_previous   True表示使用上一时刻decoder的输出,Faslse表示使用输入(仅训练时有效)

返回值为(outputs, states)
outputs 对应decoder的输出,与decoder_inputs长度相同
states  对应decoder每个输出的状态

需要注意的点:

翻译模型基础

补充一些基础知识材料,有时间再深入看看

  1. 基础的seq2seq模型

Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation (EMNLP 2014)

有两个RNN结构,一个编码器,一个解码器,编码器输入为源语言,解码器输入为目标语言和编码器的输出,并且解码器输出翻译结果。

在基本模型中,所有输入都编码为一个固定长度的状态向量,传递给解码器。

  1. 注意力模型

Neural Machine Translation by Jointly Learning to Align and Translate

引入注意力机制后,解码器每一步解码的时候都会瞥一眼输入(peek into the input)

中英翻译模型

找到一份联合国的中英平行语料,正好跑跑中英翻译的模型,毕竟法语看不懂,也不知道翻译的如何。

语料准备

联合国语料:https://conferences.unite.un.org/UNCorpus/

点击Download,填下表就可以下载,有两个文件:

1.0G    UNv1.0.en-zh.tar.gz.00
299M    UNv1.0.en-zh.tar.gz.01

看官网介绍,有91,028个文件,一共15886041行。两个文件cat到一份文件,然后解压,解压后有如下文件:

4.0K    DISCLAIMER
4.0K    README
2.3G    UNv1.0.en-zh.en
558M    UNv1.0.en-zh.ids
1.8G    UNv1.0.en-zh.zh
144K    UNv1.0.pdf

UNv1.0.en-zh.en和UNv1.0.en-zh.zh就是我们想要的平行语料了。对中文分词,然后就可以复用之前的代码跑模型了。

模型训练和预测

由于数据太大,选择了100万行训练数据,3000行验证数据,如下:

     3000 dev.en
     3000 dev.zh
  1000000 part.en
  1000000 part.zh

中文为分词结果,词语之间用空格分隔。所有文件放在目录zh_data下,词表大小设定为50000,训练模型放在./model2下,运行脚本如下:

python translate.py \
    --data_dir ./zh_data \
    --train_dir ./model2 \
    --from_vocab_size=50000 \
    --to_vocab_size=50000 \
    --from_train_data ./zh_data/part.zh \
    --to_train_data ./zh_data/part.en \
    --from_dev_data ./zh_data/dev.zh \
    --to_dev_data ./zh_data/dev.en \
    --size=256 \
    --num_layers=2 \
    --steps_per_checkpoint=100 \
    --max_train_data_size=0

训练后数据目录zh_data变为:

444K    dev.en
256K    dev.en.ids50000
408K    dev.zh
252K    dev.zh.ids50000
147M    part.en
88M     part.en.ids50000
132M    part.zh
83M     part.zh.ids50000
444K    vocab50000.from
460K    vocab50000.to

现在进行测试,测试一定要将参数设置为和训练的一样,不然加载会报错,测试命令:

python translate.py --decode \
    --data_dir ./zh_data \
    --train_dir ./model2 \
    --from_vocab_size=50000 \
    --to_vocab_size=50000 \
    --size=256 \
    --num_layers=2

迭代3400次终止:

global step 3400 learning rate 0.4950 step-time 0.18 perplexity 36.18
  eval: bucket 0 perplexity 9.97
  eval: bucket 1 perplexity 24.35
  eval: bucket 2 perplexity 50.14
  eval: bucket 3 perplexity 74.99

翻译的结果如下:

> 他 在 干 什么
_UNK
> 通过 2015年 第 一 届 常会 报告
Report of the report of the Conference of the General Assembly
> 审计 咨询 委员 会 的 报告
Report of the Committee of the Committee
> 世卫组织 的 指导 面向 卫生保健 工作者 和 设施
_UNK _UNK _UNK _UNK _UNK _UNK
> 编制 一 份 指导 文件 草案
A decision of the draft resolution
> 第六届 会议 的 时间 和 形式
_UNK and the Conference of the Conference of the Conference of the Republic of the
> 我 鼓励 诸位 在 可能 的 情况 下 参加 这些 活动
I to be be to be to be be in the work of the Republic of the Republic of the Republic of the Republic of
> 中国
China
> 美国
United Republic of the Republic of the United Republic of
> 英国
Total
> 联合国
United Nations
> 葡萄牙
Travel
> 意大利
Australia
> 马来西亚
Malaysia
> 法国
Germany
> 德国
Germany
> 问题
General Assembly
> 组织
United Nations

翻译有很多不准的地方,但是还是有一些效果的,毕竟训练的数据不是太通用,主要是联合国官方文件,而且训练的词表也设定的很小,训练数据也只用了很小一部分,另外训练数据也是需要清洗的。

其他参考

上一篇下一篇

猜你喜欢

热点阅读