global attetion学习

2020-09-30  本文已影响0人  锦绣拾年

global attention的理解

attention:

给定一个查询q ,通过计算与key的注意力分布,附加到value上,最后得到attention value,可以作为后续计算的参考。

对应seq2seq,机器翻译来说:

查询q是当下的隐向量,key和value相同,是原语句的多个隐向量,当下隐向量和原语句向量交互得到注意力权重,然后得到attention value,在和 下一个输入 concat一起,来进行预测。

参考:

https://guillaumegenthial.github.io/sequence-to-sequence.html

对于seq2seq解码过程

h_t = LSTM(h_{t-1},[w_{it-1},c_t])

s_t = g(t)

p_t = softmax(s_t)

i_t = argmax(p_t)

然后i_t得到预测单词的词向量,作为w_{it}c_{t+1}输入下一个神经元

c_t 就是得到的 attention vector,或者是context vector。

c_t 的计算过程如下:

\alpha{_\acute{t}}=f(h_{t-1},e_{\acute{t}})\in R,for\quad all\quad \acute{t} 得到注意力权重

\hat{\alpha}=softmax(\alpha)

c_t = \sum_{\acute{t}}^n\hat{\alpha_{\acute{t}}}e{_\acute{t}} 加权求和

其中:

e{_\acute{t}} 就是key,也是value,就是机器翻译中原语句词得到的隐向量。

f是得到注意力权重的方法,一般有以下几种:

f(h_{t-1},e{_\acute{t}})= \begin{cases} h_{t-1}^Te{_\acute{t}}& \text{dot}\\ h_{t-1}^TWe{_\acute{t}} & \text{general} \\ v^T tanh(W[h_{t-1},e{_\acute{t}}]) & \text{concat}\end{cases}

对于文本分类来说,

FEED-FORWARD NETWORKS WITH ATTENTION CAN SOLVE SOME LONG-TERM MEMORY PROBLEMS

论文中,注意力向量

e_t = a(h_t)

\alpha_t = \frac{exp(e_t)}{\sum_{k=1}^Texp(e_k)}

c = \sum_{t=1}^T\alpha_t h_t

即通过对隐向量的加权得到注意力向量。

隐向量的权值通过a函数学习得到。论文中:

a(h_t) = tanh(W_{hc}h_t+b_{hc})

Hierarchical Attention Networks for Document Classification

论文中,主要思想是设置了可训练的全局向量,作为q,key和value则是文章的单词和句子。
在求句子中每个单词的注意力分布时,这个全局向量可以视为:查询那个是有用的词?
在求一篇文章中每个单词的注意力分布时,这个全局向量可以视为:查询哪个是有用的句子。

具体可以看

https://www.jianshu.com/p/2a97d171e424

其余学习参考资料
https://guillaumegenthial.github.io/sequence-to-sequence.html

cs224n
https://github.com/philipperemy/keras-attention-mechanism/blob/master/attention/attention.py

上一篇 下一篇

猜你喜欢

热点阅读