[tf]进行变量和层的重用
2019-01-13 本文已影响0人
VanJordan
class LSTM(object):
"""LSTM layer using dynamic_rnn.
Exposes variables in `trainable_weights` property.
"""
def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
self.cell_size = cell_size
self.num_layers = num_layers
self.keep_prob = keep_prob
self.reuse = None
self.trainable_weights = None
self.name = name
print("You got one LSTM")
def __call__(self, x, initial_state, seq_length):
with tf.variable_scope(self.name, reuse=self.reuse) as vs:
cell = tf.contrib.rnn.MultiRNNCell([
tf.contrib.rnn.BasicLSTMCell(
self.cell_size,
forget_bias=0.0,
reuse=tf.get_variable_scope().reuse)
for _ in xrange(self.num_layers)
])
# shape(x) = (batch_size, num_timesteps, embedding_dim)
lstm_out, next_state = tf.nn.dynamic_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length)
# shape(lstm_out) = (batch_size, timesteps, cell_size)
if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
if self.reuse is None:
self.trainable_weights = vs.global_variables()
self.reuse = True
return lstm_out, next_state