BERT来作多标签文本分类
2023-03-31 本文已影响0人
万州客
渐入佳境
这个代码,我电脑配置低了,会出现OOM错误,但为了调通前面的内容,也付出不少时间。
"""
{
"text": "世界百科大全总编彭友定义本词条为 人物总类 董事长分类概述 1朱明宏的基本情况男 汉族 1968年6月生 浙江义乌人11现任 金华市发展和改革委员会副主任1拟任 金华市现代服务业投资发展有限公司董事长",
"new_spo_list": [
{"s": {"entity": "朱明宏", "type": "people"},
"p": {"entity": "民族", "type": "_rel"},
"o": {"entity": "汉族", "type": "property"}},
{"s": {"entity": "朱明宏", "type": "people"},
"p": {"entity": "出生地", "type": "_rel"},
"o": {"entity": "浙江义乌", "type": "property"}},
{"s": {"entity": "朱明宏", "type": "people"},
"p": {"entity": "出生日期", "type": "_rel"},
"o": {"entity": "1968年6月", "type": "property"}}
]
}
# 使用自定义的特征抽取器
import tensorflow as tf
import idcnn #引入的特征抽取模型
input_token = tf.keras.Input(shape=(300,),dtype=tf.int32)
#根据字数建立的wordEmbedding函数
embedding = tf.keras.layers.Embedding(input_dim=21128,output_dim=256)(input_token)
embedding = idcnn.IDCNN()(embedding)
embedding = tf.keras.layers.BatchNormalization()(embedding)
embedding = tf.keras.layers.Flatten()(embedding)
embedding = tf.keras.layers.Dropout(0.217)(embedding)
output = tf.keras.layers.Dense(32)(embedding)
model = tf.keras.Model(input_token,output)
model.compile(optimizer=tf.keras.optimizers.Adam(2.17e-5),loss= tf.nn.sigmoid_cross_entropy_with_logits,metrics=["accuracy"])
import get_data
batch_size = 256
for i in range(3):
model.fit(get_data.generator(batch_size),steps_per_epoch=get_data.train_length//batch_size,epochs=2,
validation_data=(get_data.val_token_list,get_data.val_p_entity_label_list))
model.save_weights("./saver/model.h5")
"""
# 使用预训练的多标签文本分类训练
import tensorflow as tf
from transformers import AutoTokenizer, TFBertModel
bert_model = "../bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(bert_model)
model = TFBertModel.from_pretrained(bert_model)
input_token = tf.keras.Input(shape=(300, ), dtype=tf.int32)
embedding = (model(input_token)[0])
embedding = tf.keras.layers.Flatten()(embedding)
output = tf.keras.layers.Dense(32)(embedding)
model = tf.keras.Model(input_token, output)
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-5),
loss=tf.nn.sigmoid_cross_entropy_with_logits,
metrics=["accuracy"])
import get_data
batch_size = 10
saver = tf.keras.callbacks.ModelCheckpoint(filepath="../saver/model.h5", save_freq=1, save_weights_only=True)
model.fit(
get_data.generator(batch_size),
steps_per_epoch=get_data.train_length//batch_size,
epochs=1024,
validation_data=(get_data.val_token_list, get_data.val_p_entity_label_list),
callbacks=[saver])
image.png