Pytorch学习记录-Seq2Seq打包填充序列、掩码和推理模

2019-04-26  本文已影响0人  我的昵称违规了

Pytorch学习记录-torchtext和Pytorch的实例4

0. PyTorch Seq2Seq项目介绍

在完成基本的torchtext之后,找到了这个教程,《基于Pytorch和torchtext来理解和实现seq2seq模型》。
这个项目主要包括了6个子项目

  1. 使用神经网络训练Seq2Seq
  2. 使用RNN encoder-decoder训练短语表示用于统计机器翻译
  3. 使用共同学习完成NMT的堆砌和翻译
  4. 打包填充序列、掩码和推理
  5. 卷积Seq2Seq
  6. Transformer

4. 打包填充序列、掩码和推理

教程基于之前的模型增加了打包填充序列、掩码。

这个教程同样也会关注模型的推理,给定句子,查看翻译结果。找出究竟注意力机制关注哪些词。

4.1 引入库和数据预处理

4.2 构建模型

4.3 训练模型

INPUT_DIM=len(SRC.vocab)
OUTPUT_DIM=len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
PAD_IDX = SRC.vocab.stoi['<pad>']
SOS_IDX = TRG.vocab.stoi['<sos>']
EOS_IDX = TRG.vocab.stoi['<eos>']

attn=Attention(ENC_HID_DIM,DEC_HID_DIM)
enc=Encoder(INPUT_DIM, ENC_EMB_DIM,ENC_HID_DIM,DEC_HID_DIM,ENC_DROPOUT)
dec=Decoder(OUTPUT_DIM,DEC_EMB_DIM,ENC_HID_DIM,DEC_HID_DIM,DEC_DROPOUT,attn)
model=Seq2Seq(enc,dec,PAD_IDX,SOS_IDX,EOS_IDX,device).to(device)

def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)
Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(7855, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
    )
    (embedding): Embedding(5893, 256)
    (rnn): GRU(1280, 512)
    (out): Linear(in_features=1792, out_features=5893, bias=True)
    (dropout): Dropout(p=0.5)
  )
)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 20,518,917 trainable parameters
optimizer=optim.Adam(model.parameters())
criterion=nn.CrossEntropyLoss(ignore_index=PAD_IDX)
def train(model,iterator,optimizer,criterion,clip):
    model.train()
    epoch_loss=0
    for i ,batch in enumerate(iterator):
        src,src_len=batch.src
        trg=batch.trg
        optimizer.zero_grad()
        
        output, attetion = model(src, src_len, trg)
        output=output[1:].view(-1,output.shape[-1])
        trg=trg[1:].view(-1)
        
        loss=criterion(output,trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),clip)
        optimizer.step()
        epoch_loss+=loss.item()
        
    return epoch_loss/len(iterator)
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src, src_len = batch.src
            trg = batch.trg

            output, attention = model(src, src_len, trg, 0) #turn off teacher forcing

            #trg = [trg sent len, batch size]
            #output = [trg sent len, batch size, output dim]

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)

            #trg = [(trg sent len - 1) * batch size]
            #output = [(trg sent len - 1) * batch size, output dim]

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
N_EPOCHS = 5
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut4-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
Epoch: 01 | Time: 25m 14s
    Train Loss: 3.859 | Train PPL:  47.433
     Val. Loss: 3.552 |  Val. PPL:  34.896
Epoch: 02 | Time: 25m 24s
    Train Loss: 2.656 | Train PPL:  14.246
     Val. Loss: 3.220 |  Val. PPL:  25.036
Epoch: 03 | Time: 25m 24s
    Train Loss: 2.177 | Train PPL:   8.819
     Val. Loss: 3.158 |  Val. PPL:  23.516
Epoch: 04 | Time: 25m 25s
    Train Loss: 1.854 | Train PPL:   6.383
     Val. Loss: 3.301 |  Val. PPL:  27.134
Epoch: 05 | Time: 25m 29s
    Train Loss: 1.636 | Train PPL:   5.137
     Val. Loss: 3.362 |  Val. PPL:  28.838
model.load_state_dict(torch.load('tut4-model.pt'))

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
| Test Loss: 3.206 | Test PPL:  24.676 |

Epoch: 01 | Time: 0m 46s
Train Loss: 5.062 | Train PPL: 157.870
Val. Loss: 4.775 | Val. PPL: 118.542
Epoch: 02 | Time: 0m 48s
Train Loss: 4.159 | Train PPL: 64.034
Val. Loss: 4.282 | Val. PPL: 72.351
Epoch: 03 | Time: 0m 47s
Train Loss: 3.450 | Train PPL: 31.493
Val. Loss: 3.649 | Val. PPL: 38.430
Epoch: 04 | Time: 0m 48s
Train Loss: 2.913 | Train PPL: 18.418
Val. Loss: 3.405 | Val. PPL: 30.110
Epoch: 05 | Time: 0m 48s
Train Loss: 2.511 | Train PPL: 12.312
Val. Loss: 3.275 | Val. PPL: 26.450
Epoch: 06 | Time: 0m 48s
Train Loss: 2.218 | Train PPL: 9.184
Val. Loss: 3.264 | Val. PPL: 26.166
Epoch: 07 | Time: 0m 48s
Train Loss: 1.975 | Train PPL: 7.204
Val. Loss: 3.174 | Val. PPL: 23.892
Epoch: 08 | Time: 0m 48s
Train Loss: 1.766 | Train PPL: 5.848
Val. Loss: 3.266 | Val. PPL: 26.204
Epoch: 09 | Time: 0m 48s
Train Loss: 1.616 | Train PPL: 5.035
Val. Loss: 3.282 | Val. PPL: 26.619
Epoch: 10 | Time: 0m 48s
Train Loss: 1.508 | Train PPL: 4.518
Val. Loss: 3.271 | Val. PPL: 26.341

上一篇下一篇

猜你喜欢

热点阅读