tf2 HuggingFace Transformer2.0 b

2020-02-17  本文已影响0人  阿东7

之前在网络上搜索基于tf2 的 HuggingFace Transformer2.0 资料比较少,就给自己做个笔记

词向量原理在此不介绍

bert原理在此不介绍

bert的输入参数

  1. input_ids
  2. token_type_ids
  3. attention_mask

bert的输出参数

  1. 句子字向量

python环境

  1. tensorflow2
  2. torch1.4.1
  3. transformers2.4.1
    实际程序未用到torch


    image.png

数据准备

  1. HuggingFace Transformer需要的数据
    bert 的相关数据参考
    https://www.cnblogs.com/lian1995/p/11947522.html
    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
    'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
    'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
    }

PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
}
}

但是我这里用了tf2的数据参数
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/tf_model.h5",
}

image.png

ner需要的训练和测试数据

https://codeload.github.com/liushaoweihua/keras-bert-ner/zip/master

具体训练代码

import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import tensorflow as tf
import numpy as np
from transformers import *

#设置ner类型
label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
#类型个数
label_category_total_num = len(label_category)
#句子长度,包含几个字
max_length = 128
#小于128的句子,mask用0填补
mask_padding_with_zero=True
pad_token=0
pad_token_segment_id=0
#虚拟的数据,HuggingFace Transformer2.0的bert模型实例化会使用到,默认是3, 5,这边我们需要keras.build初始128个所有shape=(3,128)
DUMMY_INPUTS = [np.random.randint(0,128,size=128), np.random.randint(0,128,size=128), np.random.randint(0,128,size=128)]

#该损失函数,其实是tf复制过来的,方便调试
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
    return tf.keras.backend.sparse_categorical_crossentropy(
        y_true, y_pred, from_logits=True, axis=axis)

#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(DUMMY_INPUTS)}

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        #根据句子长度,生成类别的网络,这边是128个
        self.classifiers = []
        for i in range(max_length):
            self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))

    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)

        #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
        words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
        #words_output数组长度为句子长度
        logits = []
        for i,o in enumerate(words_output):
            #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
            ot = tf.squeeze(o, axis=1)
            #生成分类结果
            logit = self.classifiers[i](ot)
            logits.append(logit)
        return logits

train_path = './data_ner/train.txt'
valid_path = './data_ner/dev.txt'

#HuggingFace Transformer2.0需要的bert目录文件,这边用tf2的h5文件
pretrained_path = './my-bert-base-chinese'
config_path = os.path.join(pretrained_path, 'config.json')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

tokenizer = BertTokenizer.from_pretrained(vocab_path)
# 加载config
config = BertConfig.from_json_file(config_path)
# 加载tf原始模型
model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假

#文件格式
# 中 B-ORG
# 共 I-ORG
# 中 I-ORG
# 央 I-ORG
# 获取类型
def nerCategory(pathfile):
    srctext = tf.data.TextLineDataset(pathfile)
    label_category = set([])
    for lineText in srctext:
        # lineText 数据格式 中 B-ORG
        char_label = tf.strings.split(lineText)
        # char_label 格式 ['中', 'B-ORG']
        if len(char_label) >= 2: #如果是小于等于零代表是句子结尾
            label_category.add(char_label[1].numpy())
    return label_category

#查询索引label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
#对应的索引代表类型
def labelNum(label):
    i = label_category.index(label) if label in label_category else (len(label_category) -1)
    return float(i)

# 各 O
# 位 O
# 代 O
# 表 O
# 、 O
# 各 O
# 位 O
# 同 O
# 志 O
# : O
#
# 在 O
# 中 B-ORG
# 国 I-ORG
# 致 I-ORG
#生成数据
def fromNer(pathfile):
    input_ids_dataset = []
    attention_mask_dataset = []
    token_type_ids_datasest = []
    label_dataset = [[] for _ in range(max_length)] #根据句子长度128,分为128列的标签类型

    tmp_sentence = [] #保存句子内容,如['各','位']
    tmp_label = [] #保存句子字类型, 如[6,6]
    srctext = tf.data.TextLineDataset(pathfile)
    k = 0
    word_i = 0
    for lineText in srctext:
        # lineText 数据格式 中 B-ORG
        word_i += 1
        char_label = tf.strings.split(lineText)
        # char_label 格式 ['中', 'B-ORG']
        if (len(char_label) <= 0): #等于0,代表是一句结束
            #最大句子是128,但是还要存储bert的CSP\SEP,所有预留2个位置,实际一个句子只能存126个
            if len(tmp_sentence) > (max_length-2):
                tmp_sentence = tmp_sentence[:(max_length-2)]
            if len(tmp_label) > (max_length-2):
                tmp_label = tmp_label[:(max_length-2)]

            #根据bert的输入要求生成 input_ids attention_mask token_type_ids
            inputs = tokenizer.encode_plus("".join(tmp_sentence), add_special_tokens=True, max_length=max_length, )
            input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
            attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
            padding_length = max_length - len(input_ids)
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

            #标签也补齐CLS SEP PAD
            label_ids = [labelNum('CLS')] \
                        + tmp_label + [labelNum('SEP')] \
                        + [labelNum('PAD')] * (max_length-len(tmp_label)-2)

            for i, l in enumerate(label_ids):
                label_dataset[i].append(l)

            #tf输入需要tensor类型
            input_ids = tf.constant(input_ids)
            attention_mask = tf.constant(attention_mask)
            token_type_ids = tf.constant(token_type_ids)

            input_ids_dataset.append(input_ids)
            attention_mask_dataset.append(attention_mask)
            token_type_ids_datasest.append(token_type_ids)

            tmp_sentence.clear()
            tmp_label.clear()
            k += 1
            if k % 100 == 0:
                print('line: %d; word: %d'%(k, word_i))
        else:
            tmp_sentence.append(char_label[0].numpy().decode(encoding='UTF-8'))
            tmp_label.append(labelNum(char_label[1].numpy().decode(encoding='UTF-8')))

    for i, ls in enumerate(label_dataset):
        label_dataset[i] = tf.cast(ls, dtype=tf.float32) #tf标签需要tensor类型

    return (
        tf.convert_to_tensor(input_ids_dataset),
        tf.convert_to_tensor(attention_mask_dataset),
        tf.convert_to_tensor(token_type_ids_datasest),
        label_dataset
    )

#生成bert的输入数据
input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest, label_t_dataset = fromNer(train_path)
input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest, label_v_dataset = fromNer(valid_path)
#
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = [sparse_categorical_crossentropy for _ in range(max_length)]
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.summary()

model.fit(x={'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
          y=label_t_dataset,
          epochs=2,
          steps_per_epoch=1080,
          validation_steps=7,
          validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest], label_v_dataset))

model.save_pretrained("./mysavener")

代码注意点,网络上比较少资料有介绍到

  1. 这里用的是tf2 keras h5的数据格式,参数 from_pt=False
model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假
  1. 训练输入参数{'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest}或者[input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest],可以参考HuggingFace Transformer源码


    image.png
    image.png
  2. 训练输出参数结构,总共有128,这里是我定义一句最多128个字
    label_t_dataset的结构


    image.png
    image.png

label_t_dataset转为keras model需要的格式tensor


image.png
image.png
model.fit(x={'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
          y=label_t_dataset,
          epochs=2,
          steps_per_epoch=1080,
          validation_steps=7,
          validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest], label_v_dataset))
  1. 输出参数label_t_dataset与model的logits对应的
image.png image.png
#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(DUMMY_INPUTS)}

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        #根据句子长度,生成类别的网络,这边是128个
        self.classifiers = []
        for i in range(max_length):
            self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))

    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)

        #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
        words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
        #words_output数组长度为句子长度
        logits = []
        for i,o in enumerate(words_output):
            #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
            ot = tf.squeeze(o, axis=1)
            #生成分类结果
            logit = self.classifiers[i](ot)
            logits.append(logit)
        return logits

训练代码

import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import tensorflow as tf
import numpy as np
from transformers import *

#设置ner类型
label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
#类型个数
label_category_total_num = len(label_category)
#句子长度,包含几个字
max_length = 128
#小于128的句子,mask用0填补
mask_padding_with_zero=True
pad_token=0
pad_token_segment_id=0
#虚拟的数据,HuggingFace Transformer2.0的bert模型实例化会使用到,默认是3, 5,这边我们需要keras.build初始128个所有shape=(3,128)
DUMMY_INPUTS = [np.random.randint(0,128,size=128), np.random.randint(0,128,size=128), np.random.randint(0,128,size=128)]

#该损失函数,其实是tf复制过来的,方便调试
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
    return tf.keras.backend.sparse_categorical_crossentropy(
        y_true, y_pred, from_logits=True, axis=axis)

#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(DUMMY_INPUTS)}

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        #根据句子长度,生成类别的网络,这边是128个
        self.classifiers = []
        for i in range(max_length):
            self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))

    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)

        #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
        words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
        #words_output数组长度为句子长度
        logits = []
        for i,o in enumerate(words_output):
            #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
            ot = tf.squeeze(o, axis=1)
            #生成分类结果
            logit = self.classifiers[i](ot)
            logits.append(logit)
        return logits

train_path = './data_ner/train.txt'
valid_path = './data_ner/dev.txt'

#HuggingFace Transformer2.0需要的bert目录文件,这边用tf2的h5文件
pretrained_path = './my-bert-base-chinese'
config_path = os.path.join(pretrained_path, 'config.json')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

tokenizer = BertTokenizer.from_pretrained(vocab_path)
# 加载config
config = BertConfig.from_json_file(config_path)
# 加载tf原始模型
model = TFMyBertModel.from_pretrained('./mysavener/',from_pt=False) #from_pt是否来自pytorch,这边用tf所有设置假

#被预测的句子
text = '中国的华先生,我和他谈笑风生。'

#生成符合bert的输入数据input_ids attention_mask token_type_ids
inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_length, )
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
padding_length = max_length - len(input_ids)
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

input_ids = tf.convert_to_tensor(input_ids)[None, :]
attention_mask = tf.convert_to_tensor(attention_mask)[None, :]

pred = model(inputs={"input_ids":input_ids, 'attention_mask':attention_mask})
print(pred[0].numpy().argmax().item())
print(pred[1].numpy().argmax().item())
print(pred[2].numpy().argmax().item())
print(pred[3].numpy().argmax().item())

完整代码参考github

https://github.com/wengmingdong/tf2-stu/tree/master/bert4huggingface4tran

需要的数据在百度网盘

链接:https://pan.baidu.com/s/1dvAMo59FffwC4nKDKG-5zQ
提取码:heeh

上一篇下一篇

猜你喜欢

热点阅读