基于torchtext的文本分类

2021-05-19  本文已影响0人  都灵的夏天_
%matplotlib inline

基于torchtext的文本分类

在这个项目,我们展示怎样去使用torchtext库建立数据集用于文本分类统计,用户可以去灵活调用,

访问原始数据集iterator

import torch
from torchtext.datasets import AG_NEWS
# train_iter = AG_NEWS(split='train')

# 如果不给root参数,会自动从网站下载。
train_iter = AG_NEWS(root='./data/ag_news_csv/', split='train')
train_iter
<torchtext.data.datasets_utils._RawTextIterableDataset at 0x7f125d5ba100>

::

next(train_iter)
>>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - 
Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green 
again.")

next(train_iter)
>>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private 
investment firm Carlyle Group,\\which has a reputation for making well-timed 
and occasionally\\controversial plays in the defense industry, has quietly 
placed\\its bets on another part of the market.')

next(train_iter)
>>> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring 
crude prices plus worries\\about the economy and the outlook for earnings are 
expected to\\hang over the stock market next week during the depth of 
the\\summer doldrums.")

准备数据处理pipelines

from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
"""
get_tokenizer函数的作用是创建一个分词器,将语料喂给相应的分词器,可以根据不同分词函数的规则完成分词,
分词器支持’basic_english’,‘spacy’,‘moses’,‘toktok’,‘revtok’,'subword’等规则
"""

tokenizer = get_tokenizer('basic_english')
# train_iter = AG_NEWS(split='train')
train_iter = AG_NEWS(root='./data/ag_news_csv/', split='train')

#实例化一个计数器
counter = Counter()
for (label, line) in train_iter:
    #update后的参数可以是:可迭代对象或者映射操作原理:如果要更新的关键字已存在,则对它的值进行求和;如果不存在,则添加
    counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)
[vocab[token] for token in ['here', 'is', 'an', 'example']]
[476, 22, 31, 5298]

vacab块将token列表转换为整数
::

[vocab[token] for token in ['here', 'is', 'an', 'example']]
>>> [476, 22, 31, 5298]

使用分词器(tokenizer)和词汇表(vocablary)准备文本处理管道(pipeline)。 文本和标签管道将用于处理来自数据集迭代器(iterator)的原始数据字符串。

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1

The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. The label pipeline converts the label into integers. For example,
text_pipeline根据vocab中定义的查找表将文本字符串转换为整数列表。 label_pipeline将标签转换为整数。 例如
::

text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9

生成数据批处理和迭代器

torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>__
被推荐给pytorch用户(使用说明 here <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>__).

from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(label_pipeline(_label))
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    #cusum 返回维度dim中输入元素的累计和
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    #orch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)    

# train_iter = AG_NEWS(split='train')
train_iter = AG_NEWS(root='./data/ag_news_csv/', split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

定义模型

该模型由nn.EmbeddingBag <https://pytorch.org/docs/stable/nn.html?highlight=embeddingbag#torch.nn.EmbeddingBag>__层以及用于分类目的的线性层组成。nn.EmbeddingBag,默认模式为“ mean”,计算嵌入的“袋”的平均值。 尽管此处的文本条目具有不同的长度,但是nn.EmbeddingBag模块此处不需要填充,因为文本长度以偏移量保存。

此外,由于nn.EmbeddingBag累积在fly的嵌入平均跨度
嵌入中的

nn.EmbeddingBag可以提升表现和内存效率来处理一系列张量。

image
from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        # vocab_size: 代表整个语料包含的单词总数
        # embed_dim: 代表词嵌入的维度
        # num_class: 代表是文本分类的类别数
        super(TextClassificationModel, self).__init__()
        
        # 实例化EMbeddingBag层的对象, 传入3个参数, 分别代表单词总数, 词嵌入的维度, 进行梯度求解时只更新部分权重
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        # 实例化全连接线性层的对象, 两个参数分别代表输入的维度和输出的维度
        self.fc = nn.Linear(embed_dim, num_class)
        # 对定义的所有层权重进行初始化
        self.init_weights()

    def init_weights(self):
        # 首先给定初始化权重的值域范围
        initrange = 0.5
        # 各层的权重使用均匀分布进行初始化
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        # text: 代表文本进过数字化映射后的张量
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

启动实例

AG_NEWS数据集有4个标签,因此num_class = 4
::

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

我们建立一个嵌入维度为64的模型。vcab_size的大小等于词汇实例的长度。 num_class等于标签的数量,

# train_iter = AG_NEWS(split='train')
train_iter = AG_NEWS(root='./data/ag_news_csv/', split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

Define functions to train the model and evaluate results.

import time

def train(dataloader):
    model.train()
    # 初始化训练损失值和准确率
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        # 训练模型的第一步: 将优化器的梯度清零
        optimizer.zero_grad()
        # 将一个批次的数据输入模型中, 进行预测
        predited_label = model(text, offsets)
        # 用损失函数来计算预测值和真实标签之间的损失
        loss = criterion(predited_label, label)
        # 进行反向传播的计算
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        # 参数更新
        optimizer.step()
        # 计算该批次的准确率并加到总准确率上, 注意一点这里加的是准确的数字
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        # 500轮打印一次
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    '''
    在model(test_datasets)之前,需要加上model.eval(). 
    否则的话,有输入数据,即使不训练,它也会改变权值。
    这是model中含有batch normalization层所带来的的性质。
    '''
    model.eval()
    total_acc, total_count = 0, 0
    # 注意: 在验证阶段, 一定要保证模型的参数不发生改变, 也就是不求梯度
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            # 将验证数据输入模型进行预测
            predited_label = model(text, offsets)
            #计算损失值
            loss = criterion(predited_label, label)
            # 将该批次的损失值累加到总损失值中
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

分割数据集并运行模型

由于原始的AG_NEWS没有验证数据集,因此我们拆分了训练
将数据集划分为训练/验证集,其分割比率为0.95(train),0.05(valid)。
我们使用torch.utils.data.dataset.random_split <https://pytorch.org/docs/stable/data.html?highlight=random_split#torch.utils.data.random_split>__
Pytorch核心库中的函数

CrossEntropyLoss <https://pytorch.org/docs/stable/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss>__
criterionnn.LogSoftmax()nn.NLLLoss()合并到一个类中。
在训练带有C类的分类问题时很有用。

SGD <https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html>__
implements stochastic gradient descent method as the optimizer.
实现随机梯度下降法作为优化程序。 最初的
学习率设置为5.0。
StepLR <https://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#StepLR>__
在这里,StepLR用于通过epochs调整学习率。

from torch.utils.data.dataset import random_split
# Hyperparameters
EPOCHS = 10 # epoch  指定训练的轮次
LR = 5  # learning rate 
BATCH_SIZE = 64 # batch size for training

# 定义损失函数, 定义交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()
# 定义优化器, 定义随机梯度下降优化器
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# 定义优化器步长的一个优化器, 专门用于学习率的衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
# train_iter, test_iter = AG_NEWS()
train_iter, test_iter = AG_NEWS(root='./data/ag_news_csv/')

train_dataset = list(train_iter)
test_dataset = list(test_iter)

# 选择全部训练数据的95%作为训练集数据, 剩下的5%作为验证数据
num_train = int(len(train_dataset) * 0.95)
# 子集1,子集2=random_split(数据集,[长度1,长度2])
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)
#训练10轮
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    # 如果total_accu!=0 且大于当前返回的准确率 调整学习率
    if total_accu is not None and total_accu > accu_val:
        # 进行整个轮次的优化器学习率的调整
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)
| epoch   1 |   500/ 1782 batches | accuracy    0.683
| epoch   1 |  1000/ 1782 batches | accuracy    0.852
| epoch   1 |  1500/ 1782 batches | accuracy    0.879
-----------------------------------------------------------
| end of epoch   1 | time: 10.02s | valid accuracy    0.886 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.895
| epoch   2 |  1000/ 1782 batches | accuracy    0.901
| epoch   2 |  1500/ 1782 batches | accuracy    0.904
-----------------------------------------------------------
| end of epoch   2 | time:  9.54s | valid accuracy    0.898 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.914
| epoch   3 |  1000/ 1782 batches | accuracy    0.915
| epoch   3 |  1500/ 1782 batches | accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time:  8.67s | valid accuracy    0.904 
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.923
| epoch   4 |  1000/ 1782 batches | accuracy    0.924
| epoch   4 |  1500/ 1782 batches | accuracy    0.924
-----------------------------------------------------------
| end of epoch   4 | time:  8.88s | valid accuracy    0.910 
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.931
| epoch   5 |  1000/ 1782 batches | accuracy    0.930
| epoch   5 |  1500/ 1782 batches | accuracy    0.929
-----------------------------------------------------------
| end of epoch   5 | time:  8.39s | valid accuracy    0.901 
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.940
| epoch   6 |  1000/ 1782 batches | accuracy    0.941
| epoch   6 |  1500/ 1782 batches | accuracy    0.944
-----------------------------------------------------------
| end of epoch   6 | time:  8.31s | valid accuracy    0.913 
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.943
| epoch   7 |  1000/ 1782 batches | accuracy    0.944
| epoch   7 |  1500/ 1782 batches | accuracy    0.941
-----------------------------------------------------------
| end of epoch   7 | time:  8.75s | valid accuracy    0.915 
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.944
| epoch   8 |  1000/ 1782 batches | accuracy    0.944
| epoch   8 |  1500/ 1782 batches | accuracy    0.944
-----------------------------------------------------------
| end of epoch   8 | time:  8.85s | valid accuracy    0.914 
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.946
| epoch   9 |  1000/ 1782 batches | accuracy    0.945
| epoch   9 |  1500/ 1782 batches | accuracy    0.947
-----------------------------------------------------------
| end of epoch   9 | time:  8.88s | valid accuracy    0.914 
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.946
| epoch  10 |  1000/ 1782 batches | accuracy    0.948
| epoch  10 |  1500/ 1782 batches | accuracy    0.946
-----------------------------------------------------------
| end of epoch  10 | time:  9.13s | valid accuracy    0.915 
-----------------------------------------------------------

Running the model on GPU with the following printout:

::

   | epoch   1 |   500/ 1782 batches | accuracy    0.684
   | epoch   1 |  1000/ 1782 batches | accuracy    0.852
   | epoch   1 |  1500/ 1782 batches | accuracy    0.877
   -----------------------------------------------------------
   | end of epoch   1 | time:  8.33s | valid accuracy    0.867
   -----------------------------------------------------------
   | epoch   2 |   500/ 1782 batches | accuracy    0.895
   | epoch   2 |  1000/ 1782 batches | accuracy    0.900
   | epoch   2 |  1500/ 1782 batches | accuracy    0.903
   -----------------------------------------------------------
   | end of epoch   2 | time:  8.18s | valid accuracy    0.890
   -----------------------------------------------------------
   | epoch   3 |   500/ 1782 batches | accuracy    0.914
   | epoch   3 |  1000/ 1782 batches | accuracy    0.914
   | epoch   3 |  1500/ 1782 batches | accuracy    0.916
   -----------------------------------------------------------
   | end of epoch   3 | time:  8.20s | valid accuracy    0.897
   -----------------------------------------------------------
   | epoch   4 |   500/ 1782 batches | accuracy    0.926
   | epoch   4 |  1000/ 1782 batches | accuracy    0.924
   | epoch   4 |  1500/ 1782 batches | accuracy    0.921
   -----------------------------------------------------------
   | end of epoch   4 | time:  8.18s | valid accuracy    0.895
   -----------------------------------------------------------
   | epoch   5 |   500/ 1782 batches | accuracy    0.938
   | epoch   5 |  1000/ 1782 batches | accuracy    0.935
   | epoch   5 |  1500/ 1782 batches | accuracy    0.937
   -----------------------------------------------------------
   | end of epoch   5 | time:  8.16s | valid accuracy    0.902
   -----------------------------------------------------------
   | epoch   6 |   500/ 1782 batches | accuracy    0.939
   | epoch   6 |  1000/ 1782 batches | accuracy    0.939
   | epoch   6 |  1500/ 1782 batches | accuracy    0.938
   -----------------------------------------------------------
   | end of epoch   6 | time:  8.16s | valid accuracy    0.906
   -----------------------------------------------------------
   | epoch   7 |   500/ 1782 batches | accuracy    0.941
   | epoch   7 |  1000/ 1782 batches | accuracy    0.939
   | epoch   7 |  1500/ 1782 batches | accuracy    0.939
   -----------------------------------------------------------
   | end of epoch   7 | time:  8.19s | valid accuracy    0.903
   -----------------------------------------------------------
   | epoch   8 |   500/ 1782 batches | accuracy    0.942
   | epoch   8 |  1000/ 1782 batches | accuracy    0.941
   | epoch   8 |  1500/ 1782 batches | accuracy    0.942
   -----------------------------------------------------------
   | end of epoch   8 | time:  8.16s | valid accuracy    0.904
   -----------------------------------------------------------
   | epoch   9 |   500/ 1782 batches | accuracy    0.942
   | epoch   9 |  1000/ 1782 batches | accuracy    0.941
   | epoch   9 |  1500/ 1782 batches | accuracy    0.942
   -----------------------------------------------------------
     end of epoch   9 | time:  8.16s | valid accuracy    0.904
   -----------------------------------------------------------
   | epoch  10 |   500/ 1782 batches | accuracy    0.940
   | epoch  10 |  1000/ 1782 batches | accuracy    0.942
   | epoch  10 |  1500/ 1782 batches | accuracy    0.942
   -----------------------------------------------------------
   | end of epoch  10 | time:  8.15s | valid accuracy    0.904
   -----------------------------------------------------------

Evaluate the model with test dataset

Checking the results of the test dataset…

print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))
Checking the results of test dataset.
test accuracy    0.908

::

   test accuracy    0.906

Test on a random news

Use the best model so far and test a golf news.

ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])
This is a Sports news

::

   This is a Sports news

模型保存于加载

MODEL_PATH = './news_model.pth'
torch.save(model.state_dict(), MODEL_PATH)
# print('The model saved epoch {}'.format(epoch))

# 如果未来要重新加载模型,在实例化model后直接执行下面命令即可
model.load_state_dict(torch.load(MODEL_PATH))
The model saved epoch 10





<All keys matched successfully>

上一篇下一篇

猜你喜欢

热点阅读