BERT+CRF命名实体识别做预测
2021-04-06 本文已影响0人
陶_306c
使用BERT+CRF做命名实体识别
if FLAGS.do_predict:
with open(FLAGS.middle_output+'/label2id.pkl', 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
predict_examples = processor.get_test_examples(FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
batch_tokens,batch_labels = filed_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
logging.info("***** Running prediction*****")
logging.info(" Num examples = %d", len(predict_examples))
logging.info(" Batch size = %d", FLAGS.predict_batch_size)
predict_input_fn = file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=False)
result = estimator.predict(input_fn=predict_input_fn)
output_predict_file = os.path.join(FLAGS.output_dir, "label_test1.txt")
#here if the tag is "X" means it belong to its before token, here for convenient evaluate use
# conlleval.pl we discarding it directly
Writer(output_predict_file,result,batch_tokens,batch_labels,id2label)
output_predict_file = os.path.join(FLAGS.output_dir, "label_test1.txt")
是预测的输出,每个token对应一个标签。
数据集:conll03
有4类标签:person,location,organization,miscellaneous(混杂的)实体标签
输入:test.txt
#第一列是单词,第二列是词性,第三列是语法块,第四列是实体标签。
Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER
输出:label_text.txt
#第一列是单词,第二列是真实标签,第三列是预测标签
United B-LOC B-LOC
Arab I-LOC I-LOC
Emirates I-LOC I-LOC
预测
运行BERT_NER.py文件,修改以下参数,其他不变。
--do_train=False \
--do_eval=True \
--do_predict=True
预测结果:
processed 40610 tokens with 4671 phrases; found: 4557 phrases; correct: 4104.
accuracy: 98.02%; precision: 90.06%; recall: 87.86%; FB1: 88.95
LOC: precision: 92.07%; recall: 91.34%; FB1: 91.71 1387
MISC: precision: 82.19%; recall: 77.54%; FB1: 79.80 668
ORG: precision: 87.41%; recall: 82.49%; FB1: 84.88 1191
PER: precision: 94.36%; recall: 94.93%; FB1: 94.64 1311
输入的句子:
"When it comes to cyberspace, the is the most technologically-advanced nation. talks with cybersecurity experts about the national security implications of the suspected hacking, and the dangers it poses to the ."
分别将上述红字识别出来,疑惑是将United States和US识别为LOC,将Russian识别为MISC
实体准确率计算
这里的accuracy是对测试集做预测,然后计算的准确率。
acc = 预测正确的(包括‘O’)/所有样本
而毕竟O占大多数,实体占小部分,我单独计算了一下实体的准确率只有91.42%。以下是代码
with open('./label_test.txt','r',encoding='utf-8')as f:
sents = [line.strip() for line in f.readlines()]
totals = len(sents)
print(totals)
count=0
total = 0
for sent in sents:
words = sent.split()
if words[-1]!='O':
total+=1
if words[-1]==words[-2]:
count+=1
# print(total)
# print(count)
print('Accuracy:%.4f'%(count/total))