[fairseq] generate.py

2019-04-28  本文已影响0人  VanJordan

[TOC]

generate.py

task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
    args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Optimize ensemble for generation
for model in models:
    model.make_generation_fast_(
        beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        need_attn=args.print_alignment,
    )
    if args.fp16:
        model.half()
    if use_cuda:
        model.cuda()

replace_unk的原理

align_dict = utils.load_align_dict(args.replace_unk)
def load_align_dict(replace_unk):
    if replace_unk is None:
        align_dict = None
    elif isinstance(replace_unk, str):
        # Load alignment dictionary for unknown word replacement if it was passed as an argument.
        align_dict = {}
        with open(replace_unk, 'r') as f:
            for line in f:
                cols = line.split()
                align_dict[cols[0]] = cols[1]
    else:
        # No alignment dictionary provided but we still want to perform unknown word replacement by copying the
        # original source word.
        align_dict = {}
    return align_dict
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
src_str = src_dict.string(src_tokens, args.remove_bpe)
if not args.quiet:
    if src_dict is not None:
        print('S-{}\t{}'.format(sample_id, src_str))
    if has_target:
        print('T-{}\t{}'.format(sample_id, target_str))
prefix_tokens = None
if args.prefix_size > 0:
    prefix_tokens = sample['target'][:, :args.prefix_size]

gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
if not args.quiet:
    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
    print('P-{}\t{}'.format(
        sample_id,
        ' '.join(map(
            lambda x: '{:.4f}'.format(x),
            hypo['positional_scores'].tolist(),
        ))
    ))
for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
# Process top predictions
for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
        hypo_tokens=hypo['tokens'].int().cpu(),
        src_str=src_str,
        alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
        align_dict=align_dict,
        tgt_dict=tgt_dict,
        remove_bpe=args.remove_bpe,
    )
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
    num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target:
    print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if args.print_alignment:
    print('A-{}\t{}'.format(
        sample_id,
        ' '.join(map(lambda x: str(utils.item(x)), alignment))
    ))

其他部分

上一篇下一篇

猜你喜欢

热点阅读