You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mnist.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import gzip
  10. import os
  11. import pickle
  12. import struct
  13. from typing import Tuple
  14. import numpy as np
  15. from tqdm import tqdm
  16. from ....logger import get_logger
  17. from .meta_vision import VisionDataset
  18. from .utils import _default_dataset_root, load_raw_data_from_url
  19. logger = get_logger(__name__)
  20. class MNIST(VisionDataset):
  21. r""" ``Dataset`` for MNIST meta data
  22. """
  23. url_path = "http://yann.lecun.com/exdb/mnist/"
  24. """
  25. url prefix for downloading raw file
  26. """
  27. raw_file_name = [
  28. "train-images-idx3-ubyte.gz",
  29. "train-labels-idx1-ubyte.gz",
  30. "t10k-images-idx3-ubyte.gz",
  31. "t10k-labels-idx1-ubyte.gz",
  32. ]
  33. """
  34. raw file names of both training set and test set (10k)
  35. """
  36. raw_file_md5 = [
  37. "f68b3c2dcbeaaa9fbdd348bbdeb94873",
  38. "d53e105ee54ea40749a09fcbcd1e9432",
  39. "9fb629c4189551a2d022fa330f9573f3",
  40. "ec29112dd5afa0611ce80d1b7f02629c",
  41. ]
  42. """
  43. md5 for checking raw files
  44. """
  45. train_file = "train.pkl"
  46. """
  47. default pickle file name of training set and its meta data
  48. """
  49. test_file = "test.pkl"
  50. """
  51. default pickle file name of test set and its meta data
  52. """
  53. def __init__(
  54. self,
  55. root: str = None,
  56. train: bool = True,
  57. download: bool = True,
  58. timeout: int = 500,
  59. ):
  60. r"""
  61. initialization:
  62. 1. check root path and target file (train or test)
  63. 2. check target file exists
  64. * if exists:
  65. * load pickle file as meta-data and data in MNIST dataset
  66. * else:
  67. * if download:
  68. a. load all raw datas (both train and test set) by url
  69. b. process raw data ( idx3/idx1 -> dict (meta-data) ,numpy.array (data) )
  70. c. save meta-data and data as pickle file
  71. d. load pickle file as meta-data and data in MNIST dataset
  72. :param root: path for mnist dataset downloading or loading, if ``None``,
  73. set ``root`` to the ``_default_root``
  74. :param train: if ``True``, loading trainingset, else loading test set
  75. :param download: after checking the target files existence, if target files do not
  76. exists and download sets to ``True``, download raw files and process,
  77. then load, otherwise raise ValueError, default is True
  78. """
  79. super().__init__(root, order=("image", "image_category"))
  80. self.timeout = timeout
  81. # process the root path
  82. if root is None:
  83. self.root = self._default_root
  84. if not os.path.exists(self.root):
  85. os.makedirs(self.root)
  86. else:
  87. self.root = root
  88. if not os.path.exists(self.root):
  89. raise ValueError("dir %s does not exist" % self.root)
  90. # choose the target pickle file
  91. if train:
  92. self.target_file = os.path.join(self.root, self.train_file)
  93. else:
  94. self.target_file = os.path.join(self.root, self.test_file)
  95. # check existence of target pickle file, if exists load the
  96. # pickle file no matter what download is set
  97. if os.path.exists(self.target_file):
  98. self._meta_data, self.arrays = self._load_file(self.target_file)
  99. elif self._check_raw_files():
  100. self.process()
  101. self._meta_data, self.arrays = self._load_file(self.target_file)
  102. else:
  103. if download:
  104. self.download()
  105. self._meta_data, self.arrays = self._load_file(self.target_file)
  106. else:
  107. raise ValueError(
  108. "dir does not contain target file\
  109. %s,please set download=True"
  110. % (self.target_file)
  111. )
  112. def __getitem__(self, index: int) -> Tuple:
  113. return tuple(array[index] for array in self.arrays)
  114. def __len__(self) -> int:
  115. return len(self.arrays[0])
  116. @property
  117. def _default_root(self):
  118. return os.path.join(_default_dataset_root(), self.__class__.__name__)
  119. @property
  120. def meta(self):
  121. return self._meta_data
  122. def _load_file(self, target_file):
  123. with open(target_file, "rb") as f:
  124. return pickle.load(f)
  125. def _check_raw_files(self):
  126. return all(
  127. [
  128. os.path.exists(os.path.join(self.root, path))
  129. for path in self.raw_file_name
  130. ]
  131. )
  132. def download(self):
  133. for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
  134. url = self.url_path + file_name
  135. load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
  136. self.process()
  137. def process(self):
  138. # load raw files and transform them into meta data and datasets Tuple(np.array)
  139. logger.info("process raw data ...")
  140. meta_data_images_train, images_train = parse_idx3(
  141. os.path.join(self.root, self.raw_file_name[0])
  142. )
  143. meta_data_labels_train, labels_train = parse_idx1(
  144. os.path.join(self.root, self.raw_file_name[1])
  145. )
  146. meta_data_images_test, images_test = parse_idx3(
  147. os.path.join(self.root, self.raw_file_name[2])
  148. )
  149. meta_data_labels_test, labels_test = parse_idx1(
  150. os.path.join(self.root, self.raw_file_name[3])
  151. )
  152. meta_data_train = {
  153. "images": meta_data_images_train,
  154. "labels": meta_data_labels_train,
  155. }
  156. meta_data_test = {
  157. "images": meta_data_images_test,
  158. "labels": meta_data_labels_test,
  159. }
  160. dataset_train = (images_train, labels_train)
  161. dataset_test = (images_test, labels_test)
  162. # save both training set and test set as pickle files
  163. with open(os.path.join(self.root, self.train_file), "wb") as f:
  164. pickle.dump((meta_data_train, dataset_train), f, pickle.HIGHEST_PROTOCOL)
  165. with open(os.path.join(self.root, self.test_file), "wb") as f:
  166. pickle.dump((meta_data_test, dataset_test), f, pickle.HIGHEST_PROTOCOL)
  167. def parse_idx3(idx3_file):
  168. # parse idx3 file to meta data and data in numpy array (images)
  169. logger.debug("parse idx3 file %s ..." % idx3_file)
  170. assert idx3_file.endswith(".gz")
  171. with gzip.open(idx3_file, "rb") as f:
  172. bin_data = f.read()
  173. # parse meta data
  174. offset = 0
  175. fmt_header = ">iiii"
  176. magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset)
  177. meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width}
  178. # parse images
  179. image_size = height * width
  180. offset += struct.calcsize(fmt_header)
  181. fmt_image = ">" + str(image_size) + "B"
  182. images = []
  183. bar = tqdm(total=meta_data["imgs"], ncols=80)
  184. for image in struct.iter_unpack(fmt_image, bin_data[offset:]):
  185. images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1)))
  186. bar.update()
  187. bar.close()
  188. return meta_data, images
  189. def parse_idx1(idx1_file):
  190. # parse idx1 file to meta data and data in numpy array (labels)
  191. logger.debug("parse idx1 file %s ..." % idx1_file)
  192. assert idx1_file.endswith(".gz")
  193. with gzip.open(idx1_file, "rb") as f:
  194. bin_data = f.read()
  195. # parse meta data
  196. offset = 0
  197. fmt_header = ">ii"
  198. magic, imgs = struct.unpack_from(fmt_header, bin_data, offset)
  199. meta_data = {"magic": magic, "imgs": imgs}
  200. # parse labels
  201. offset += struct.calcsize(fmt_header)
  202. fmt_image = ">B"
  203. labels = np.empty(imgs, dtype=int)
  204. bar = tqdm(total=meta_data["imgs"], ncols=80)
  205. for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])):
  206. labels[i] = label[0]
  207. bar.update()
  208. bar.close()
  209. return meta_data, labels

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)