RNN、LSTM TF源码

2019-06-20  本文已影响0人  小透明苞谷

RNN

class BasicRNNCell(RNNCell):
  """The most basic RNN cell.
  Args:
    num_units: int, The number of units in the RNN cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
  """

  def __init__(self, num_units, activation=None, reuse=None):
    super(BasicRNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._linear = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
    if self._linear is None:
      self._linear = _Linear([inputs, state], self._num_units, True)

    output = self._activation(self._linear([inputs, state]))
    return output, output
  
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print(cell.state_size)
inputs = tf.placeholder(tf.float32, shape=[32, 100])
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)
print(output.shape) #128
print(h1.shape)         #128

#这里我们首先初始化了一个神经元个数为 128 的 BasicRNNCell 类,然后构造了一个 shape 为 [32, 100] 的变量作为 inputs,其代表 batch_size 为 32, 维度为 100,随后初始化了初始隐藏状态,调用了 zero_state() 方法,最终调用了其 call() 方法,最后得到 output 和 h1

LSTM

class BasicRNNCell(RNNCell):
  def __init__(self, num_units, forget_bias=1.0,
                 state_is_tuple=True, activation=None, reuse=None):
      super(BasicLSTMCell, self).__init__(_reuse=reuse)
      if not state_is_tuple:
        logging.warn("%s: Using a concatenated state is slower and will soon be "
                     "deprecated.  Use state_is_tuple=True.", self)
      self._num_units = num_units
      self._forget_bias = forget_bias
      self._state_is_tuple = state_is_tuple
      self._activation = activation or math_ops.tanh
      self._linear = None
      
  @property
  def state_size(self):
      return (LSTMStateTuple(self._num_units, self._num_units)
          if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
      return self._num_units
    
  def call(self, inputs, state):
      """Long short-term memory cell (LSTM).

      Args:
        inputs: `2-D` tensor with shape `[batch_size x input_size]`.
        state: An `LSTMStateTuple` of state tensors, each shaped
          `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
          `True`.  Otherwise, a `Tensor` shaped
          `[batch_size x 2 * self.state_size]`.

      Returns:
        A pair containing the new hidden state, and the new state (either a
          `LSTMStateTuple` or a concatenated state, depending on
          `state_is_tuple`).
      """
      sigmoid = math_ops.sigmoid
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
          c, h = state
      else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

      if self._linear is None:
          self._linear = _Linear([inputs, h], 4 * self._num_units, True)
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(
          value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

      new_c = (
          c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
      else:
          new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state
    
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(tf.float32, shape=(32, 100))
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)

摘自:https://cuiqingcai.com/4925.html

上一篇下一篇

猜你喜欢

热点阅读