Bi-LSTM with Attention (PyTorch
2021-05-16 本文已影响0人
Jarkata
这里用Bi-LSTM + Attention机制实现一个简单的句子分类任务。
先导包
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.utils.data as Data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
定义Bi-LSTM(Attention)的参数
batch_size = 3
embedding_dim = 2
n_hidden = 5 # number of hidden units in one cell
num_classes = 2 # 0 or 1
处理数据,构造label
# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good.
vocab = list(set(" ".join(sentences).split()))
word2idx = {w: i for i, w in enumerate(vocab)}
vocab_size = len(word2idx)
构建数据集
def make_data(sentences):
inputs = []
for sen in sentences:
inputs.append(np.asarray([word2idx[n] for n in sen.split()]))
targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function
return torch.LongTensor(inputs), torch.LongTensor(targets)
inputs, targets = make_data(sentences)
dataset = Data.TensorDataset(inputs, targets)
loader = Data.DataLoader(dataset, batch_size, True)
定义模型,注意这里的attention_net层,接受输入是lstm_output和final_state。根据lstm_output和final_state之间的相关性得到context向量和attention权重向量。之后将context向量经过一个全连接层,得到最终的分类结果。
class BiLSTM_Attention(nn.Module):
def __init__(self):
super(BiLSTM_Attention, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)
self.out = nn.Linear(n_hidden * 2, num_classes)
def attention_net(self,lstm_output, final_state):
# lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix
# final_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
batch_size = len(lstm_output)
# hidden = final_state.view(batch_size,-1,1)
hidden = torch.cat((final_state[0],final_state[1]),dim=1).unsqueeze(2)
# hidden : [batch_size, n_hidden * num_directions(=2), n_layer(=1)]
attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)
# attn_weights : [batch_size,n_step]
soft_attn_weights = F.softmax(attn_weights,1)
# context: [batch_size, n_hidden * num_directions(=2)]
context = torch.bmm(lstm_output.transpose(1,2),soft_attn_weights.unsqueeze(2)).squeeze(2)
return context, soft_attn_weights
def forward(self,X):
'''
:param X: [batch_size, seq_len]
:return:
'''
input = self.embedding(X) # input : [batch_size, seq_len, embedding_dim]
input = input.transpose(0, 1) # input : [seq_len, batch_size, embedding_dim]
# final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
# output : [seq_len, batch_size, n_hidden * num_directions(=2)]
output, (final_hidden_state, final_cell_state) = self.lstm(input)
output = output.transpose(0, 1) #output : [batch_size, seq_len, n_hidden * num_directions(=2)]
attn_output, attention = self.attention_net(output,final_hidden_state)
return self.out(attn_output),attention # attn_output : [batch_size, num_classes], attention : [batch_size, n_step]
model = BiLSTM_Attention().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练
for epoch in range(5000):
for x, y in loader:
x, y = x.to(device), y.to(device)
pred, attention = model(x)
loss = criterion(pred,y)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
测试
test_text = 'i hate me'
tests = [np.asarray([word2idx[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests).to(device)
# Predict
predict, _ = model(test_batch)
predict = predict.data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
查看Attention矩阵
fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention.cpu().data, cmap='viridis')
ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)
ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})
plt.show()