TensorFlow seq2seq模型实战
中法翻译模型
教程: https://www.tensorflow.org/versions/r0.12/tutorials/seq2seq/
目标: 训练一个端到端的英语到法语的翻译模型
下面分如下几个部分讲述:
- 准备代码
- 准备数据
- 运行
- 代码分析
准备代码
代码地址: https://github.com/tensorflow/models/tree/master/tutorials/rnn/translate
有3个文件:
data_utils.py
seq2seq_model.py
translate.py
其中,translate.py是主脚本,运行python translate.py -h 可查看参数。seq2seq_model.py是seq2seq+attention的翻译模型实现,data_utils.py是处理数据的脚本,包括下载、解压、分词、构建词表、文档id化等预处理流程。
运行python translate.py --self_test可以测试代码是否可以正常运行。如果不报错就是ok的。
但是由于TensorFlow框架目前更新很快,代码很有可能不能运行。我的机器安装的是tensorflow-1.1.0,测试上面的代码会报错。查看github issue区的一些讨论,定位到是cell的定义有问题,解决方法时修改seq2seq_model.py中的关于cell的定义:
def single_cell():
return tf.contrib.rnn.GRUCell(size, reuse=tf.get_variable_scope().reuse)
if use_lstm:
def single_cell():
return tf.contrib.rnn.BasicLSTMCell(size, reuse=tf.get_variable_scope().reuse)
def mycell():
return single_cell();
if num_layers > 1:
def mycell():
return tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layers)])
# The seq2seq function: we use embedding for the input and attention.
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
encoder_inputs,
decoder_inputs,
mycell(),
num_encoder_symbols=source_vocab_size,
num_decoder_symbols=target_vocab_size,
embedding_size=size,
output_projection=output_projection,
feed_previous=do_decode,
dtype=dtype)
修改后,再测试就是ok的了。
准备数据
数据有两份,一份是训练数据,一份是验证数据。数据都是平行语料,即英语到法语的句子对。
训练数据: training-giga-fren.tar 下载地址
验证数据: dev-v2.tgz 下载地址
在脚本data_utils.py中调用相关函数会自动下载上述数据,但是如果训练的机器没有网络或者下载很慢,可以单独下载之后再训练。下载后解压的数据情况如下:
60M dev-v2.tar
3.6G giga-fren.release2.fixed.en
4.3G giga-fren.release2.fixed.fr
2.5G training-giga-fren.tar
文件行数为2千万行,每个文件都是一行一个句子:
22520376 giga-fren.release2.fixed.en
22520376 giga-fren.release2.fixed.fr
dev-v2.tar解压后有许多文件,只需要其中的两个文件,每个3000行:
newstest2013.en
newstest2013.fr
运行
代码和数据准备好之后,接下来就是运行了。
按照官网教程,第一步是预处理加训练:
python translate.py --data_dir ./data --train_dir ./train_data --en_vocab_size=40000 --fr_vocab_size=40000
其中, ./data目录存放的是下载的训练语料和验证语料,./train_data存在训练的模型,英文和法文词表大小都设置为40000,这样其他的词会用UNK代替,这样设定也是为了节省训练时间。
预处理较长,需要慢慢等待。所有文件处理后+原文件有18G,如下:
21M dev-v2.tgz
3.6G giga-fren.release2.fixed.en
1.2G giga-fren.release2.fixed.en.gz
2.3G giga-fren.release2.fixed.en.ids40000
4.3G giga-fren.release2.fixed.fr
1.3G giga-fren.release2.fixed.fr.gz
2.7G giga-fren.release2.fixed.fr.ids40000
328K newstest2013.en
228K newstest2013.en.ids40000
388K newstest2013.fr
264K newstest2013.fr.ids40000
2.5G training-giga-fren.tar
336K vocab40000.from
372K vocab40000.to
vocab40000.from 是英文词表,vocab40000.to是法语词表。
giga-fren.release2.fixed.en.ids40000 和 iga-fren.release2.fixed.fr.ids40000分别是id化的文件。
训练的时候读取到2900000行进程就被killed了,可能是占用内存太大了。于是取部分数据进行训练,放在目录test中:
cd test
head -n 2500000 ../data/giga-fren.release2.fixed.en > en.txt
head -n 2500000 ../data/giga-fren.release2.fixed.fr > fr.txt
cp ../data/newstest2013.en ./
cp ../data/newstest2013.fr .
cd ..
重新训练:
python translate.py \
--data_dir ./test \
--train_dir ./train_data \
--en_vocab_size=40000 \
--fr_vocab_size=40000 \
--from_train_data ./test/en.txt \
--to_train_data ./test/fr.txt \
--from_dev_data ./test/newstest2013.en \
--to_dev_data ./test/newstest2013.fr
如果是新的数据,则需要指定train_data和dev_data的路径,会调用data_utils.py中的prepare_data函数构建词表,然后id化。
训练过程中进程还是总被killed,于是降低模型参数,用更少数据训练:
python translate.py \
--data_dir ./test \
--train_dir ./model \
--from_vocab_size=40000 \
--to_vocab_size=40000 \
--from_train_data ./test/en.txt \
--to_train_data ./test/fr.txt \
--from_dev_data ./test/newstest2013.en \
--to_dev_data ./test/newstest2013.fr \
--size=256 \
--num_layers=2 \
--max_train_data_size=1000000
翻译case如下:
Reading model parameters from ./model/translate.ckpt-6200
> States than the number of people killed by lightning.
Les répondants ont été de plus de plus de plus de plus de plus de
> One thing is certain: these new provisions will have a negative impact on voter turn-out.
Il est pas de plus de renseignements sur les questions de ces pays .
> These restrictions are not without consequence.
Les répondants sont pas de plus de plus .
> Who is the president of the United States?
Le _UNK est le _UNK
代码分析
官网教程中说tf有seq2seq的库,包含各种seq2seq模型。如基本模型basic_rnn_seq2seq,输入就是词语。而embedding模型,输入变为词的embedding表示。而例子中用的是加入attention机制的embedding模型,如下
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
encoder_inputs,
decoder_inputs,
mycell(),
num_encoder_symbols=source_vocab_size,
num_decoder_symbols=target_vocab_size,
embedding_size=size,
output_projection=output_projection,
feed_previous=do_decode,
dtype=dtype)
encoder_inputs 对应源句子中的词
decoder_inputs 对应目标句子中的词
cell 为RNNCell实例,如GRUCell,LSTMCell
num_encoder_symbols encoder输入词表大小
num_decoder_symbols decoder输入词表大小
output_projection 不设定的话输出维数可能很大(取决于词表大小),设定的话投影到一个低维向量(by sampled softmax loss)
feed_previous True表示使用上一时刻decoder的输出,Faslse表示使用输入(仅训练时有效)
返回值为(outputs, states)
outputs 对应decoder的输出,与decoder_inputs长度相同
states 对应decoder每个输出的状态
需要注意的点:
- Sampled softmax and output projection: 输出要投影到低维,不然输出会很大
- Bucketing and padding: 设定两个语言长度对齐的方法,比如一种语言长度为5的句子一般翻译为另一种语言长度为10的句子
翻译模型基础
补充一些基础知识材料,有时间再深入看看
- 基础的seq2seq模型
Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation (EMNLP 2014)
有两个RNN结构,一个编码器,一个解码器,编码器输入为源语言,解码器输入为目标语言和编码器的输出,并且解码器输出翻译结果。
在基本模型中,所有输入都编码为一个固定长度的状态向量,传递给解码器。
- 注意力模型
Neural Machine Translation by Jointly Learning to Align and Translate
引入注意力机制后,解码器每一步解码的时候都会瞥一眼输入(peek into the input)
中英翻译模型
找到一份联合国的中英平行语料,正好跑跑中英翻译的模型,毕竟法语看不懂,也不知道翻译的如何。
语料准备
联合国语料:https://conferences.unite.un.org/UNCorpus/
点击Download,填下表就可以下载,有两个文件:
1.0G UNv1.0.en-zh.tar.gz.00
299M UNv1.0.en-zh.tar.gz.01
看官网介绍,有91,028个文件,一共15886041行。两个文件cat到一份文件,然后解压,解压后有如下文件:
4.0K DISCLAIMER
4.0K README
2.3G UNv1.0.en-zh.en
558M UNv1.0.en-zh.ids
1.8G UNv1.0.en-zh.zh
144K UNv1.0.pdf
UNv1.0.en-zh.en和UNv1.0.en-zh.zh就是我们想要的平行语料了。对中文分词,然后就可以复用之前的代码跑模型了。
模型训练和预测
由于数据太大,选择了100万行训练数据,3000行验证数据,如下:
3000 dev.en
3000 dev.zh
1000000 part.en
1000000 part.zh
中文为分词结果,词语之间用空格分隔。所有文件放在目录zh_data下,词表大小设定为50000,训练模型放在./model2下,运行脚本如下:
python translate.py \
--data_dir ./zh_data \
--train_dir ./model2 \
--from_vocab_size=50000 \
--to_vocab_size=50000 \
--from_train_data ./zh_data/part.zh \
--to_train_data ./zh_data/part.en \
--from_dev_data ./zh_data/dev.zh \
--to_dev_data ./zh_data/dev.en \
--size=256 \
--num_layers=2 \
--steps_per_checkpoint=100 \
--max_train_data_size=0
训练后数据目录zh_data变为:
444K dev.en
256K dev.en.ids50000
408K dev.zh
252K dev.zh.ids50000
147M part.en
88M part.en.ids50000
132M part.zh
83M part.zh.ids50000
444K vocab50000.from
460K vocab50000.to
现在进行测试,测试一定要将参数设置为和训练的一样,不然加载会报错,测试命令:
python translate.py --decode \
--data_dir ./zh_data \
--train_dir ./model2 \
--from_vocab_size=50000 \
--to_vocab_size=50000 \
--size=256 \
--num_layers=2
迭代3400次终止:
global step 3400 learning rate 0.4950 step-time 0.18 perplexity 36.18
eval: bucket 0 perplexity 9.97
eval: bucket 1 perplexity 24.35
eval: bucket 2 perplexity 50.14
eval: bucket 3 perplexity 74.99
翻译的结果如下:
> 他 在 干 什么
_UNK
> 通过 2015年 第 一 届 常会 报告
Report of the report of the Conference of the General Assembly
> 审计 咨询 委员 会 的 报告
Report of the Committee of the Committee
> 世卫组织 的 指导 面向 卫生保健 工作者 和 设施
_UNK _UNK _UNK _UNK _UNK _UNK
> 编制 一 份 指导 文件 草案
A decision of the draft resolution
> 第六届 会议 的 时间 和 形式
_UNK and the Conference of the Conference of the Conference of the Republic of the
> 我 鼓励 诸位 在 可能 的 情况 下 参加 这些 活动
I to be be to be to be be in the work of the Republic of the Republic of the Republic of the Republic of
> 中国
China
> 美国
United Republic of the Republic of the United Republic of
> 英国
Total
> 联合国
United Nations
> 葡萄牙
Travel
> 意大利
Australia
> 马来西亚
Malaysia
> 法国
Germany
> 德国
Germany
> 问题
General Assembly
> 组织
United Nations
翻译有很多不准的地方,但是还是有一些效果的,毕竟训练的数据不是太通用,主要是联合国官方文件,而且训练的词表也设定的很小,训练数据也只用了很小一部分,另外训练数据也是需要清洗的。