attention机制的热度图

2017-12-08  本文已影响58人  是neinei啊

1. 举个🌰

def plot_attention(data, X_label=None, Y_label=None):
  '''
    Plot the attention model heatmap
    Args:
      data: attn_matrix with shape [ty, tx], cutted before 'PAD'
      X_label: list of size tx, encoder tags
      Y_label: list of size ty, decoder tags
  '''
  fig, ax = plt.subplots(figsize=(20, 8)) # set figure size
  heatmap = ax.pcolor(data, cmap=plt.cm.Blues, alpha=0.9)
  
  # Set axis labels
  if X_label != None and Y_label != None:
    X_label = [x_label.decode('utf-8') for x_label in X_label]
    Y_label = [y_label.decode('utf-8') for y_label in Y_label]
    
    xticks = range(0,len(X_label))
    ax.set_xticks(xticks, minor=False) # major ticks
    ax.set_xticklabels(X_label, minor = False, rotation=45)   # labels should be 'unicode'
    
    yticks = range(0,len(Y_label))
    ax.set_yticks(yticks, minor=False)
    ax.set_yticklabels(Y_label, minor = False)   # labels should be 'unicode'
    
    ax.grid(True)

2. 参数

X_label: 是encoder的句子一个一个word组成的list;
Y_label: 是decoder的句子一个一个word组成的list

3. 结果展示

热度图.png
上一篇 下一篇

猜你喜欢

热点阅读