You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cifar100.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Cifar100 reader class.
  17. """
  18. import builtins
  19. import io
  20. import pickle
  21. import os
  22. import numpy as np
  23. from ..shardutils import check_filename
  24. __all__ = ['Cifar100']
  25. safe_builtins = {
  26. 'range',
  27. 'complex',
  28. 'set',
  29. 'frozenset',
  30. 'slice',
  31. }
  32. class RestrictedUnpickler(pickle.Unpickler):
  33. """
  34. Unpickle allowing only few safe classes from the builtins module or numpy
  35. Raises:
  36. pickle.UnpicklingError: If there is a problem unpickling an object
  37. """
  38. def find_class(self, module, name):
  39. # Only allow safe classes from builtins and numpy
  40. if module == "builtins" and name in safe_builtins:
  41. return getattr(builtins, name)
  42. if module == "numpy.core.multiarray" and name == "_reconstruct":
  43. return getattr(np.core.multiarray, name)
  44. if module == "numpy":
  45. return getattr(np, name)
  46. # Forbid everything else.
  47. raise pickle.UnpicklingError("global '%s.%s' is forbidden" %(module, name))
  48. def restricted_loads(s):
  49. """Helper function analogous to pickle.loads()."""
  50. if isinstance(s, str):
  51. raise TypeError("can not load pickle from unicode string")
  52. f = io.BytesIO(s)
  53. return RestrictedUnpickler(f, encoding='bytes').load()
  54. class Cifar100:
  55. """
  56. Class to convert cifar100 to MindRecord.
  57. Cifar100 contains train & test data. There are 500 training images and
  58. 100 testing images per class. The 100 classes in the CIFAR-100 are grouped
  59. into 20 superclasses.
  60. Args:
  61. path (str): cifar100 directory which contain train and test binary data.
  62. one_hot (bool): one_hot flag.
  63. """
  64. class Test:
  65. pass
  66. def __init__(self, path, one_hot=True):
  67. check_filename(path)
  68. self.path = path
  69. if not isinstance(one_hot, bool):
  70. raise ValueError("The parameter one_hot must be bool")
  71. self.one_hot = one_hot
  72. self.images = None
  73. self.fine_labels = None
  74. self.coarse_labels = None
  75. def load_data(self):
  76. """
  77. Returns a list which contain train data & two labels, test data & two labels.
  78. Returns:
  79. list, train and test images, fine labels, coarse labels.
  80. """
  81. dic = {}
  82. fine_labels = []
  83. coarse_labels = []
  84. files = os.listdir(self.path)
  85. for file in files:
  86. if file == "train":
  87. with open(os.path.join(self.path, file), 'rb') as f: #load train data
  88. dic = restricted_loads(f.read())
  89. images = np.array(dic[b"data"].reshape([-1, 3, 32, 32]))
  90. fine_labels.append(dic[b"fine_labels"])
  91. coarse_labels.append(dic[b"coarse_labels"])
  92. elif file == "test": #load test data
  93. with open(os.path.join(self.path, file), 'rb') as f:
  94. dic = restricted_loads(f.read())
  95. test_images = np.array(dic[b"data"].reshape([-1, 3, 32, 32]))
  96. test_fine_labels = np.array(dic[b"fine_labels"])
  97. test_coarse_labels = np.array(dic[b"coarse_labels"])
  98. dic["train_images"] = images.transpose(0, 2, 3, 1)
  99. dic["train_fine_labels"] = np.array(fine_labels).reshape([-1, 1])
  100. dic["train_coarse_labels"] = np.array(coarse_labels).reshape([-1, 1])
  101. dic["test_images"] = test_images.transpose(0, 2, 3, 1)
  102. dic["test_fine_labels"] = test_fine_labels.reshape([-1, 1])
  103. dic["test_coarse_labels"] = test_coarse_labels.reshape([-1, 1])
  104. if self.one_hot:
  105. dic["train_fine_labels"] = self._one_hot(dic["train_fine_labels"], 100)
  106. dic["train_coarse_labels"] = self._one_hot(dic["train_coarse_labels"], 20)
  107. dic["test_fine_labels"] = self._one_hot(dic["test_fine_labels"], 100)
  108. dic["test_coarse_labels"] = self._one_hot(dic["test_coarse_labels"], 20)
  109. self.images, self.fine_labels, self.coarse_labels = \
  110. dic["train_images"], dic["train_fine_labels"], dic["train_coarse_labels"]
  111. self.Test.images, self.Test.fine_labels, self.Test.coarse_labels = \
  112. dic["test_images"], dic["test_fine_labels"], dic["test_coarse_labels"]
  113. return [dic["train_images"], dic["train_fine_labels"], dic["train_coarse_labels"],
  114. dic["test_images"], dic["test_fine_labels"], dic["test_coarse_labels"]]
  115. def _one_hot(self, labels, num):
  116. """
  117. Returns a numpy.
  118. Returns:
  119. Object, numpy array.
  120. """
  121. size = labels.shape[0]
  122. label_one_hot = np.zeros([size, num])
  123. for i in range(size):
  124. label_one_hot[i, np.squeeze(labels[i])] = 1
  125. return label_one_hot