CRNN+Attention
2018-07-18 本文已影响0人
翻开日记
"""Current"""
import tensorflow as tf
import numpy as np
import mobilenet_v1
from config import hparams_at as hparams
from EasyModel import CNN, lstmRNN
slim = tf.contrib.slim
class Model:
def __init__(self, hparams, is_training=True):
self.hparams = hparams
self.batch_size = self.hparams.batch_size
self.hidden = self.hparams.hidden
self.num_class = self.hparams.num_class
initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.1)
regularizer = slim.l2_regularizer(scale=0.00004)
with tf.name_scope('PlaceHolder'):
self.inputs = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, 32, None, 3], name="sentences")
self.gtlabels = tf.sparse_placeholder(tf.int32, name='gtlabel')
self.seq_lens = tf.placeholder(dtype=tf.int32, shape=[self.batch_size], name='seq_lens')
batch_norm_params = {
'is_training': is_training,
'center': True,
'scale': True,
'decay': 0.9,
'epsilon': 0.001,
}
with slim.arg_scope([slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],
weights_initializer=initializer,
weights_regularizer=regularizer,
activation_fn=tf.nn.relu6,
normalizer_fn=slim.batch_norm):
with slim.arg_scope([tf.contrib.layers.batch_norm], **batch_norm_params):
with slim.arg_scope([slim.dropout], is_training=is_training):
with tf.variable_scope(None, 'cnn_model'):
sequence = CNN(self.inputs)
print("sequence:", sequence)
with tf.name_scope("Attention"):
rnn_output1 = lstmRNN(sequence, self.hidden, self.seq_lens, .5)
attention = tf.layers.dense(sequence, self.hidden, use_bias=False)
rnn2_inputs = tf.nn.tanh(rnn_output1 + attention)
with tf.name_scope("RNN"):
rnn_output2 = lstmRNN(rnn2_inputs, self.hidden, self.seq_lens, .5)
rnn_output2 = tf.reshape(rnn_output2, [-1, self.hidden])
with tf.name_scope('SotfmaxResult'):
softmax_w = tf.Variable(tf.truncated_normal([self.hidden, self.num_class],
stddev=0.1, dtype=tf.float32), dtype=tf.float32)
softmax_b = tf.Variable(tf.constant(0.0, shape=[self.num_class],
dtype=tf.float32), dtype=tf.float32)
softmax_input = tf.reshape(rnn_output2, [-1, self.hidden])
rnn_logit = tf.add(tf.matmul(softmax_input, softmax_w), softmax_b)
self.scores = tf.reshape(rnn_logit, [self.batch_size, -1, self.num_class], name='scores')
print("Model Finished")
print("Total Parameters: ", np.sum(np.prod(v.get_shape().as_list()) for v in tf.global_variables()))
with tf.name_scope("others"):
self.global_step = tf.Variable(0, name='global_step', trainable=False)
with tf.name_scope("opt"):
logits = tf.transpose(self.scores, [1, 0, 2])
ctc_loss = tf.nn.ctc_loss(self.gtlabels, logits, self.seq_lens)
ctc_loss = tf.reduce_mean(ctc_loss)
self.loss = ctc_loss
self.learning_rate = tf.train.exponential_decay(learning_rate=self.hparams.init_lr,
global_step=self.global_step,
decay_steps=self.hparams.decay_steps,
decay_rate=self.hparams.decay_rate,
staircase=False)
self.learning_rate = tf.maximum(self.learning_rate, 1e-6, name='clip')
opt = tf.train.RMSPropOptimizer(self.learning_rate)
grads_and_vars = opt.compute_gradients(self.loss)
clipped_gvs = []
for grad, var in grads_and_vars:
if grad is not None:
clipped_gvs.append((tf.clip_by_value(grad, -self.hparams.grad_clip,
self.hparams.grad_clip), var))
else:
clipped_gvs.append((grad, var))
train_step = opt.apply_gradients(clipped_gvs, global_step=self.global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(train_step)
self.update_op = tf.group(*update_ops)
with tf.name_scope("EditDistance"):
scores = tf.transpose(self.scores, [1, 0, 2])
self.decoded, _ = tf.nn.ctc_beam_search_decoder(scores, self.seq_lens, merge_repeated=False)
self.edit_distance = tf.reduce_mean(tf.edit_distance(tf.cast(self.decoded[0], tf.int32), self.gtlabels))
self.dense_decoded = tf.sparse_tensor_to_dense(self.decoded[0], default_value=self.hparams.num_class)
text_tensor = tf.constant(self.hparams.wanted_words)
self.decoded_text = tf.gather(text_tensor, self.dense_decoded-1)
if __name__ == '__main__':
model = Model(hparams)