深度学习

AI 专属数据库的定制

2018-06-28  本文已影响9人  水之心

我们知道诸如 KerasMXNetTensorflow 各大平台都封装了自己的基础数据集,如 MNISTcifar 等。如果我们要在不同平台使用这些数据集,还需要了解它们是如何组织这些数据集的,需要花费一些不必要的时间学习它们的 API。为此,我们为何不创建属于自己的数据集呢?下面我仅仅使用了 Numpy 来实现数据集 MNISTFashion MNISTCifa 10Cifar 100 的封装。

API 介绍

环境搭建

我使用了 Anaconda 这个十分好用的包管理工具。我们需要载入一些必须包:

import struct
import numpy as np
import gzip, tarfile
import os
import pickle
import time

一个十分好用的 python 结构,具体可参考:好用的 Bunch

class Bunch(dict):
    def __init__(self, *args, **kwds):
        super().__init__(*args, **kwds)
        self.__dict__ = self

下载数据集

链接:

MNIST

class MNIST:

    def __init__(self, root, namespace, train=True, transform=None):
        """
        (MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist)
        (A dataset of Zalando's article images consisting of fashion products,
        a drop-in replacement of the original MNIST dataset from https://github.com/zalandoresearch/fashion-mnist)

        Each sample is an image (in 3D NDArray) with shape (28, 28, 1).

        Parameters
        ----------
        root : 数据根目录,如 'E:/Data/Zip/'
        namespace : 'mnist' or 'fashion_mnist'
        train : bool, default True
            Whether to load the training or testing set.
        transform : function, default None
            A user defined callback that transforms each sample. For example:
        ::

            transform=lambda data, label: (data.astype(np.float32)/255, label)
        """
        self._train = train
        self.namespace = namespace
        root = root + namespace
        self._train_data = f'{root}/train-images-idx3-ubyte.gz'
        self._train_label = f'{root}/train-labels-idx1-ubyte.gz'
        self._test_data = f'{root}/t10k-images-idx3-ubyte.gz'
        self._test_label = f'{root}/t10k-labels-idx1-ubyte.gz'
        self._get_data()

    def _get_data(self):
        '''
        官方网站的数据是以 `[offset][type][value][description]` 的格式封装的,因而 `struct.unpack` 时需要注意
        '''
        if self._train:
            data, label = self._train_data, self._train_label
        else:
            data, label = self._test_data, self._test_label

        with gzip.open(label, 'rb') as fin:
            struct.unpack(">II", fin.read(8))
            self.label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)

        with gzip.open(data, 'rb') as fin:
            Y = struct.unpack(">IIII", fin.read(16))
            data = np.frombuffer(fin.read(), dtype=np.uint8)
            self.data = data.reshape(Y[1:])

Cifar

上一篇 下一篇

猜你喜欢

热点阅读