tf.keras.layers.Attention
2021-01-10 本文已影响0人
又双叒叕苟了一天
点乘注意力层, 又叫Luong-style attention.
tf.keras.layers.Attention(
use_scale=False, **kwargs
)
query' shape: [batch_size, Tq, dim], value's shape: [batch_size, Tv, dim], key's shape: [batch_size, Tv, dim], 计算的步骤如下:
- 计算点乘注意力分数
[batch_size, Tq, Tv]:scores = tf.matmul(query, key, transpose_b=True) - 计算
softmax:distribution = tf.nn.softmax(scores) - 对value加权求和:
tf.matmul(distribution, value), 得到shape为[batch_size, Tq, dim]的输出.
| 参数 | |
|---|---|
use_scale |
如果为 True, 将会创建一个标量的变量对注意力分数进行缩放. |
causal |
Boolean. 可以设置为 True 用于解码器的自注意力. 它会添加一个mask, 使位置i 看不到未来的信息. |
dropout |
0到1之间的浮点数. 对注意力分数的dropout |
调用参数:
inputs:
- query:
[batch_size, Tq, dim] - value:
[batch_size, Tv, dim] - key:
[batch_size, Tv, dim], 如果没有给定, 则默认key=value
mask:
- query_mask:
[batch_size, Tq], 如果给定,mask==False的位置输出为0. - value_mask:
[batch_size, Tv], 如果给定,mask==False的位置不会对输出产生贡献.
training: 是否启用dropout
示例:
# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32')
# Embedding lookup.
token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input)
# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
filters=100,
kernel_size=4,
# Use 'same' padding so outputs have the same shape as inputs.
padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.Attention()(
[query_seq_encoding, value_seq_encoding])
# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
query_value_attention_seq)
# Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
[query_encoding, query_value_attention])
# Add DNN layers, and create Model.
# ...