deep_learning 02. tf.nn.rnn_cell

2019-03-04  本文已影响0人  adowu

开始的话:
从基础做起,不断学习,坚持不懈,加油。
一位爱生活爱技术来自火星的程序汪

上一节讲到了最基础的BasicRNNCell,本章就简单介绍下BasicLSTMCell。如果有不对的地方还请指正,谢谢!

话不多说,先上图:


basicLSTMCell.png

这张图大家肯定看到过很多次,是一个展开的LSTM Cell的内部结构。接下来还是和上一节一样,从tensorflow代码层面分析下。

代码和上一节的BasicRNNCell 都差不多,只是调用的rnn_cell变了。

def basic_lstm_demo():
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=4)
    zero_state = cell.zero_state(batch_size=2, dtype=tf.float32)
    a = tf.random_normal([2, 3, 4])
    out, state = tf.nn.dynamic_rnn(
        cell=cell,
        initial_state=zero_state,
        inputs=a
    )

tensorflow中的主要逻辑代码如下:

  def call(self, inputs, state):
    """Long short-term memory cell (LSTM)."""
    sigmoid = math_ops.sigmoid
    one = constant_op.constant(1, dtype=dtypes.int32)
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_tuple:
      c, h = state
    else:
      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, h], 1), self._kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
        value=gate_inputs, num_or_size_splits=4, axis=one)

    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
    # Note that using `add` and `multiply` instead of `+` and `*` gives a
    # performance improvement. So using those at the cost of readability.
    add = math_ops.add
    multiply = math_ops.multiply
    new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
                multiply(sigmoid(i), self._activation(j)))
    new_h = multiply(self._activation(new_c), sigmoid(o))

    if self._state_is_tuple:
      new_state = LSTMStateTuple(new_c, new_h)
    else:
      new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

接下来就结合着这段源码简要分析下:

        结合着LSTM的图示来理解代码更清楚。

        #   输入的inputs [2,3,4],经过unstack则为 list([2,4]).size为3,所以输入到LSTM中的input为[2,4]
        #   初始化的 c 和 h 都是zero_state 也就是都为[2,4]的zero,这是参数state_is_tuple的情况下,
        #   如果这个参数为 False,则 c,h = [2,2]
        c, h = state
        #   初始化权重参数为:在此处就是 [4 + 4, 4 * 4] = [8, 16],为什么乘以4后面就可以知道原因
        kernel_ = [input_dims +num_units, 4 * num_units]

        #   concat[inputs, h] = [2, 8] kernel_ = [8, 16], bias=zero of [4 * num_units]
        #   所以gate_inputs = [2, 16]
        gate_inputs = bias_add(matmul(concat([inputs, h], axis=1), kernel_), bias)

        #   i 表示input_gate
        #   j 表示new_input
        #   f 表示forget_gate
        #   o 表示output_gate
        #   为了保持维度正确,所以前面要在num_units上乘以4的原因
        i, j, f, o = array_ops.split(value=gate_inputs, num_or_size_splits=4, axis=1)

        forget_bias = 1.0

        #   计算这个cell中的new_c 和 new_h
        #   forget_gate_output =  sigmoid(add(f, forget_bias_tensor))
        #   input_gate_output = multiply(sigmoid(i), tanh(j))
        #   update_c = add(multiply(c, forget_gate_output), input_gate_output)
        #   output_gate_output = multiply(tanh(new_c), sigmoid(o))

        new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),multiply(sigmoid(i), tanh(j)))
        new_h = multiply(tanh(new_c), sigmoid(o))

最后的output输出为:包含了shape为[2,3,4]的每个时间步的输出,以及最后一个cell的输出,这个又包含了c 和 h,shape分别为[2,4]

每一个时间步的输出:shape 为 [2,3,4]
    tf.Tensor(
        [[[ 0.29594404 -0.06257749  0.00272913  0.38393494]
          [ 0.12317018 -0.10669467  0.21305212 -0.0534559 ]
          [ 0.11735746 -0.03012969  0.08865868 -0.10764799]]
        
         [[-0.07051807  0.02736617  0.07237878 -0.19151129]
          [-0.07522646  0.00569247 -0.01109379 -0.00774325]
          [ 0.05763769 -0.00310471  0.21375947 -0.16625713]]], shape=(2, 3, 4), dtype=float32)
  
    最后一个时间步的输出,包括c 和 h shape 都为 [2,4]
    LSTMStateTuple(
            c=<tf.Tensor: id=309, shape=(2, 4), dtype=float32, numpy=
            array([[ 0.26399267, -0.09096628,  0.1642536 , -0.30149382],
                   [ 0.2447102 , -0.00411555,  0.38746575, -0.21990177]],
                  dtype=float32)>, 
            h=<tf.Tensor: id=312, shape=(2, 4), dtype=float32, numpy=
            array([[ 0.11735746, -0.03012969,  0.08865868, -0.10764799],
                    [ 0.05763769, -0.00310471,  0.21375947, -0.16625713]],
                    dtype=float32)>
            )

这里额外介绍两种LSTM的变体:

Peephole Connection:也就是让每一个门中都加入细胞状态c


peephole_connection_lstm.png

Coupled: forget_gate 和 input_gate 的sigmoid值是相关的,


coupled_lstm.png

BasicLSTMCell 是已经要deprecated的接口,更多的变体在接口tf.nn.rnn_cell.LSTMCell()中可以见到。

我们了解了最基础的BasicLSTMCell,其他变体无非就是在计算方式上做了一些改变,学习起来就很简单了。

更多代码请移步我的个人github,会不定期更新各种框架。
本章代码见code
欢迎关注

上一篇下一篇

猜你喜欢

热点阅读