NER

BERT+CRF命名实体识别做预测

2021-04-06  本文已影响0人  陶_306c

使用BERT+CRF做命名实体识别

    if FLAGS.do_predict:
        with open(FLAGS.middle_output+'/label2id.pkl', 'rb') as rf:
            label2id = pickle.load(rf)
            id2label = {value: key for key, value in label2id.items()}
   
        predict_examples = processor.get_test_examples(FLAGS.data_dir)

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        batch_tokens,batch_labels = filed_based_convert_examples_to_features(predict_examples, label_list,
                                                 FLAGS.max_seq_length, tokenizer,
                                                 predict_file)

        logging.info("***** Running prediction*****")
        logging.info("  Num examples = %d", len(predict_examples))
        logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=False)

        result = estimator.predict(input_fn=predict_input_fn)
        output_predict_file = os.path.join(FLAGS.output_dir, "label_test1.txt")
        #here if the tag is "X" means it belong to its before token, here for convenient evaluate use
        # conlleval.pl we  discarding it directly
        Writer(output_predict_file,result,batch_tokens,batch_labels,id2label)

output_predict_file = os.path.join(FLAGS.output_dir, "label_test1.txt")是预测的输出,每个token对应一个标签。

数据集:conll03

有4类标签:person,location,organization,miscellaneous(混杂的)实体标签

输入:test.txt

#第一列是单词,第二列是词性,第三列是语法块,第四列是实体标签。
Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER

输出:label_text.txt

#第一列是单词,第二列是真实标签,第三列是预测标签
United  B-LOC   B-LOC
Arab    I-LOC   I-LOC
Emirates    I-LOC   I-LOC

预测

运行BERT_NER.py文件,修改以下参数,其他不变。

--do_train=False   \
--do_eval=True   \
--do_predict=True 

预测结果:


processed 40610 tokens with 4671 phrases; found: 4557 phrases; correct: 4104.
accuracy:  98.02%; precision:  90.06%; recall:  87.86%; FB1:  88.95
              LOC: precision:  92.07%; recall:  91.34%; FB1:  91.71  1387
             MISC: precision:  82.19%; recall:  77.54%; FB1:  79.80  668
              ORG: precision:  87.41%; recall:  82.49%; FB1:  84.88  1191
              PER: precision:  94.36%; recall:  94.93%; FB1:  94.64  1311

输入的句子:
"When it comes to cyberspace, the \color{red}{United States} is the most technologically-advanced nation. \color{red}{Ted Koppel} talks with cybersecurity experts about the national security implications of the suspected \color{red}{Russian} hacking, and the dangers it poses to the \color{red}{US}."
分别将上述红字识别出来,疑惑是将United States和US识别为LOC,将Russian识别为MISC

实体准确率计算

这里的accuracy是对测试集做预测,然后计算的准确率。
acc = 预测正确的(包括‘O’)/所有样本
而毕竟O占大多数,实体占小部分,我单独计算了一下实体的准确率只有91.42%。以下是代码

with open('./label_test.txt','r',encoding='utf-8')as f:
    sents = [line.strip() for line in f.readlines()]

totals = len(sents)
print(totals)

count=0
total = 0
for sent in sents:
    words = sent.split()
    if words[-1]!='O':
        total+=1
        if words[-1]==words[-2]:
            count+=1
# print(total)
# print(count)
print('Accuracy:%.4f'%(count/total))
上一篇下一篇

猜你喜欢

热点阅读