simpletransformers 实现文本分类

2021-05-19  本文已影响0人  三方斜阳

本文主要是基于英文垃圾信息分类比赛,讲解如何 通过几行代码简单使用simpletransformers 调用预训练模型实现分类任务及其他。
数据集下载:https://static.leiphone.com/sms_spam.zip

数据处理

from sklearn.preprocessing import LabelBinarizer
from simpletransformers.classification import ClassificationModel
encoder = LabelBinarizer()
label,sentence=[],[]
with open('sms_train.txt','r',encoding='utf-8') as Inp:
    for line in Inp:
            head=line.strip().split('\t',1)
            label.append(head[0])
            sentence.append(head[1])
    target = encoder.fit_transform(label)
    target=[int(item) for item in target]
>>
target:[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
dic1 = {'label': sentence, 'labels':target}
train = pd.DataFrame(dic1)
test=[]
with open('sms_test.txt','r',encoding='utf-8') as Inp:
   for line in Inp:
       test.append(line.strip())

导入模型

#导入分类模型
from simpletransformers.classification import ClassificationModel
model = ClassificationModel('roberta', 'roberta-base', num_labels=2)
model.train_model(train)
predictions, _ = model.predict(test)
pd.DataFrame({'ID':[i for i in range(len(predictions))], 'labels': predictions}).to_csv('submission.csv', index=False, header=False)
simpletransformers
  1. 这将创建一个TransformerModel,用于训练,评估和预测。第一个参数是model_type,第二个参数是model_name,第三个参数是数据中的标签数
  2. model_type可以是['bert','xlnet','xlm','roberta','distilbert']之一
from simpletransformers.model import ClassificationModel
# Create a TransformerModel
model = ClassificationModel('roberta', 'roberta-base', num_labels=3)
  1. 也可以加载以前保存的模型,而不是默认模型的模型,将model_name更改为包含已保存模型的目录的路径。
model = TransformerModel('xlnet', 'path_to_model/', num_labels=4)
  1. ClassificationArgs传递模型参数:
from simpletransformers.classification import ClassificationModel, ClassificationArgs

model_args = ClassificationArgs()
model_args.num_train_epochs = 5
model_args.learning_rate = 1e-4
model = ClassficationModel("bert", "bert-base-cased", args=model_args)

使用另一种格式

from simpletransformers.classification import ClassificationModel

model_args = {
    "num_train_epochs": 5,
    "learning_rate": 1e-4,
}
model = ClassficationModel("bert", "bert-base-cased", args=model_args)
  1. 定义好模型之后只需要一行代码开启train/evaluate/test:
# Train the model
model.train_model(train_df)

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)

# Make predictions with the model
predictions, raw_outputs = model.predict(test_df)

更多细节和其他模型用法查看官方文档:github:https://github.com/ThilinaRajapakse/simpletransformers

上一篇 下一篇

猜你喜欢

热点阅读