[tf]进行变量和层的重用

2019-01-13  本文已影响0人  VanJordan
class LSTM(object):
  """LSTM layer using dynamic_rnn.

  Exposes variables in `trainable_weights` property.
  """

  def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
    self.cell_size = cell_size
    self.num_layers = num_layers
    self.keep_prob = keep_prob
    self.reuse = None
    self.trainable_weights = None
    self.name = name
    print("You got one LSTM")

  def __call__(self, x, initial_state, seq_length):
    with tf.variable_scope(self.name, reuse=self.reuse) as vs:
      cell = tf.contrib.rnn.MultiRNNCell([
          tf.contrib.rnn.BasicLSTMCell(
              self.cell_size,
              forget_bias=0.0,
              reuse=tf.get_variable_scope().reuse)
          for _ in xrange(self.num_layers)
      ])

      # shape(x) = (batch_size, num_timesteps, embedding_dim)

      lstm_out, next_state = tf.nn.dynamic_rnn(
          cell, x, initial_state=initial_state, sequence_length=seq_length)

      # shape(lstm_out) = (batch_size, timesteps, cell_size)

      if self.keep_prob < 1.:
        lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)

      if self.reuse is None:
        self.trainable_weights = vs.global_variables()

    self.reuse = True

    return lstm_out, next_state
上一篇 下一篇

猜你喜欢

热点阅读