9.2.1 PTB数据的预处理
2018-08-12 本文已影响0人
醉乡梦浮生
import codecs
import collections
from operator import itemgetter
import sys
def generate_dic(RAW_DATA, VOCAB_OUTPUT):
counter = collections.Counter()
with codecs.open(RAW_DATA, "r", "utf-8") as f:
for line in f:
for word in line.strip().split():
counter[word] += 1
# 按词频顺序对单词进行排序。
sorted_word_to_cnt = sorted(
counter.items(), key=itemgetter(1), reverse=True)
sorted_words = [x[0] for x in sorted_word_to_cnt]
# 稍后我们需要在文本换行处加入句子结束符"<eos>",这里预先将其加入词汇表。
sorted_words = ["<eos>"] + sorted_words
with codecs.open(VOCAB_OUTPUT, 'w', 'utf-8') as file_output:
for word in sorted_words:
file_output.write(word + "\n")
def generate_file(RAW_DATA, VOCAB, OUTPUT_DATA):
# 读取词汇表,并建立词汇到单词编号的映射。
with codecs.open(VOCAB, "r", "utf-8") as f_vocab:
vocab = [w.strip() for w in f_vocab.readlines()]
word_to_id = {k: v for (k, v) in zip(vocab, range(len(vocab)))}
# 如果出现了不在词汇表内的低频词,则替换为"unk"。
def get_id(word):
return word_to_id[word] if word in word_to_id else word_to_id["<unk>"]
fin = codecs.open(RAW_DATA, "r", "utf-8")
fout = codecs.open(OUTPUT_DATA, 'w', 'utf-8')
for line in fin:
words = line.strip().split() + ["<eos>"] # 读取单词并添加<eos>结束符
# 将每个单词替换为词汇表中的编号
out_line = ' '.join([str(get_id(w)) for w in words]) + '\n'
fout.write(out_line)
fin.close()
fout.close()
if __name__ == '__main__':
RAW_DATA = ["PTB_data/ptb.train.txt", "PTB_data/ptb.test.txt", "PTB_data/ptb.valid.txt" ]
VOCAB_OUTPUT = ["ptb_train.vocab", "ptb_test.vocab", "ptb_valid.vocab"]
OUTPUT_DATA = ["ptb.train", "ptb.test", "ptb.valid"]
for i in range(3):
generate_dic(RAW_DATA[i], VOCAB_OUTPUT[i])
generate_file(RAW_DATA[i], VOCAB_OUTPUT[i], OUTPUT_DATA[i])