tf2 HuggingFace Transformer2.0 b

2020-02-17  本文已影响0人  阿东7

之前在网络上搜索基于tf2 的 HuggingFace Transformer2.0 资料比较少,就给自己做个笔记

词向量原理在此不介绍

bert原理在此不介绍

bert的输入参数

  1. input_ids
  2. token_type_ids
  3. attention_mask

bert的输出参数

  1. 句子字向量

python环境

  1. tensorflow2
  2. torch1.4.1
  3. transformers2.4.1
    实际程序未用到torch


    image.png

数据准备

  1. HuggingFace Transformer需要的数据
    bert 的相关数据参考
    https://www.cnblogs.com/lian1995/p/11947522.html
    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
    'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
    'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
    }

PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
}
}

但是我这里用了tf2的数据参数
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/tf_model.h5",
}

image.png

情感分析需要的训练和测试数据

https://codeload.github.com/pengming617/bert_classification/zip/master

image.png

具体训练代码

import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import tensorflow as tf
from transformers import *

#该损失函数,其实是tf复制过来的,方便调试
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
    return tf.keras.backend.sparse_categorical_crossentropy(
        y_true, y_pred, from_logits=True, axis=axis)

#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        # label text_a
        # 1 這間酒店環境和服務態度亦算不錯

        #自定义分类,目前只有两个分类:0、1,所有 tf.keras.layers.Dense(2, name="classifier") 这边设置2
        #这边用的是sparse_categorical_crossentropy损失函数,会自动one-hot,label只要提供数值(例如这边0和1)
        self.classifier = tf.keras.layers.Dense(2, name="classifier")
    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)
        print(pooled_output.shape)
        #这边只是为了学习,在训练时就没加dropout,防止过拟合
        logits = self.classifier(pooled_output)
        return logits

#定义句子长度
max_length = 128
mask_padding_with_zero=True
pad_token=0
pad_token_segment_id=0

train_path = './data/train.tsv'
valid_path = './data/dev.tsv'

pretrained_path = './my-bert-base-chinese'
config_path = os.path.join(pretrained_path, 'config.json')
vocab_path = os.path.join(pretrained_path, 'vocab.txt')

tokenizer = BertTokenizer.from_pretrained(vocab_path)
# 加载config
config = BertConfig.from_json_file(config_path)
# 加载tf原始模型
model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假

def build_dataset(pathfile):
    input_ids_dataset = []
    attention_mask_dataset = []
    token_type_ids_datasest = []
    label_dataset = []
    # label text_a
    # 1 這間酒店環境和服務態度亦算不錯
    # 以上是文件格式,  1 代表类型;后面句子代表内容
    dataset = tf.data.experimental.make_csv_dataset(
        pathfile,
        100,
        label_name='label',
        select_columns=['label', 'text_a'],
        field_delim='\t',
        header=True,
        use_quote_delim=False,
        shuffle=False,
        num_epochs=1)

    def build_src_dataset(examples, labels):
        # examples, labels = data_element  # 第一个批次
        btextarr = examples['text_a'].numpy()
        text_label_arr = [(b.decode(encoding='UTF-8'), l) for b, l in zip(btextarr, labels.numpy())]
        return text_label_arr

    for examples, labels in dataset:
        text_label_arr = build_src_dataset(examples, labels)

        for text, label in text_label_arr:
            #生成符合bert的输入格式数据 input_ids,attention_mask,token_type_ids
            inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_length, )
            input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
            attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
            padding_length = max_length - len(input_ids)
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

            #需要转为tensor张量
            input_ids = tf.constant(input_ids)
            attention_mask = tf.constant(attention_mask)
            token_type_ids = tf.constant(token_type_ids)

            input_ids_dataset.append(input_ids)
            attention_mask_dataset.append(attention_mask)
            token_type_ids_datasest.append(token_type_ids)
            label_dataset.append(label)


    return (
        tf.convert_to_tensor(input_ids_dataset),
        tf.convert_to_tensor(attention_mask_dataset),
        tf.convert_to_tensor(token_type_ids_datasest),
        tf.cast(label_dataset, dtype=tf.float32))

#生成训练和cv测试数据
input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest, label_t_dataset = build_dataset(train_path)
input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest, label_v_dataset = build_dataset(valid_path)

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = sparse_categorical_crossentropy
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.summary()

model.fit({'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
          label_t_dataset,
          epochs=2,
          steps_per_epoch=1080,
          validation_steps=7,
          validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest],label_v_dataset))

#保存模型,在预测时使用
model.save_pretrained("./mysave")

代码注意点,网络上比较少资料有介绍到

  1. 这里用的是tf2 keras h5的数据格式,参数 from_pt=False
model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假
  1. 训练输入参数{'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest}或者[input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest],可以参考HuggingFace Transformer源码


    image.png
    image.png
l.fit({'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
          label_t_dataset,
          epochs=2,
          steps_per_epoch=1080,
          validation_steps=7,
          validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest],label_v_dataset))
  1. 自定义model
 self.classifier = tf.keras.layers.Dense(2, name="classifier"),这边的2代表两种类型
#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        # label text_a
        # 1 這間酒店環境和服務態度亦算不錯

        #自定义分类,目前只有两个分类:0、1,所有 tf.keras.layers.Dense(2, name="classifier") 这边设置2
        #这边用的是sparse_categorical_crossentropy损失函数,会自动one-hot,label只要提供数值(例如这边0和1)
        self.classifier = tf.keras.layers.Dense(2, name="classifier")
    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)
        print(pooled_output.shape)
        #这边只是为了学习,在训练时就没加dropout,防止过拟合
        logits = self.classifier(pooled_output)
        return logits

预测代码

import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import tensorflow as tf
from transformers import *

#该损失函数,其实是tf复制过来的,方便调试
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
    return tf.keras.backend.sparse_categorical_crossentropy(
        y_true, y_pred, from_logits=True, axis=axis)

#从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
class TFMyBertModel(TFBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = TFBertMainLayer(config, name="bert")
        # label text_a
        # 1 這間酒店環境和服務態度亦算不錯

        #自定义分类,目前只有两个分类:0、1,所有 tf.keras.layers.Dense(2, name="classifier") 这边设置2
        #这边用的是sparse_categorical_crossentropy损失函数,会自动one-hot,label只要提供数值(例如这边0和1)
        self.classifier = tf.keras.layers.Dense(2, name="classifier")
    def call(self, inputs, **kwargs):
        sequence_output, pooled_output = self.bert(inputs, **kwargs)
        print(sequence_output.shape)
        print(pooled_output.shape)
        #这边只是为了学习,在训练时就没加dropout,防止过拟合
        logits = self.classifier(pooled_output)
        return logits

max_length = 128
mask_padding_with_zero=True
pad_token=0
pad_token_segment_id=0

pretrained_path = './my-bert-base-chinese'
vocab_path = os.path.join(pretrained_path, 'vocab.txt')
tokenizer = BertTokenizer.from_pretrained(vocab_path)
model = TFMyBertModel.from_pretrained('./mysave/',from_pt=False)

#被预测的句子
text = '我个人觉得这本书是我读过最好的一本书,书中采用为什么要这样,动机是什么,他的出现有什么背景知识,以及一个概念属于的多重含义,在不同场景的不同运用,这样的方式来进行的,我个人认为翻译的还不错,这本书的作者,真的是不愧为大师啊,每每看下几页纸,就感触颇多,那是很有体会啊,一本书能够勾起你很多的想法时,我想你读书的目的已经达到,如果你还在犹豫,那么请你果决点,绝对没有错,强烈推荐!'

#生成符合bert的输入数据input_ids attention_mask token_type_ids
inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_length, )
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
padding_length = max_length - len(input_ids)
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

input_ids = tf.convert_to_tensor(input_ids)[None, :]
attention_mask = tf.convert_to_tensor(attention_mask)[None, :]

pred_1 = model(inputs={"input_ids":input_ids, 'attention_mask':attention_mask})
print(pred_1.numpy()[0].argmax().item()) 

完整代码在github

https://github.com/wengmingdong/tf2-stu/tree/master/bert4huggingface4tran

所有需要的数据百度网盘地址

链接:https://pan.baidu.com/s/1dvAMo59FffwC4nKDKG-5zQ
提取码:heeh

上一篇 下一篇

猜你喜欢

热点阅读