AI 专属数据库的定制
2018-06-28 本文已影响9人
水之心
我们知道诸如 Keras
、MXNet
、Tensorflow
各大平台都封装了自己的基础数据集,如 MNIST
、cifar
等。如果我们要在不同平台使用这些数据集,还需要了解它们是如何组织这些数据集的,需要花费一些不必要的时间学习它们的 API。为此,我们为何不创建属于自己的数据集呢?下面我仅仅使用了 Numpy
来实现数据集 MNIST
、Fashion MNIST
、Cifa 10
、Cifar 100
的封装。
API 介绍
环境搭建
我使用了 Anaconda
这个十分好用的包管理工具。我们需要载入一些必须包:
import struct
import numpy as np
import gzip, tarfile
import os
import pickle
import time
一个十分好用的 python 结构,具体可参考:好用的 Bunch
class Bunch(dict):
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)
self.__dict__ = self
下载数据集
链接:
MNIST
class MNIST:
def __init__(self, root, namespace, train=True, transform=None):
"""
(MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist)
(A dataset of Zalando's article images consisting of fashion products,
a drop-in replacement of the original MNIST dataset from https://github.com/zalandoresearch/fashion-mnist)
Each sample is an image (in 3D NDArray) with shape (28, 28, 1).
Parameters
----------
root : 数据根目录,如 'E:/Data/Zip/'
namespace : 'mnist' or 'fashion_mnist'
train : bool, default True
Whether to load the training or testing set.
transform : function, default None
A user defined callback that transforms each sample. For example:
::
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
self._train = train
self.namespace = namespace
root = root + namespace
self._train_data = f'{root}/train-images-idx3-ubyte.gz'
self._train_label = f'{root}/train-labels-idx1-ubyte.gz'
self._test_data = f'{root}/t10k-images-idx3-ubyte.gz'
self._test_label = f'{root}/t10k-labels-idx1-ubyte.gz'
self._get_data()
def _get_data(self):
'''
官方网站的数据是以 `[offset][type][value][description]` 的格式封装的,因而 `struct.unpack` 时需要注意
'''
if self._train:
data, label = self._train_data, self._train_label
else:
data, label = self._test_data, self._test_label
with gzip.open(label, 'rb') as fin:
struct.unpack(">II", fin.read(8))
self.label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)
with gzip.open(data, 'rb') as fin:
Y = struct.unpack(">IIII", fin.read(16))
data = np.frombuffer(fin.read(), dtype=np.uint8)
self.data = data.reshape(Y[1:])