TensorFlow操作

CRNN+Attention

2018-07-18  本文已影响0人  翻开日记
"""Current"""
import tensorflow as tf
import numpy as np
import mobilenet_v1
from config import hparams_at as hparams
from EasyModel import CNN, lstmRNN

slim = tf.contrib.slim

class Model:
    def __init__(self, hparams, is_training=True):
        self.hparams = hparams
        self.batch_size = self.hparams.batch_size
        self.hidden = self.hparams.hidden
        self.num_class = self.hparams.num_class
        initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.1)
        regularizer = slim.l2_regularizer(scale=0.00004)
        with tf.name_scope('PlaceHolder'):
            self.inputs = tf.placeholder(dtype=tf.float32, shape=[self.batch_size, 32, None, 3], name="sentences")
            self.gtlabels = tf.sparse_placeholder(tf.int32, name='gtlabel')
            self.seq_lens = tf.placeholder(dtype=tf.int32, shape=[self.batch_size], name='seq_lens')


        batch_norm_params = {
            'is_training': is_training,
            'center': True,
            'scale': True,
            'decay': 0.9,
            'epsilon': 0.001,
        }
        with slim.arg_scope([slim.conv2d, slim.separable_conv2d, slim.conv2d_transpose],
                            weights_initializer=initializer,
                            weights_regularizer=regularizer,
                            activation_fn=tf.nn.relu6,
                            normalizer_fn=slim.batch_norm):
            with slim.arg_scope([tf.contrib.layers.batch_norm], **batch_norm_params):
                with slim.arg_scope([slim.dropout], is_training=is_training):
                    with tf.variable_scope(None, 'cnn_model'):
                        sequence = CNN(self.inputs)
                        print("sequence:", sequence)
            with tf.name_scope("Attention"):

                rnn_output1 = lstmRNN(sequence, self.hidden, self.seq_lens, .5)
                attention = tf.layers.dense(sequence, self.hidden, use_bias=False)
                rnn2_inputs = tf.nn.tanh(rnn_output1 + attention)

            with tf.name_scope("RNN"):

                rnn_output2 = lstmRNN(rnn2_inputs, self.hidden, self.seq_lens, .5)
                rnn_output2 = tf.reshape(rnn_output2, [-1, self.hidden])

            with tf.name_scope('SotfmaxResult'):

                softmax_w = tf.Variable(tf.truncated_normal([self.hidden, self.num_class],
                                                            stddev=0.1, dtype=tf.float32), dtype=tf.float32)
                softmax_b = tf.Variable(tf.constant(0.0, shape=[self.num_class],
                                                    dtype=tf.float32), dtype=tf.float32)
                softmax_input = tf.reshape(rnn_output2, [-1, self.hidden])
                rnn_logit = tf.add(tf.matmul(softmax_input, softmax_w), softmax_b)
                self.scores = tf.reshape(rnn_logit, [self.batch_size, -1, self.num_class], name='scores')

            print("Model Finished")
            print("Total Parameters: ", np.sum(np.prod(v.get_shape().as_list()) for v in tf.global_variables()))

            with tf.name_scope("others"):
                self.global_step = tf.Variable(0, name='global_step', trainable=False)
            with tf.name_scope("opt"):
                logits = tf.transpose(self.scores, [1, 0, 2])
                ctc_loss = tf.nn.ctc_loss(self.gtlabels, logits, self.seq_lens)
                ctc_loss = tf.reduce_mean(ctc_loss)
                self.loss = ctc_loss
                self.learning_rate = tf.train.exponential_decay(learning_rate=self.hparams.init_lr,
                                                                global_step=self.global_step,
                                                                decay_steps=self.hparams.decay_steps,
                                                                decay_rate=self.hparams.decay_rate,
                                                                staircase=False)
                self.learning_rate = tf.maximum(self.learning_rate, 1e-6, name='clip')
                opt = tf.train.RMSPropOptimizer(self.learning_rate)
                grads_and_vars = opt.compute_gradients(self.loss)
                clipped_gvs = []
                for grad, var in grads_and_vars:
                    if grad is not None:
                        clipped_gvs.append((tf.clip_by_value(grad, -self.hparams.grad_clip,
                                                             self.hparams.grad_clip), var))
                    else:
                        clipped_gvs.append((grad, var))
                train_step = opt.apply_gradients(clipped_gvs, global_step=self.global_step)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                update_ops.append(train_step)
                self.update_op = tf.group(*update_ops)
        with tf.name_scope("EditDistance"):
            scores = tf.transpose(self.scores, [1, 0, 2])
            self.decoded, _ = tf.nn.ctc_beam_search_decoder(scores, self.seq_lens, merge_repeated=False)
            self.edit_distance = tf.reduce_mean(tf.edit_distance(tf.cast(self.decoded[0], tf.int32), self.gtlabels))
            self.dense_decoded = tf.sparse_tensor_to_dense(self.decoded[0], default_value=self.hparams.num_class)
            text_tensor = tf.constant(self.hparams.wanted_words)
            self.decoded_text = tf.gather(text_tensor, self.dense_decoded-1)


if __name__ == '__main__':
    model = Model(hparams)
上一篇 下一篇

猜你喜欢

热点阅读