# -*- coding: utf-8 -*- import gzip import os import struct from typing import Tuple import numpy as np from tqdm import tqdm from ....logger import get_logger from .meta_vision import VisionDataset from .utils import _default_dataset_root, load_raw_data_from_url logger = get_logger(__name__) class MNIST(VisionDataset): r""":class:`~.Dataset` for MNIST meta data.""" url_path = "http://yann.lecun.com/exdb/mnist/" """ Url prefix for downloading raw file. """ raw_file_name = [ "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", ] """ Raw file names of both training set and test set (10k). """ raw_file_md5 = [ "f68b3c2dcbeaaa9fbdd348bbdeb94873", "d53e105ee54ea40749a09fcbcd1e9432", "9fb629c4189551a2d022fa330f9573f3", "ec29112dd5afa0611ce80d1b7f02629c", ] """ Md5 for checking raw files. """ def __init__( self, root: str = None, train: bool = True, download: bool = True, timeout: int = 500, ): r""" :param root: path for mnist dataset downloading or loading, if ``None``, set ``root`` to the ``_default_root``. :param train: if ``True``, loading trainingset, else loading test set. :param download: if raw files do not exists and download sets to ``True``, download raw files and process, otherwise raise ValueError, default is True. """ super().__init__(root, order=("image", "image_category")) self.timeout = timeout # process the root path if root is None: self.root = self._default_root if not os.path.exists(self.root): os.makedirs(self.root) else: self.root = root if not os.path.exists(self.root): if download: logger.debug( "dir %s does not exist, will be automatically created", self.root, ) os.makedirs(self.root) else: raise ValueError("dir %s does not exist" % self.root) if self._check_raw_files(): self.process(train) elif download: self.download() self.process(train) else: raise ValueError( "root does not contain valid raw files, please set download=True" ) def __getitem__(self, index: int) -> Tuple: return tuple(array[index] for array in self.arrays) def __len__(self) -> int: return len(self.arrays[0]) @property def _default_root(self): return os.path.join(_default_dataset_root(), self.__class__.__name__) @property def meta(self): return self._meta_data def _check_raw_files(self): return all( [ os.path.exists(os.path.join(self.root, path)) for path in self.raw_file_name ] ) def download(self): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): url = self.url_path + file_name load_raw_data_from_url(url, file_name, md5, self.root) def process(self, train): # load raw files and transform them into meta data and datasets Tuple(np.array) logger.info("process the raw files of %s set...", "train" if train else "test") if train: meta_data_images, images = parse_idx3( os.path.join(self.root, self.raw_file_name[0]) ) meta_data_labels, labels = parse_idx1( os.path.join(self.root, self.raw_file_name[1]) ) else: meta_data_images, images = parse_idx3( os.path.join(self.root, self.raw_file_name[2]) ) meta_data_labels, labels = parse_idx1( os.path.join(self.root, self.raw_file_name[3]) ) self._meta_data = { "images": meta_data_images, "labels": meta_data_labels, } self.arrays = (images, labels.astype(np.int32)) def parse_idx3(idx3_file): # parse idx3 file to meta data and data in numpy array (images) logger.debug("parse idx3 file %s ...", idx3_file) assert idx3_file.endswith(".gz") with gzip.open(idx3_file, "rb") as f: bin_data = f.read() # parse meta data offset = 0 fmt_header = ">iiii" magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset) meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width} # parse images image_size = height * width offset += struct.calcsize(fmt_header) fmt_image = ">" + str(image_size) + "B" images = [] bar = tqdm(total=meta_data["imgs"], ncols=80) for image in struct.iter_unpack(fmt_image, bin_data[offset:]): images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1))) bar.update() bar.close() return meta_data, images def parse_idx1(idx1_file): # parse idx1 file to meta data and data in numpy array (labels) logger.debug("parse idx1 file %s ...", idx1_file) assert idx1_file.endswith(".gz") with gzip.open(idx1_file, "rb") as f: bin_data = f.read() # parse meta data offset = 0 fmt_header = ">ii" magic, imgs = struct.unpack_from(fmt_header, bin_data, offset) meta_data = {"magic": magic, "imgs": imgs} # parse labels offset += struct.calcsize(fmt_header) fmt_image = ">B" labels = np.empty(imgs, dtype=int) bar = tqdm(total=meta_data["imgs"], ncols=80) for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])): labels[i] = label[0] bar.update() bar.close() return meta_data, labels