pytorch

[PyTorch]可以将处理好的数据使用torch.save存储

2019-06-21  本文已影响0人  VanJordan
def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cache=None, with_labels=False):
    """ Retrieve, tokenize, encode and cache a dataset with optional labels """
    if dataset_cache and os.path.isfile(dataset_cache):
        logger.info("Load encoded dataset from cache at %s", dataset_cache)
        encoded_dataset = torch.load(dataset_cache)
    else:
        # If the dataset is in our list of DATASETS_URL, use this url, otherwise, look for 'train.txt' and 'valid.txt' files
        if dataset_dir in DATASETS_URL:
            dataset_map = DATASETS_URL[dataset_dir]
        else:
            dataset_map = {'train': os.path.join(dataset_dir, 'train.txt'),
                           'valid': os.path.join(dataset_dir, 'valid.txt')}

        logger.info("Get dataset from %s", dataset_dir)
        # Download and read dataset and replace a few token for compatibility with the Bert tokenizer we are using
        dataset = {}
        for split_name in dataset_map.keys():
            dataset_file = cached_path(dataset_map[split_name])
            with open(dataset_file, "r", encoding="utf-8") as f:
                all_lines = f.readlines()
                dataset[split_name] = [
                        line.strip(' ').replace('<unk>', '[UNK]').replace('\n', '[SEP]' if not with_labels else '')
                        for line in tqdm(all_lines)]

        # If we have labels, download and and convert labels in integers
        labels = {}
        if with_labels:
            label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir]
            for split_name in DATASETS_LABELS_URL[dataset_dir]:
                dataset_file = cached_path(dataset_map['labels'][split_name])
                with open(dataset_file, "r", encoding="utf-8") as f:
                    all_lines = f.readlines()
                    labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)]

        # Tokenize and encode the dataset
        logger.info("Tokenize and encode the dataset")
        logging.getLogger("pytorch_pretrained_bert.tokenization").setLevel(logging.ERROR)  # No warning on sample size
        def encode(obj):
            if isinstance(obj, str):
                return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
            if isinstance(obj, dict):
                return dict((n, encode(o)) for n, o in obj.items())
            return list(encode(o) for o in tqdm(obj))
        encoded_dataset = encode(dataset)

        # Add labels if needed, or if we are doing language modeling, add number of words to get word-level ppl and gather in one list
        for split_name in ['train', 'valid']:
            if with_labels:
                encoded_dataset[split_name + '_labels'] = labels[split_name]
            else:
                encoded_dataset[split_name] = [ind for line in encoded_dataset[split_name] for ind in line]
                encoded_dataset[split_name + '_num_words'] = sum(len(line.split(' ')) for line in dataset[split_name])

        # Save to cache
        if dataset_cache:
            logger.info("Save encoded dataset to cache at %s", dataset_cache)
            torch.save(encoded_dataset, dataset_cache)

    return encoded_dataset
上一篇下一篇

猜你喜欢

热点阅读