基于LSTM的文本分类

2020-06-20  本文已影响0人  还闹不闹
#!usr/bin/python
# coding=utf-8
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from keras.models import Model
from keras.layers import LSTM, Activation, Dense, Dropout, Input, Embedding
from keras.optimizers import RMSprop
from keras.preprocessing.text import Tokenizer
from keras.preprocessing import sequence
from keras.callbacks import EarlyStopping
# ## 设置字体
# from matplotlib.font_manager import FontProperties
# # fonts = FontProperties(fname = "/Library/Fonts/华文细黑.ttf",size=14)
# # %config InlineBackend.figure_format = 'retina'
# %matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
# 画图支持中文显示
from pylab import *
matplotlib.rcParams['font.sans-serif'] = ['SimHei']

# 显示所有列
pd.set_option('display.max_columns', None)
# 显示所有行
pd.set_option('display.max_rows', None)
# 设置value的显示长度为10000,默认为50
pd.set_option('display.width',10000)
pd.set_option('display.unicode.ambiguous_as_wide', True)
pd.set_option('display.unicode.east_asian_width', True)
#
np.set_printoptions(linewidth=1000)

## 读取测数据集
train_df = pd.read_csv("G:\\rnn\lstm\cnews-LSTM\cnews_train.csv")
val_df = pd.read_csv("G:\\rnn\lstm\cnews-LSTM\cnews_val.csv")
test_df = pd.read_csv("G:\\rnn\lstm\cnews-LSTM\cnews_test.csv")
# print(train_df.head())
# print(val_df.head())
# print(test_df.head())
print(train_df.iloc[1:2,[2]])
print(type(train_df.iloc[1:2,[2]]))
# print(len((train_df.iloc[1:2,[2]]).values), (train_df.iloc[1:2,[2]]).values)
print((train_df.iloc[1:2,[2]]).values.tolist())
print(len(set(((train_df.iloc[1:2,[2]]).values.tolist())[0][0])))
print(train_df.iloc[1:2,[3]]) # 第4列为对应词组的个数
# -------------------------------------------------------------
a = list(filter(None, (((train_df.iloc[0:1,[2]]).values.tolist())[0][0]).split(" ")))
b = list(filter(None, (((train_df.iloc[1:2,[2]]).values.tolist())[0][0]).split(" ")))
print(len(a), len(b), a, b)
a.extend(b)
print(a)
print(len(a), len(set(a)))
# 该数据集已经进行了处理,每个数据集包含4列数据,其中第一列为标签数据,第二列为新闻的原文数据,第三列为经过分词、去停用词等操作,并使用空格连接的分词后数据,第4列为对应词组的个数。

# 数据探索:查看训练集都有哪些标签
plt.figure()
sns.countplot(train_df.label)
# plt.xlabel('Label',fontproperties = fonts,size = 10)
# plt.xticks(fontproperties = fonts,size = 10)
plt.xlabel('Label',size = 10)
plt.xticks(size = 10)
plt.show()
# 分析训练集中词组数量的分布
print(train_df.cutwordnum.describe())
plt.figure()
plt.hist(train_df.cutwordnum,bins=100)
# plt.xlabel("词组长度",fontproperties = fonts,size = 12)
# plt.ylabel("频数",fontproperties = fonts,size = 12)
# plt.title("训练数据集",fontproperties = fonts)
plt.xlabel("词组长度",size = 12)
plt.ylabel("频数",size = 12)
plt.title("训练数据集")
plt.show()

# 接下来对数据集的标签数据进行编码,首先是LabelEncoder()编码,然后是进行OneHotEncoder()编码。-------------------------------编码标签
# 对数据集的标签数据进行编码
train_y = train_df.label
val_y = val_df.label
test_y = test_df.label
le = LabelEncoder()
train_y = le.fit_transform(train_y).reshape(-1,1)
val_y = le.transform(val_y).reshape(-1,1)
test_y = le.transform(test_y).reshape(-1,1)
# 对数据集的标签数据进行one-hot编码
ohe = OneHotEncoder()
train_y = ohe.fit_transform(train_y).toarray()
val_y = ohe.transform(val_y).toarray()
test_y = ohe.transform(test_y).toarray()

# 使用Tokenizer对词组进行编码--------------------------------------------------------------------------------------------编码文本
# 当我们创建了一个Tokenizer对象后,使用该对象的fit_on_texts()函数,以空格去识别每个词,
# 可以将输入的文本中的每个词编号,编号是根据词频的,词频越大,编号越小。
max_words = 5000
max_len = 600
tok = Tokenizer(num_words=max_words)  ## 使用的最大词语数为5000
tok.fit_on_texts(train_df.cutword)
# 使用word_index属性可以看到每次词对应的编码
# 使用word_counts属性可以看到每个词对应的频数
for ii,iterm in enumerate(tok.word_index.items()):
    if ii < 10:
        print(iterm)
    else:
        break
print("===================")
for ii,iterm in enumerate(tok.word_counts.items()):
    if ii < 10:
        print(iterm)
    else:
        break

# 使用tok.texts_to_sequences()将数据转化为序列,并使用sequence.pad_sequences()将每个序列调整为相同的长度
# 对每个词编码之后,每句新闻中的每个词就可以用对应的编码表示,即每条新闻可以转变成一个向量了:
train_seq = tok.texts_to_sequences(train_df.cutword)
val_seq = tok.texts_to_sequences(val_df.cutword)
test_seq = tok.texts_to_sequences(test_df.cutword)
## 将每个序列调整为相同的长度
train_seq_mat = sequence.pad_sequences(train_seq,maxlen=max_len)
val_seq_mat = sequence.pad_sequences(val_seq,maxlen=max_len)
test_seq_mat = sequence.pad_sequences(test_seq,maxlen=max_len)

print(train_seq_mat.shape)
print(val_seq_mat.shape)
print(test_seq_mat.shape)

# 定义LSTM模型
inputs = Input(name='inputs',shape=[max_len])
## Embedding(词汇表大小,batch大小,每个新闻的词长)
layer = Embedding(max_words+1,128,input_length=max_len)(inputs)
# layer = LSTM(128)(layer)
layer = LSTM(8)(layer)
layer = Dense(128,activation="relu",name="FC1")(layer)
layer = Dropout(0.5)(layer)
layer = Dense(10,activation="softmax",name="FC2")(layer)
model = Model(inputs=inputs,outputs=layer)
model.summary()
model.compile(loss="categorical_crossentropy",optimizer=RMSprop(),metrics=["accuracy"])

# 模型训练
# model_fit = model.fit(train_seq_mat,train_y,batch_size=128,epochs=10,
#                       validation_data=(val_seq_mat,val_y),
#                       callbacks=[EarlyStopping(monitor='val_loss',min_delta=0.0001)] ## 当val-loss不再提升时停止训练
#                      )
model_fit = model.fit(train_seq_mat,train_y,batch_size=128,epochs=1,
                      validation_data=(val_seq_mat,val_y),
                      callbacks=[EarlyStopping(monitor='val_loss',min_delta=0.0001)] ## 当val-loss不再提升时停止训练
                     )

# -----------------------------------------------------------------------------------------------------------------
# 对测试集进行预测
test_pre = model.predict(test_seq_mat)

## 评价预测效果,计算混淆矩阵
confm = metrics.confusion_matrix(np.argmax(test_pre,axis=1),np.argmax(test_y,axis=1))
## 混淆矩阵可视化
Labname = ["体育","娱乐","家居","房产","教育","时尚","时政","游戏","科技","财经"]
plt.figure(figsize=(8,8))
sns.heatmap(confm.T, square=True, annot=True,
            fmt='d', cbar=False,linewidths=.8,
            cmap="YlGnBu")
plt.xlabel('True label',size = 14)
plt.ylabel('Predicted label',size = 14)
plt.xticks(np.arange(10)+0.5,Labname,fontproperties = fonts,size = 12)
plt.yticks(np.arange(10)+0.3,Labname,fontproperties = fonts,size = 12)
plt.show()

print(metrics.classification_report(np.argmax(test_pre,axis=1),np.argmax(test_y,axis=1)))

参考:
https://www.cnblogs.com/BobHuang/p/11157489.html
https://kexue.fm/archives/3414
https://zhuanlan.zhihu.com/p/39884984
https://www.jianshu.com/p/caee648f6a1f
https://zhuanlan.zhihu.com/p/50657430

上一篇下一篇

猜你喜欢

热点阅读