9.2.1 PTB数据的预处理

2018-08-12  本文已影响0人  醉乡梦浮生
import codecs
import collections
from operator import  itemgetter
import sys


def generate_dic(RAW_DATA, VOCAB_OUTPUT):
    counter = collections.Counter()
    with codecs.open(RAW_DATA, "r", "utf-8") as f:
        for line in f:
            for word in line.strip().split():
                counter[word] += 1

    # 按词频顺序对单词进行排序。
    sorted_word_to_cnt = sorted(
        counter.items(), key=itemgetter(1), reverse=True)
    sorted_words = [x[0] for x in sorted_word_to_cnt]

    # 稍后我们需要在文本换行处加入句子结束符"<eos>",这里预先将其加入词汇表。
    sorted_words = ["<eos>"] + sorted_words

    with codecs.open(VOCAB_OUTPUT, 'w', 'utf-8') as file_output:
        for word in sorted_words:
            file_output.write(word + "\n")


def generate_file(RAW_DATA, VOCAB, OUTPUT_DATA):
    # 读取词汇表,并建立词汇到单词编号的映射。
    with codecs.open(VOCAB, "r", "utf-8") as f_vocab:
        vocab = [w.strip() for w in f_vocab.readlines()]
    word_to_id = {k: v for (k, v) in zip(vocab, range(len(vocab)))}

    # 如果出现了不在词汇表内的低频词,则替换为"unk"。
    def get_id(word):
        return word_to_id[word] if word in word_to_id else word_to_id["<unk>"]

    fin = codecs.open(RAW_DATA, "r", "utf-8")
    fout = codecs.open(OUTPUT_DATA, 'w', 'utf-8')
    for line in fin:
        words = line.strip().split() + ["<eos>"]  # 读取单词并添加<eos>结束符
        # 将每个单词替换为词汇表中的编号
        out_line = ' '.join([str(get_id(w)) for w in words]) + '\n'
        fout.write(out_line)
    fin.close()
    fout.close()


if __name__ == '__main__':
    RAW_DATA = ["PTB_data/ptb.train.txt", "PTB_data/ptb.test.txt", "PTB_data/ptb.valid.txt" ]
    VOCAB_OUTPUT = ["ptb_train.vocab", "ptb_test.vocab", "ptb_valid.vocab"]
    OUTPUT_DATA = ["ptb.train", "ptb.test", "ptb.valid"]

    for i in range(3):
        generate_dic(RAW_DATA[i], VOCAB_OUTPUT[i])
        generate_file(RAW_DATA[i], VOCAB_OUTPUT[i], OUTPUT_DATA[i])
上一篇下一篇

猜你喜欢

热点阅读