You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mnist.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # -*- coding: utf-8 -*-
  2. import gzip
  3. import os
  4. import struct
  5. from typing import Tuple
  6. import numpy as np
  7. from tqdm import tqdm
  8. from ....logger import get_logger
  9. from .meta_vision import VisionDataset
  10. from .utils import _default_dataset_root, load_raw_data_from_url
  11. logger = get_logger(__name__)
  12. class MNIST(VisionDataset):
  13. r""":class:`~.Dataset` for MNIST meta data."""
  14. url_path = "http://yann.lecun.com/exdb/mnist/"
  15. """
  16. Url prefix for downloading raw file.
  17. """
  18. raw_file_name = [
  19. "train-images-idx3-ubyte.gz",
  20. "train-labels-idx1-ubyte.gz",
  21. "t10k-images-idx3-ubyte.gz",
  22. "t10k-labels-idx1-ubyte.gz",
  23. ]
  24. """
  25. Raw file names of both training set and test set (10k).
  26. """
  27. raw_file_md5 = [
  28. "f68b3c2dcbeaaa9fbdd348bbdeb94873",
  29. "d53e105ee54ea40749a09fcbcd1e9432",
  30. "9fb629c4189551a2d022fa330f9573f3",
  31. "ec29112dd5afa0611ce80d1b7f02629c",
  32. ]
  33. """
  34. Md5 for checking raw files.
  35. """
  36. def __init__(
  37. self,
  38. root: str = None,
  39. train: bool = True,
  40. download: bool = True,
  41. timeout: int = 500,
  42. ):
  43. r"""
  44. :param root: path for mnist dataset downloading or loading, if ``None``,
  45. set ``root`` to the ``_default_root``.
  46. :param train: if ``True``, loading trainingset, else loading test set.
  47. :param download: if raw files do not exists and download sets to ``True``,
  48. download raw files and process, otherwise raise ValueError, default is True.
  49. """
  50. super().__init__(root, order=("image", "image_category"))
  51. self.timeout = timeout
  52. # process the root path
  53. if root is None:
  54. self.root = self._default_root
  55. if not os.path.exists(self.root):
  56. os.makedirs(self.root)
  57. else:
  58. self.root = root
  59. if not os.path.exists(self.root):
  60. if download:
  61. logger.debug(
  62. "dir %s does not exist, will be automatically created",
  63. self.root,
  64. )
  65. os.makedirs(self.root)
  66. else:
  67. raise ValueError("dir %s does not exist" % self.root)
  68. if self._check_raw_files():
  69. self.process(train)
  70. elif download:
  71. self.download()
  72. self.process(train)
  73. else:
  74. raise ValueError(
  75. "root does not contain valid raw files, please set download=True"
  76. )
  77. def __getitem__(self, index: int) -> Tuple:
  78. return tuple(array[index] for array in self.arrays)
  79. def __len__(self) -> int:
  80. return len(self.arrays[0])
  81. @property
  82. def _default_root(self):
  83. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  84. @property
  85. def meta(self):
  86. return self._meta_data
  87. def _check_raw_files(self):
  88. return all(
  89. [
  90. os.path.exists(os.path.join(self.root, path))
  91. for path in self.raw_file_name
  92. ]
  93. )
  94. def download(self):
  95. for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
  96. url = self.url_path + file_name
  97. load_raw_data_from_url(url, file_name, md5, self.root)
  98. def process(self, train):
  99. # load raw files and transform them into meta data and datasets Tuple(np.array)
  100. logger.info("process the raw files of %s set...", "train" if train else "test")
  101. if train:
  102. meta_data_images, images = parse_idx3(
  103. os.path.join(self.root, self.raw_file_name[0])
  104. )
  105. meta_data_labels, labels = parse_idx1(
  106. os.path.join(self.root, self.raw_file_name[1])
  107. )
  108. else:
  109. meta_data_images, images = parse_idx3(
  110. os.path.join(self.root, self.raw_file_name[2])
  111. )
  112. meta_data_labels, labels = parse_idx1(
  113. os.path.join(self.root, self.raw_file_name[3])
  114. )
  115. self._meta_data = {
  116. "images": meta_data_images,
  117. "labels": meta_data_labels,
  118. }
  119. self.arrays = (images, labels.astype(np.int32))
  120. def parse_idx3(idx3_file):
  121. # parse idx3 file to meta data and data in numpy array (images)
  122. logger.debug("parse idx3 file %s ...", idx3_file)
  123. assert idx3_file.endswith(".gz")
  124. with gzip.open(idx3_file, "rb") as f:
  125. bin_data = f.read()
  126. # parse meta data
  127. offset = 0
  128. fmt_header = ">iiii"
  129. magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset)
  130. meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width}
  131. # parse images
  132. image_size = height * width
  133. offset += struct.calcsize(fmt_header)
  134. fmt_image = ">" + str(image_size) + "B"
  135. images = []
  136. bar = tqdm(total=meta_data["imgs"], ncols=80)
  137. for image in struct.iter_unpack(fmt_image, bin_data[offset:]):
  138. images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1)))
  139. bar.update()
  140. bar.close()
  141. return meta_data, images
  142. def parse_idx1(idx1_file):
  143. # parse idx1 file to meta data and data in numpy array (labels)
  144. logger.debug("parse idx1 file %s ...", idx1_file)
  145. assert idx1_file.endswith(".gz")
  146. with gzip.open(idx1_file, "rb") as f:
  147. bin_data = f.read()
  148. # parse meta data
  149. offset = 0
  150. fmt_header = ">ii"
  151. magic, imgs = struct.unpack_from(fmt_header, bin_data, offset)
  152. meta_data = {"magic": magic, "imgs": imgs}
  153. # parse labels
  154. offset += struct.calcsize(fmt_header)
  155. fmt_image = ">B"
  156. labels = np.empty(imgs, dtype=int)
  157. bar = tqdm(total=meta_data["imgs"], ncols=80)
  158. for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])):
  159. labels[i] = label[0]
  160. bar.update()
  161. bar.close()
  162. return meta_data, labels