Pytorch学习记录-卷积Seq2Seq(模型实现)
Pytorch学习记录-torchtext和Pytorch的实例5
0. PyTorch Seq2Seq项目介绍
在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目
使用神经网络训练Seq2Seq使用RNN encoder-decoder训练短语表示用于统计机器翻译使用共同学习完成NMT的堆砌和翻译打包填充序列、掩码和推理- 卷积Seq2Seq
- Transformer
5. 卷积Seq2Seq
5.1 准备数据
还是老一套,使用torchtext对英-德语料进行处理。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import time
SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True
spacy_de=spacy.load('de')
spacy_en=spacy.load('en')
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(tokenize=tokenize_de,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
TRG=Field(tokenize=tokenize_en,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
train_data,valid_data,test_data=Multi30k.splits(exts=('.de','.en'),fields=(SRC,TRG))
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
BATCH_SIZE=128
train_iterator,valid_iterator,test_iterator=BucketIterator.splits(
(train_data,valid_data,test_data),
batch_size=BATCH_SIZE,
device=device
)
5.2 构建模型
这个是教程的原图,但是没有做什么解读,我又找了一篇解读的教程看。


该模型依旧是 encoder-decoder + attention 模块的大框架:encoder 和 decoder 采用了相同的卷积结构,其中的非线性部分采用的是门控结构 gated linear units(GLU);attention 部分采用的是多跳注意力机制 multi-hop attention,也即在 decoder 的每一个卷积层都会进行 attention 操作,并将结果输入到下一层。

-
卷积块结构
encoder 和 decoder 都是由 l 层卷积层构成,encoder 输出为,decoder输出为
。由于卷积网络是层级结构,通过层级叠加能够得到远距离的两个词之间的关系信息。这里把一次 "卷积计算+非线性计算" 看作一个单元 Convolutional Block,这个单元在一个卷积层内是共享的。
卷积块中包括了卷积计算、非线性计算、残差连接和输出 -
多步注意力
原理与传统的 attention 相似,attention 权重由 decoder 的当前输出和 encoder 的所有输出共同决定,利用该权重对 encoder 的输出进行加权,得到了表示输入句子信息的向量
,
和
相加组成新的
。
在每一个卷积层都会进行 attention 的操作,得到的结果输入到下一层卷积层,这就是多跳注意机制 multi-hop attention。这样做的好处是使得模型在得到下一个注意时,能够考虑到之前的已经注意过的词。
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, kernel_size, dropout, device):
super(Encoder, self).__init__()
assert kernel_size %2 ==1,"Kernel size must be odd!"
self.input_dim=input_dim
self.emb_dim=emb_dim
self.hid_dim=hid_dim
self.kernel_size=kernel_size
self.dropout=dropout
self.device=device
self.scale=torch.sqrt(torch.FloatTensor([0.5])).to(device)
# 词嵌入
self.tok_embedding=nn.Embedding(input_dim,emb_dim)
# 位置信息
self.pos_embedding=nn.Embedding(100,emb_dim)
self.emb2hid=nn.Linear(emb_dim,hid_dim)
self.hid2emb=nn.Linear(hid_dim,emb_dim)
self.convs=nn.ModuleList([nn.Conv1d(in_channels=hid_dim,out_channels=2*hid_dim,kernel_size=kernel_size,padding=(kernel_size-1)//2) for _ in range(n_layers)])
self.dropout=nn.Dropout(dropout)
def forward(self, src):
# src=[batch_size, src_sent_len]
# 构造pos张量,就是使用src的格式构建一个相同的batch_size的张量
pos=torch.arrange(0,src.shape[1]).unsqueeze(0).repeat(src.shape[0],1).to(self.device)
# pos=[batch_size, src_sent_len]
# 对tok和pos都做词嵌入
tok_embedded=self.tok_embedded(src)
pos_embedded=self.pos_embedded(pos)
#tok_embedded = pos_embedded = [batch size, src sent len, emb dim]
# 将tok_embedded和pos_embedded整合起来
embedded=self.dropout(tok_embedded+pos_embedded)
# 通过linear层将嵌入好的数据传入转为hid_dim
conv_input=self.emb2hid(embedded)
#conv_input = [batch size, hid dim, src sent len]
for i,conv in enumerate(self.convs):
conved=conv(self.dropout(conv_input))
#conved = [batch size, 2*hid dim, src sent len]
conved=F.glu(conved,dim=1)
#conved = [batch size, hid dim, src sent len]
# 传入残差连接
conved=(conved+conv_input)*self.scale
#conved = [batch size, hid dim, src sent len]
conv_input=conved
# 使用permute进行转置,将最后一个元素的转为emb_dim
conved=self.hid2emb(conved.permute(0,2,1))
#conved = [batch size, src sent len, emb dim]
combined=(conved+embedded)*self.scale
return conved,combined
5.2.2 Decoder
Decoder部分包括了attention结构,看一下代码会发现出了增加了attn_hid2emb和attn_emb2hid,其余类似。
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, kernel_size, dropout, pad_idx, device):
super(Decoder, self).__init__()
self.output_dim = output_dim
self.emb_dim = emb_dim
self.hid_dim = hid_dim
self.kernel_size = kernel_size
self.dropout = dropout
self.pad_idx = pad_idx
self.device = device
self.scale=torch.sqrt(torch.FloatTensor([0.5])).to(device)
self.tok_embedding = nn.Embedding(output_dim, emb_dim)
self.pos_embedding = nn.Embedding(100, emb_dim)
self.emb2hid = nn.Linear(emb_dim, hid_dim)
self.hid2emb = nn.Linear(hid_dim, emb_dim)
self.attn_hid2emb=nn.Linear(hid_dim,emb_dim)
self.attn_emb2hid=nn.Linear(emb_dim,hid_dim)
self.out=nn.Linear(emb_dim,output_dim)
self.convs = nn.ModuleList([nn.Conv1d(hid_dim, 2*hid_dim, kernel_size) for _ in range(n_layers)])
self.dropout = nn.Dropout(dropout)
def calculate_attention(self, embedded, conved, encoder_conved,encoder_combined):
conved_emb=self.attn_hid2emb(conved.permute(0,2,1))
#conved_emb = [batch size, trg sent len, emb dim]
combined=(embedded+conved_emb)*self.scale
#combined = [batch size, trg sent len, emb dim]
energy=torch.matmul(combined, encoder_conved.permute(0,2,1))
#energy = [batch size, trg sent len, src sent len]
attention=F.softmax(energy, dim=2)
#attention = [batch size, trg sent len, src sent len]
attended_encoding=torch.matmul(attention,(encoder_conved+encoder_combined))
#attended_encoding = [batch size, trg sent len, emd dim]
attended_combined = (conved + attended_encoding.permute(0, 2, 1)) * self.scale
#attended_combined = [batch size, hid dim, trg sent len]
return attention, attended_combined
def forward(self, trg, encoder_conved, encoder_combined):
#trg = [batch size, trg sent len]
#pos = [batch size, trg sent len]
#encoder_conved = encoder_combined = [batch size, src sent len, emb dim]
pos=torch.arange(0, trg.shape[1]).unsqueeze(0).repeat(trg.shape[0], 1).to(device)
tok_embedded = self.tok_embedding(trg)
pos_embedded = self.pos_embedding(pos)
#tok_embedded = [batch size, trg sent len, emb dim]
#pos_embedded = [batch size, trg sent len, emb dim]
embedded = self.dropout(tok_embedded + pos_embedded)
#embedded = [batch size, trg sent len, emb dim]
conv_input=self.emb2hid(embedded)
#conv_input = [batch size, trg sent len, hid dim]
conv_input=conv_input.permute(0,2,1)
#conv_input = [batch size, hid dim, trg sent len]
for i , conv in enumerate(self.convs):
conv_input=self.dropout(conv_input)
#need to pad so decoder can't "cheat"
padding = torch.zeros(conv_input.shape[0], conv_input.shape[1], self.kernel_size-1).fill_(self.pad_idx).to(device)
padded_conv_input = torch.cat((padding, conv_input), dim=2)
conved=conv(padded_conv_input)
#conved = [batch size, 2*hid dim, trg sent len]
conved=F.glu(conved, dim=1)
#conved = [batch size, hid dim, trg sent len]
attention, conved = self.calculate_attention(embedded, conved, encoder_conved, encoder_combined)
#attention = [batch size, trg sent len, src sent len]
#conved = [batch size, hid dim, trg sent len]
conved=(conved+conv_input)*self.scale
conv_input=conved
conved=self.hid2emb(conved.permute(0,2,1))
output=self.out(self.dropout(conved))
return output, attention
5.2.3 模型整合Seq2Seq
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg):
#src = [batch size, src sent len]
#trg = [batch size, trg sent len]
#calculate z^u (encoder_conved) and e (encoder_combined)
#encoder_conved is output from final encoder conv. block
#encoder_combined is encoder_conved plus (elementwise) src embedding plus positional embeddings
encoder_conved, encoder_combined = self.encoder(src)
#encoder_conved = [batch size, src sent len, emb dim]
#encoder_combined = [batch size, src sent len, emb dim]
#calculate predictions of next words
#output is a batch of predictions for each word in the trg sentence
#attention a batch of attention scores across the src sentence for each word in the trg sentence
output, attention = self.decoder(trg, encoder_conved, encoder_combined)
#output = [batch size, trg sent len, output dim]
#attention = [batch size, trg sent len, src sent len]
return output, attention