对京东评论的文本分类学习笔记(二)——LSTM
2018-08-10 本文已影响118人
DataArk
一、整体思路
整体思路很简单,就是读入数据集建立好词典后使用LSTM和全连接得到最后的分类结果。
具体流程:
- 读入数据,并对数据进行清洗(数据集按好评、中评和差评分成三类,已经用 jieba分好词)
- 建立词典
- 建立分类模型(主要就是LSTM和全连接层)
- 训练得到结果
从上面可以看出,基本与前面的方法类似,只是把卷积改成LSTM而已。
下面的代码主要来自ChenZhongFu 大佬
二、模型构建
循环神经网络是一种具有记忆功能的神经网络,每次计算时,利用了上一个时刻的记忆值,特别适合序列数据分析。网络接受的是一个序列数据,即一组向量,依次把它们输入网络,计算每个时刻的输出值。
LSTM是一种循环神经网络,其与普通的RNN不同的地方主要在于引入了门的机制就,具体就不展开了。
1. LSTM
data:image/s3,"s3://crabby-images/d0a8d/d0a8d94e7cba9ddb49972dffb82016c06766001f" alt=""
data:image/s3,"s3://crabby-images/c4996/c4996bd35a523f6c2649d3c9c5bb5b0c763ea6f2" alt=""
2. 模型定义
#定义模型
class LSTM_model(nn.Module):
def __init__(self,len_dic,emb_dim):
super(LSTM_model,self).__init__()
self.embed=nn.Embedding(len_dic,emb_dim) #b,64,128 -> 64,b,128
self.lstm1=nn.LSTM(input_size=emb_dim,hidden_size=256,dropout=0.2)#64,b,256
self.lstm2=nn.LSTM(input_size=256,hidden_size=256,dropout=0.2)#64,b,256 -> b,256
self.classify=nn.Linear(256,3)#b,3
def forward(self, x):
x=self.embed(x)
# print(x.size())
x=x.permute(1,0,2)
out,_=self.lstm1(x)
out,_=self.lstm2(out)
out=out[-1,:,:]
# print(out.size())
out=out.view(-1,256)
out=self.classify(out)
# print(out.size())
return out