论文阅读“k-Nearest Neighbor Augmente

2022-10-06  本文已影响0人  掉了西红柿皮_Kee

Wang Z, Hamza W, Song L. k-Nearest Neighbor Augmented Neural Networks for Text Classification[J]. arXiv preprint arXiv:1708.07863, 2017.

摘要导读

近年来,许多基于深度学习的模型被用于文本分类。然而,在训练的过程中缺乏对训练集中实例级信息的利用。在本文中,作者建议通过利用输入文本的k-nearest neighbor(kNN)信息来加强神经网络模型对文本嵌入的学习以更好的辅助分类任务。具体来说,提出的模型采用了一个神经网络,将文本编码为嵌入表示。此外,该模型还利用输入文本的k-近邻作为外部存储器,并利用它来捕捉训练集中的实例级信息。最终的分类预测是基于神经网络编码器和kNN memory的特征进行的。实验结果显示,提出的模型在所有数据集上都优于基线模型,甚至在几个数据集上击败了29层的神经网络模型;并且在训练实例稀少和训练集严重不平衡的情况下也显示出卓越的性能;该模型甚至可以很好的利用在半监督训练和转移学习等技术中。

模型浅析

提出的模型主要从训练集中提取全局和实例级信息来进行文本分类任务。为捕捉全局信息,训练了一个神经网络编码器,根据所有的训练实例及其类别信息将文本编码到一个嵌入空间。为了捕捉实例级的信息,对于每个输入的文本,从训练集中搜索其对应的k个近邻样本,然后将其作为外部存储器来增强神经网络。

上图中蓝色的data flow是传统的文本分类流程。余下的部分即为本文提出的kNN memory,即使用注意力机制来抽取实例级别的信息用于预测。可以形式化为如下:给定样本x和样本x的kNN\{x'_1,\cdots,x'_k,\cdots, x'_K\}以及对应的正确标签y\{y'_1,\cdots,y'_k,\cdots, y'_K\}。因此本文的任务是基于训练集估计一个条件概率Pr=(y|x,x'_1,\cdots, x'_K,y'_1,\cdots,y'_K),然后可以用于测试样本的标签: \mathcal{A}(y)是所有可能标签的集合。
序号 符号表示 描述
1 h 当前文本对应的text embedding
2 \hat{y} 当前文本对应的kNN所生成的基于关注度的标签分布
3 \hat{h} 当前文本对应的kNN所生成的基于关注度的text embedding

最后的步骤时将这些特征向量都拼接起来,用于学习分类预测。即将[h,\hat{y},\hat{h}]输入最终的分类器得出预测结果。在测试集上,则还是在训练集的样本中构造kNN memory,用于预测。


将训练集中的信息用的很全面。简单且有效。
上一篇 下一篇

猜你喜欢

热点阅读