Pytorch学习记录-更深的TorchText学习01
Pytorch学习记录-更深的TorchText学习01
简单实现torchtext之后,我希望能够进一步学习torchtext。找到两个教程
1. practical-torchtext简介
有效使用torchtext的教程,包括两个部分
- 文本分类
- 词级别的语言模型
1.1 目标
torchtext的文档仍然相对不完整,目前有效地使用torchtext需要阅读相当多的代码。这组教程旨在提供使用torchtext的工作示例,以使更多用户能够充分利用这个风扇库。
1.2 使用
torchtext的当前pip版本存在一些错误,这些错误会导致某些代码运行不正确。这些错误目前只在torchtext的github存储库的主分支上修复。因此,教程建议使用以下命令从github存储库安装torchtext:
pip install --upgrade git+https://github.com/pytorch/text
2. 基于torchtext处理的文本分析
我看了一下,第一课是基于这两天反复操作的那个教程,但是作者进行了丰富和解释,就放在这里再跑一次了
2.0 简介
前面都是一致的,加载数据,预处理之后生成dataset,输入模型。
使用的数据集还是之前的Kaggle垃圾信息数据。
import pandas as pd
import numpy as np
import torch
from torch.nn import init
from torchtext.data import Field
2.1 声明Fields
Field类用于确定数据是如何预处理并转化为数字格式。
很简单。标签的预处理更加容易,因为它们已经转换为二进制编码。我们需要做的就是告诉Field类标签已经处理完毕。我们通过将use_vocab = False关键字传递给构造函数来完成此操作
tokenize=lambda x: x.split()
TEXT=Field(sequential=True, tokenize=tokenize, lower=True)
LABEL=Field(sequential=False, use_vocab=False)
2.2 创建Dataset
我们将使用TabularDataset类来读取我们的数据,因为它是csv格式(截至目前,TabularDataset处理csv,tsv和json文件)
对于列车和验证数据,我们需要处理标签。我们传入的字段必须与列的顺序相同。对于我们不使用的字段,我们传入一个元组,其中第二个元素是None
%%time
from torchtext.data import TabularDataset
tv_datafields=[
('id',None),
('comment_text',TEXT),
("toxic", LABEL),
("severe_toxic", LABEL),
("threat", LABEL),
("obscene", LABEL),
("insult", LABEL),
("identity_hate", LABEL)
]
trn,vld=TabularDataset.splits(
path=r'C:\Users\jwc19\Desktop\2001_2018jszyfz\code\data\torchtextdata',
train='train.csv',
validation='valid.csv',
format='csv',
skip_header=True,
fields=tv_datafields
)
Wall time: 4.99 ms
%%time
tst_datafields=[
('id',None),
('comment_text',TEXT)
]
tst=TabularDataset(
path=r'C:\Users\jwc19\Desktop\2001_2018jszyfz\code\data\torchtextdata\test.csv',
format='csv',
skip_header=True,
fields=tst_datafields
)
Wall time: 3.01 ms
2.3 构建字典
对于TEXT字段将单词转换为整数,需要告诉整个词汇是什么。为此,我们运行TEXT.build_vocab,传入数据集以构建词汇表。
%%time
TEXT.build_vocab(trn)
TEXT.vocab.freqs.most_common(10)
print(TEXT.vocab.freqs.most_common(10))
[('the', 78), ('to', 41), ('you', 33), ('of', 30), ('and', 26), ('a', 26), ('is', 24), ('that', 22), ('i', 20), ('if', 19)]
Wall time: 3.99 ms
在这里,dataset中每一个元素都是一个Example对象,包含有若干独立数据
# 查看trn这个field的标签
print(trn[0].__dict__.keys())
# 查看某一行中的文本,在结果中可以看到,已有的文本是已经被分好词的
print(trn[10].comment_text)
dict_keys(['comment_text', 'toxic', 'severe_toxic', 'threat', 'obscene', 'insult', 'identity_hate'])
['"', 'fair', 'use', 'rationale', 'for', 'image:wonju.jpg', 'thanks', 'for', 'uploading', 'image:wonju.jpg.', 'i', 'notice', 'the', 'image', 'page', 'specifies', 'that', 'the', 'image', 'is', 'being', 'used', 'under', 'fair', 'use', 'but', 'there', 'is', 'no', 'explanation', 'or', 'rationale', 'as', 'to', 'why', 'its', 'use', 'in', 'wikipedia', 'articles', 'constitutes', 'fair', 'use.', 'in', 'addition', 'to', 'the', 'boilerplate', 'fair', 'use', 'template,', 'you', 'must', 'also', 'write', 'out', 'on', 'the', 'image', 'description', 'page', 'a', 'specific', 'explanation', 'or', 'rationale', 'for', 'why', 'using', 'this', 'image', 'in', 'each', 'article', 'is', 'consistent', 'with', 'fair', 'use.', 'please', 'go', 'to', 'the', 'image', 'description', 'page', 'and', 'edit', 'it', 'to', 'include', 'a', 'fair', 'use', 'rationale.', 'if', 'you', 'have', 'uploaded', 'other', 'fair', 'use', 'media,', 'consider', 'checking', 'that', 'you', 'have', 'specified', 'the', 'fair', 'use', 'rationale', 'on', 'those', 'pages', 'too.', 'you', 'can', 'find', 'a', 'list', 'of', "'image'", 'pages', 'you', 'have', 'edited', 'by', 'clicking', 'on', 'the', '""my', 'contributions""', 'link', '(it', 'is', 'located', 'at', 'the', 'very', 'top', 'of', 'any', 'wikipedia', 'page', 'when', 'you', 'are', 'logged', 'in),', 'and', 'then', 'selecting', '""image""', 'from', 'the', 'dropdown', 'box.', 'note', 'that', 'any', 'fair', 'use', 'images', 'uploaded', 'after', '4', 'may,', '2006,', 'and', 'lacking', 'such', 'an', 'explanation', 'will', 'be', 'deleted', 'one', 'week', 'after', 'they', 'have', 'been', 'uploaded,', 'as', 'described', 'on', 'criteria', 'for', 'speedy', 'deletion.', 'if', 'you', 'have', 'any', 'questions', 'please', 'ask', 'them', 'at', 'the', 'media', 'copyright', 'questions', 'page.', 'thank', 'you.', '(talk', '•', 'contribs', '•', ')', 'unspecified', 'source', 'for', 'image:wonju.jpg', 'thanks', 'for', 'uploading', 'image:wonju.jpg.', 'i', 'noticed', 'that', 'the', "file's", 'description', 'page', 'currently', "doesn't", 'specify', 'who', 'created', 'the', 'content,', 'so', 'the', 'copyright', 'status', 'is', 'unclear.', 'if', 'you', 'did', 'not', 'create', 'this', 'file', 'yourself,', 'then', 'you', 'will', 'need', 'to', 'specify', 'the', 'owner', 'of', 'the', 'copyright.', 'if', 'you', 'obtained', 'it', 'from', 'a', 'website,', 'then', 'a', 'link', 'to', 'the', 'website', 'from', 'which', 'it', 'was', 'taken,', 'together', 'with', 'a', 'restatement', 'of', 'that', "website's", 'terms', 'of', 'use', 'of', 'its', 'content,', 'is', 'usually', 'sufficient', 'information.', 'however,', 'if', 'the', 'copyright', 'holder', 'is', 'different', 'from', 'the', "website's", 'publisher,', 'then', 'their', 'copyright', 'should', 'also', 'be', 'acknowledged.', 'as', 'well', 'as', 'adding', 'the', 'source,', 'please', 'add', 'a', 'proper', 'copyright', 'licensing', 'tag', 'if', 'the', 'file', "doesn't", 'have', 'one', 'already.', 'if', 'you', 'created/took', 'the', 'picture,', 'audio,', 'or', 'video', 'then', 'the', 'tag', 'can', 'be', 'used', 'to', 'release', 'it', 'under', 'the', 'gfdl.', 'if', 'you', 'believe', 'the', 'media', 'meets', 'the', 'criteria', 'at', 'wikipedia:fair', 'use,', 'use', 'a', 'tag', 'such', 'as', 'or', 'one', 'of', 'the', 'other', 'tags', 'listed', 'at', 'wikipedia:image', 'copyright', 'tags#fair', 'use.', 'see', 'wikipedia:image', 'copyright', 'tags', 'for', 'the', 'full', 'list', 'of', 'copyright', 'tags', 'that', 'you', 'can', 'use.', 'if', 'you', 'have', 'uploaded', 'other', 'files,', 'consider', 'checking', 'that', 'you', 'have', 'specified', 'their', 'source', 'and', 'tagged', 'them,', 'too.', 'you', 'can', 'find', 'a', 'list', 'of', 'files', 'you', 'have', 'uploaded', 'by', 'following', '[', 'this', 'link].', 'unsourced', 'and', 'untagged', 'images', 'may', 'be', 'deleted', 'one', 'week', 'after', 'they', 'have', 'been', 'tagged,', 'as', 'described', 'on', 'criteria', 'for', 'speedy', 'deletion.', 'if', 'the', 'image', 'is', 'copyrighted', 'under', 'a', 'non-free', 'license', '(per', 'wikipedia:fair', 'use)', 'then', 'the', 'image', 'will', 'be', 'deleted', '48', 'hours', 'after', '.', 'if', 'you', 'have', 'any', 'questions', 'please', 'ask', 'them', 'at', 'the', 'media', 'copyright', 'questions', 'page.', 'thank', 'you.', '(talk', '•', 'contribs', '•', ')', '"']
2.4 构建迭代器
在训练期间,将使用一种称为BucketIterator的特殊迭代器。当数据传递到神经网络时,我们希望将数据填充为相同的长度,以便我们可以批量处理它们。
如果序列的长度差异很大,则填充将消耗大量浪费的内存和时间。BucketIterator将每个批次的相似长度的序列组合在一起,以最小化填充。
from torchtext.data import Iterator, BucketIterator
# sort_key就是告诉BucketIterator使用哪个key值去进行组合,很明显,在这里是comment_text
# repeat设定为False是因为之后要打包这个迭代层
train_iter, val_iter=BucketIterator.splits(
(trn,vld),
batch_sizes=(64,64),
device=-1,
sort_key=lambda x:len(x.comment_text),
sort_within_batch=False,
repeat=False
)
# 现在就可以看一下输出的BucketIterator是怎样的。
batch=next(train_iter.__iter__());batch
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
[torchtext.data.batch.Batch of size 25]
[.comment_text]:[torch.LongTensor of size 494x25]
[.toxic]:[torch.LongTensor of size 25]
[.severe_toxic]:[torch.LongTensor of size 25]
[.threat]:[torch.LongTensor of size 25]
[.obscene]:[torch.LongTensor of size 25]
[.insult]:[torch.LongTensor of size 25]
[.identity_hate]:[torch.LongTensor of size 25]
batch.__dict__.keys()
dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'comment_text', 'toxic', 'severe_toxic', 'threat', 'obscene', 'insult', 'identity_hate'])
test_iter = Iterator(tst, batch_size=64, device=-1, sort=False, sort_within_batch=False, repeat=False)
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
2.5 打包迭代器
目前,迭代器返回一个名为torchtext.data.Batch的自定义数据类型。这使得代码重用变得困难(因为每次列名更改时,我们都需要修改代码),并且使得torchtext很难与其他库一起用于某些用例(例如torchsample和fastai)。
这里教程将写了一个简单的包装器,使批量易于使用。具体地说,我们将批处理转换为元素形式(x,y),其中x是自变量(模型的输入),y是因变量(监督数据)。
class BatchWrapper:
def __init__(self, dl, x_var, y_vars):
self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # we pass in the list of attributes for x and y
def __iter__(self):
for batch in self.dl:
x = getattr(batch, self.x_var) # we assume only one input in this wrapper
if self.y_vars is not None: # we will concatenate y into a single tensor
y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
else:
y = torch.zeros((1))
yield (x, y)
def __len__(self):
return len(self.dl)
train_dl = BatchWrapper(train_iter, "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"])
valid_dl = BatchWrapper(val_iter, "comment_text", ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"])
test_dl = BatchWrapper(test_iter, "comment_text", None)
验证一下,这里有一个理解,iter方法是用来迭代出tensor的?似乎这样是可以的。
next(train_dl.__iter__())
(tensor([[ 63, 66, 354, ..., 334, 453, 778],
[ 4, 82, 63, ..., 55, 523, 650],
[664, 2, 4, ..., 520, 30, 22],
...,
[ 1, 1, 1, ..., 1, 1, 1],
[ 1, 1, 1, ..., 1, 1, 1],
[ 1, 1, 1, ..., 1, 1, 1]]),
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[1., 1., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]]))
2.6 训练一个文本分类器
依旧是LSTM
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
class LSTM(nn.Module):
def __init__(self, hidden_dim, emb_dim=300, num_linear=1):
super().__init__()
self.embedding = nn.Embedding(len(TEXT.vocab), emb_dim)
self.encoder = nn.LSTM(emb_dim, hidden_dim, num_layers=1)
self.linear_layers = []
for _ in range(num_linear - 1):
self.linear_layers.append(nn.Linear(hidden_dim, hidden_dim))
self.linear_layer = nn.ModuleList(self.linear_layers)
self.predictor = nn.Linear(hidden_dim, 6)
def forward(self, seq):
hdn, _ = self.encoder(self.embedding(seq))
feature = hdn[-1, :, :]
for layer in self.linear_layers:
feature = layer(feature)
preds = self.predictor(feature)
return preds
em_sz = 100
nh = 500
nl = 3
model = LSTM(nh, emb_dim=em_sz)
%%time
import tqdm
opt=optim.Adam(model.parameters(),lr=1e-2)
loss_func=nn.BCEWithLogitsLoss()
epochs=2
Wall time: 0 ns
for epoch in range(1, epochs + 1):
running_loss = 0.0
running_corrects = 0
model.train()
for x, y in tqdm.tqdm(train_dl):
opt.zero_grad()
preds = model(x)
loss = loss_func(y, preds)
loss.backward()
opt.step()
running_loss += loss.item()* x.size(0)
epoch_loss = running_loss / len(trn)
val_loss = 0.0
model.eval() # 评估模式
for x, y in valid_dl:
preds = model(x)
loss = loss_func(y, preds)
val_loss += loss.item()* x.size(0)
val_loss /= len(vld)
print('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, epoch_loss, val_loss))
test_preds = []
for x, y in tqdm.tqdm(test_dl):
preds = model(x)
preds = preds.data.numpy()
# 模型的实际输出是logit,所以再经过一个sigmoid函数
preds = 1 / (1 + np.exp(-preds))
test_preds.append(preds)
test_preds = np.hstack(test_preds)
print(test_preds)
100%|██████████| 1/1 [00:06<00:00, 6.28s/it]
Epoch: 1, Training Loss: 14.2130, Validation Loss: 4.4170
100%|██████████| 1/1 [00:04<00:00, 4.20s/it]
Epoch: 2, Training Loss: 10.5315, Validation Loss: 3.3947
100%|██████████| 1/1 [00:00<00:00, 2.87it/s]
[[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99978834 0.99761593 0.5279695 0.9961003 0.9957486 0.3662841 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]
[0.99982786 0.99812 0.53367174 0.99682033 0.9966144 0.3649216 ]]
2.7 测试数据并查看
test_preds = []
for x, y in tqdm.tqdm(test_dl):
preds = model(x)
# if you're data is on the GPU, you need to move the data back to the cpu
# preds = preds.data.cpu().numpy()
preds = preds.data.numpy()
# the actual outputs of the model are logits, so we need to pass these values to the sigmoid function
preds = 1 / (1 + np.exp(-preds))
test_preds.append(preds)
test_preds = np.hstack(test_preds)
100%|██████████| 1/1 [00:00<00:00, 2.77it/s]
df = pd.read_csv("./data/torchtextdata/test.csv")
for i, col in enumerate(["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]):
df[col] = test_preds[:, i]
df.head(3)
image.png