TensorFlow源码学习 | tf.contrib.rnn.
2018-07-06 本文已影响21人
简书已注销
tf.contrib.rnn.BasicLSTMCell(经典结构)
__init__(
num_units,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None
)
【参数信息】:
- num_units: int类型,LSTM cell中的单元数;
- forget_bias: float类型,遗忘门中的添加的偏置值。在从CudnnLSTM-trained检查点恢复必须手动置0.0;
- state_is_tuple: 如果是True,返回的是一个二元组,包含两个状态c_state和m_state。如果是False,沿着列方向将c_state和m_state拼接成一个向量。但是不赞成将state_is_tuple置为False;
- activation: 内部状态的激活函数。 默认值:tanh;
- reuse: (可选参数)Python的布尔值描述是否在一个已知域中去重新使用变量。如果是False,那么已知的域就会重新生成变量,会出bug;
- 如果从CudnnLSTM-trained的检查点中恢复模型,必须使用CudnnCompatibleLSTMCell。
【BasicLSTMCell(经典结构)】
【源码】:
将上图的结构和源码对照着看
if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)
new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o) # 新的输出
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h) # 返回两个变量
else:
new_state = array_ops.concat([new_c, new_h], 1) #按列拼接
return new_h, new_state
tf.contrib.rnn.LSTMCell
__init__(
num_units,
use_peepholes=False,
cell_clip=None,
initializer=None,
num_proj=None,
proj_clip=None,
num_unit_shards=None,
num_proj_shards=None,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None
)
参数信息:
- num_units:int类型,LSTM cell中的单元数;
- use_peepholes: bool类型, 设置为True 的时候增加了窥视孔(diagonal/peephole connections);
- cell_clip: (可选) 一个float浮点型值, if provided the cell state is clipped by this value prior to the cell output activation.
- initializer: (可选) 为了权重矩阵和投影(projection)矩阵的初始化;
- num_proj: (可选) int类型,投影(projection)矩阵的输出维度。如果设置成None, 没有执行的投影矩阵;
- proj_clip: (可选) 一个浮点型值, 如果num_proj > 0 并且给出proj_clip的值, 那么投影矩阵元素的值是在[-proj_clip, proj_clip]这个区间中;
- num_unit_shards:很遗憾,这个变量已经在 2017年1月就撤销了!
- num_proj_shards: 很遗憾,这个变量已经在 2017年1月就撤销了!
- forget_bias: float类型,遗忘门中的添加的偏置值。在从CudnnLSTM-trained检查点恢复必须手动置0.0;
- state_is_tuple: 如果是True,返回的是一个二元组,包含两个状态c_state和m_state。如果是False,沿着列方向将c_state和m_state拼接成一个向量。但是不赞成将state_is_tuple置为False;
- activation: 内部状态的激活函数。 默认值:tanh;
- reuse: (可选参数)Python的布尔值描述是否在一个已知域中去重新使用变量。如果是False,那么已知的域就会重新生成变量,会出bug;
- 如果从CudnnLSTM-trained的检查点中恢复模型,必须使用CudnnCompatibleLSTMCell。
【带有窥视孔的LSTM结构】
【源码】:
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = self._linear1([inputs, m_prev])
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
# Diagonal connections
if self._use_peepholes and not self._w_f_diag:
scope = vs.get_variable_scope()
with vs.variable_scope(
scope, initializer=self._initializer) as unit_scope:
with vs.variable_scope(unit_scope):
self._w_f_diag = vs.get_variable(
"w_f_diag", shape=[self._num_units], dtype=dtype)
self._w_i_diag = vs.get_variable(
"w_i_diag", shape=[self._num_units], dtype=dtype)
self._w_o_diag = vs.get_variable(
"w_o_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes: # 使用窥视孔
c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
if self._linear2 is None:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):
with vs.variable_scope("projection") as proj_scope:
if self._num_proj_shards is not None:
proj_scope.set_partitioner(
partitioned_variables.fixed_size_partitioner(
self._num_proj_shards))
self._linear2 = _Linear(m, self._num_proj, False)
m = self._linear2(m) # 在输出之前增加了一层Projection layer,增加一层线性变换
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
array_ops.concat([c, m], 1))
return m, new_state