tf2+cnn+中文文本分类优化系列(1)
你在看我的日志
我记起你的名字
这是安静的交易
收获彼此的说词
1 前言
接着上篇关于英文的text-cnn,今天分享一篇基础级别的中文文本分类实践练习。数据集是复旦大学开源的文本数据集,label种类为20,该数据集有点久远,感兴趣可网上搜到。这次文本分类,主要基于字级别+cnn来实现的。相对于词级别,字级别的优势就是处理简单些,不用去考虑分词错误带来的误差;缺陷就是,字所带的语义含义没词丰富,此外同样长度限制下,词级别处理的文本长度要远远大于字级别。但操作方法的角度来看,二者本质是一致的。接下来详细介绍如何实现字级别的文本分类。
2 数据处理
从网上download下的资源,分train和test文件夹,里面分别包括20个label文件夹,每个label文件夹按.txt格式存储每条文本数据。为了方便训练,需要将所有文本数据整理到同个文件下,形成 train.txt和test.txt,具体格式如下图所示:
样本数据其中train.txt 数据集有9804条数据,test.txt数据集有9832。这个比例接近1:1,不是特别合理;
3 模型构建
#引入需要用的包;
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import collections
import matplotlib.pyplot as plt
import codecs
import re
import os
#定义一些超参数类cfg;
class TextConfig():
embedding_size=100 #词的维度
vocab_size=6000 #词表的大小
seq_length=300 #文本长度
num_classes=20 #类别数量
num_filters=128 #卷积核数量
filter_size=2 #卷积核大小
keep_prob=0.5 #dropout
lr= 1e-3 #学习率
num_epochs=10 #epochs
batch_size=64 #batch_size
train_dir=r'E:\data\train.txt' #train data
test_dir=r'E:\data\test.txt' #test data
vocab_dir=r'E:\data\vocab.txt' #vocabulary
#定义read_file,处理数据文件,将一条文本按字输出,但过滤掉类似标点符号,数字类型的字符;
def read_file(file_dir):
re_han = re.compile(u"([\u4E00-\u9FD5a-zA-Z]+)") #去掉标点符号和数字类型的字符
with codecs.open(file_dir,'r',encoding='utf-8') as f:
for line in f:
label,text=line.split('\t')
content=[]
for w in text:
if re_han.match(w):
content.append(w)
yield content,label
#利用train.txt,和test.txt文件生成词表
def build_vocab(file_dirs,vocab_dir,vocab_size=6000):
all_data = []
for filename in file_dirs:
for content,_ in read_file(filename):
all_data.extend(content)
counter=collections.Counter(all_data)
count_pairs=counter.most_common(vocab_size-1)
words,_=list(zip(*count_pairs))
words=['<PAD>']+list(words)
with codecs.open(vocab_dir,'w',encoding='utf-8') as f:
f.write('\n'.join(words)+'\n')
#利用词表,将每条数据转成预定长度的token形式,其中categories是具体label的种类;
def convert_examples_to_tokens(input_dir,vocab_dir,seq_length):
words=codecs.open(vocab_dir,'r',encoding='utf-8').read().strip().split('\n')
word_to_id=dict(zip(words,range(len(words))))
categories = ['Art', 'Literature', 'Education', 'Philosophy', 'History', 'Space', 'Energy', 'Electronics',
'Communication', 'Computer','Mine','Transport','Enviornment','Agriculture','Economy',
'Law','Medical','Military','Politics','Sports']
cat_to_id=dict(zip(categories,range(len(categories))))
input_ids,label_ids=[],[]
for content,label in read_file(input_dir):
input_ids.append([word_to_id[x] if x in word_to_id else 0 for x in content ])
label_ids.append(cat_to_id[label])
input_ids =tf.keras.preprocessing.sequence.pad_sequences(input_ids, value=0,padding='post', maxlen=seq_length)
label_ids=np.array(label_ids)
return (input_ids,label_ids)
#构建model,具体为大小2的卷积核(类似2-gram特征),接着最大池化,dropout,最后加个softmax分类层;
def cnn_model(cfg):
model = tf.keras.Sequential([
layers.Embedding(input_dim=cfg.vocab_size, output_dim=cfg.embedding_size,input_length=cfg.seq_length),
layers.Conv1D(filters=cfg.num_filters, kernel_size=cfg.filter_size, strides=1, padding='valid'),
layers.GlobalMaxPooling1D(),
layers.Flatten(),
layers.Dropout(rate=cfg.keep_prob,name='dropout'),
layers.Dense(cfg.num_classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(cfg.lr),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
print(model.summary())
return model
#模型训练,文本在文件存储是按类别集中存储的,为了更好学习特征,需要将数据集打乱。
if __name__=="__main__":
cfg = TextConfig()
file_dirs = [cfg.train_dir, cfg.test_dir]
if not os.path.exists(cfg.vocab_dir):
build_vocab(file_dirs, cfg.vocab_dir, cfg.vocab_size)
train_x,train_y=convert_examples_to_tokens(cfg.train_dir,cfg.vocab_dir,cfg.seq_length)
test_x,test_y=convert_examples_to_tokens(cfg.test_dir,cfg.vocab_dir,cfg.seq_length)
indices=np.random.permutation(np.arange(len(train_x)))
train_x=train_x[indices]
train_y=train_y[indices]
model = cnn_model(cfg)
history=model.fit(train_x,train_y,epochs=5,batch_size=64,verbose=1,validation_split=0.1)
model.evaluate(test_x,test_y)
4 训练结果
训练过程中,在验证集上最好的结果0.79,测试集上为0.78。因为设置的epochs为5,有点过小,这个结果不是训练的最佳结果。若有兴趣,可以自己去调整一些超参数,像句子长度(seq_length),学习率(lr)等。
训练结果不过这个数据集各个label的样本很不平衡,导致最终的结果不是特别好,可以从下面的各个label的统计结果看出。
每个label的统计5 结 语
上图看出,总体识别准确率不高,有些label的识别准确率为0,虽然一方面跟样本不平衡相关,但还是有一定优化空间的,接下来会在此基础利用各种trick或者框架进行优化提升。