自然语言处理N天-实现Transformer加载数据方法
2019-03-05 本文已影响4人
我的昵称违规了
新建 Microsoft PowerPoint 演示文稿 (2).jpg
这个算是在课程学习之外的探索,不过希望能尽快用到项目实践中。在文章里会引用较多的博客,文末会进行reference。
搜索Transformer机制,会发现高分结果基本上都源于一篇论文Jay Alammar的《The Illustrated Transformer》(图解Transformer),提到最多的Attention是Google的《Attention Is All You Need》。
对于Transformer的运行机制了解即可,所以会基于这篇论文来学习Transformer,结合《Sklearn+Tensorflow》中Attention注意力机制一章完成基本的概念学习;- 找一个基于Transformer的项目练手
5.代码实现
构建data_load
import tensorflow as tf
from utils import calc_num_batches
def load_vocab(vocab_fpath):
'''
加载词文件,返回一个idx<->token的图
:param vocab_fpath: 字符串,词文件的地址 0: <pad>, 1: <unk>, 2: <s>, 3: </s>
:return: 两个字典
'''
vocab=[line.split() for line in open(vocab_fpath,'r',encoding='utf-8').read().splitlines()]
token2idx={token:idx for idx, token in enumerate(vocab)}
idx2token={idx:token for idx, token in enumerate(vocab)}
return token2idx, idx2token
def load_data(fpath1,fpath2,maxlen1,maxlen2):
'''
加载源语和目标语数据,筛选出最长的样例,用于生成掩码
:param fpath1: 源语地址
:param fpath2: 目标语地址
:param maxlen1: 源语句子中最长的长度
:param maxlen2: 目标语句子中最长的长度
:return:
'''
sents1, sents2 = [], []
with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
for sent1, sent2 in zip(f1, f2):
if len(sent1.split()) + 1 > maxlen1: continue # 1: </s>
if len(sent2.split()) + 1 > maxlen2: continue # 1: </s>
sents1.append(sent1.strip())
sents2.append(sent2.strip())
return sents1, sents2
def encode(inp, type, dict):
'''
将字符串转为数字,用于generator_fn。
:param inp: 一维
:param type: x表示源语,y表示目标语
:param dict: token2idx字典
:return: 数字列表
'''
inp_str=inp.decode("utf-8")
if type=='x': tokens=inp_str.split()+["</s>"]
else: tokens=["<s>"]+inp_str.split()+["</s>"]
x=[dict.get(t,dict["<unk>"])for t in tokens]
return x
def generator_fn(sents1, sents2, vocab_fpath):
'''
生成训练和评价数据
:param sents1: 源语句子列表
:param sents2: 目标句子列表
:param vocab_fpath: 字符串,词文件地址
'''
token2idx, _ = load_vocab(vocab_fpath)
for sent1, sent2 in zip(sents1, sents2):
x = encode(sent1, "x", token2idx)
y = encode(sent2, "y", token2idx)
decoder_input, y = y[:-1], y[1:]
x_seqlen, y_seqlen = len(x), len(y)
yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
'''
批量化数据
:param sents1: 源语句子列表
:param sents2: 目标句子列表
:param vocab_fpath: 字符串,词文件地址
batch_size: scalar
shuffle: boolean
Returns
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
'''
shapes = (([None], (), ()),
([None], [None], (), ()))
types = ((tf.int32, tf.int32, tf.string),
(tf.int32, tf.int32, tf.int32, tf.string))
paddings = ((0, 0, ''),
(0, 0, 0, ''))
dataset = tf.data.Dataset.from_generator(
generator_fn,
output_shapes=shapes,
output_types=types,
args=(sents1, sents2, vocab_fpath)) # <- arguments for generator_fn. converted to np string arrays
if shuffle: # for training
dataset = dataset.shuffle(128*batch_size)
dataset = dataset.repeat() # iterate forever
dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
return dataset
def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
'''
获取训练和评价小型数据
fpath1: source file path. string.
fpath2: target file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
batches
num_batches: number of mini-batches
num_samples
'''
sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
num_batches = calc_num_batches(len(sents1), batch_size)
return batches, num_batches, len(sents1)