Seq2Seq模型

2019-03-20  本文已影响0人  _Megamind_
Seq2Seq
  • [GO] 启动序列的解码,当解码到输出 [EOS] 时结束
  • h0 是初始化的隐藏变量,s0 是初始化的解码输出
  ##########
  # encoder 
  inputs = tf.placeholder([None, ])
  # 对于变长的输入序列,要定义好每条序列的长度
  inputs_length = tf.placeholder([None,])
  encoder_embedding = tf.Variable(tf.random_uniform([input_vocab_size, embedding_dim]))
  encoder_input_embedded = tf.nn.embedding_lookup(encoder_embedding, inputs)
  # 选用 lstm 作为编码的神经元
  encoder_cell = tf.rnn.rnn_cell.LSTMCell(hidden_dim)
  encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=encoder_input_embedded, sequence_length=inputs_length)
  #########
  # decoder
  decoder_embedding = tf.Variable(tf.random_uniform([output_vocab_size, embedding_dim]))
  labels = tf.placeholder([None, ])
  labels_length = tf.placeholder([None, ])
  target_embedded = tf.nn.embedding_lookup(decoder_embedding, labels)
  decoder_cell = tf.rnn.rnn_cell.LSTMCell(hidden_dim)
  
  # 构造辅助训练的 Decoder Input
  GO = tf.zeros([batch_size])
  decoder_embedding = tf.Variable(tf.random_uniform([output_vocab_size, embedding_dim]))
  decoder_inputs_embedded = tf.nn.embedding_lookup(decoder_embedding, tf.concat([tf.reshape(GO, [-1, 1]), labels[:, :-1]], 1))
  
  helper = tf.contrib.seq2seq.TrainingHelper(decoder_inputs_embedded, labels_length)
  decoder_initial_state = tf.clone()
  decoder = seq2seq_contrib.BasicDecoder(decoder_cell, helper, decoder_initial_state,               output_layer=tf.layers.Dense(config.target_vocab_size))
  decoder_outputs, decoder_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=tf.reduce_max(self.seq_targets_length))
上一篇 下一篇

猜你喜欢

热点阅读