GENERATING LONG AND DIVERSE RESP

2017-09-14  本文已影响21人  yingtaomj

和gnmt的区别在于:

train环节
            for segNum in  range(int(max_decoder_input_len / 10)):
                tmp_max_encoder_input_len = segNum * 10 + max_encoder_input_len#当前最大值
                #  choose a bucket(depends on the encoder length)
                bucket_id = choose_bucket(tmp_max_encoder_input_len)
                encoder_inputs, decoder_inputs, target_weights = model.pad_pair(encoder_inputs_ori,
                                                                                decoder_inputs_ori,
                                                                                segNum, FLAGS.segment_length,
                                                                                bucket_id, segNum == 0)
                _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                                             target_weights, bucket_id, False)
                step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
                loss += step_loss / FLAGS.steps_per_checkpoint

即:每次根据最长长度取不同的bucket,然后调用pad_pair函数来生成encoder_inputs, decoder_inputs,step步骤和正常一样。

decode环节

main contribution在这里。
generate_response是直接负责的函数,它返回最终生成的回答句子。具体过程是:

从beam num个候选中选择分数最高的
上一篇 下一篇

猜你喜欢

热点阅读