You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterators.py 9.4 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in iterators.
  16. """
  17. from abc import abstractmethod
  18. import copy
  19. import weakref
  20. from mindspore._c_dataengine import DEPipeline
  21. from mindspore._c_dataengine import OpName
  22. from mindspore import log as logger
  23. from . import datasets as de
  24. ITERATORS_LIST = list()
  25. def _cleanup():
  26. """Release all the Iterator."""
  27. for itr_ref in ITERATORS_LIST:
  28. itr = itr_ref()
  29. if itr is not None:
  30. itr.release()
  31. def alter_tree(node):
  32. """Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
  33. if not node.children:
  34. return _alter_node(node)
  35. converted_children = []
  36. for input_op in node.children:
  37. converted_children.append(alter_tree(input_op))
  38. node.children = converted_children
  39. return _alter_node(node)
  40. def _alter_node(node):
  41. """DEPRECATED"""
  42. # Please check ccsrc/dataset/engine/opt for tree transformation.
  43. if isinstance(node, de.MapDataset):
  44. if node.python_multiprocessing:
  45. # Bootstrap can only be performed on a copy of the original dataset node.
  46. # Bootstrap on original dataset node will make all iterators share the same process pool
  47. node.iterator_bootstrap()
  48. return node
  49. class Iterator:
  50. """
  51. General Iterator over a dataset.
  52. Attributes:
  53. dataset: Dataset to be iterated over
  54. """
  55. def __init__(self, dataset):
  56. ITERATORS_LIST.append(weakref.ref(self))
  57. # create a copy of tree and work on it.
  58. self.dataset = copy.deepcopy(dataset)
  59. self.dataset = alter_tree(self.dataset)
  60. if not self.__is_tree():
  61. raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
  62. self.depipeline = DEPipeline()
  63. # for manifest temporary use
  64. self.__batch_node(self.dataset, 0)
  65. root = self.__convert_node_postorder(self.dataset)
  66. self.depipeline.AssignRootNode(root)
  67. self.depipeline.LaunchTreeExec()
  68. self._index = 0
  69. def __is_tree_node(self, node):
  70. """Check if a node is tree node."""
  71. if not node.children:
  72. if len(node.parent) > 1:
  73. return False
  74. if len(node.parent) > 1:
  75. return False
  76. for input_node in node.children:
  77. cls = self.__is_tree_node(input_node)
  78. if not cls:
  79. return False
  80. return True
  81. def __is_tree(self):
  82. return self.__is_tree_node(self.dataset)
  83. @staticmethod
  84. def __get_dataset_type(dataset):
  85. """Get the dataset type."""
  86. op_type = None
  87. if isinstance(dataset, de.ShuffleDataset):
  88. op_type = OpName.SHUFFLE
  89. elif isinstance(dataset, de.MindDataset):
  90. op_type = OpName.MINDRECORD
  91. elif isinstance(dataset, de.BatchDataset):
  92. op_type = OpName.BATCH
  93. elif isinstance(dataset, de.BucketBatchByLengthDataset):
  94. op_type = OpName.BUCKETBATCH
  95. elif isinstance(dataset, de.SyncWaitDataset):
  96. op_type = OpName.BARRIER
  97. elif isinstance(dataset, de.ZipDataset):
  98. op_type = OpName.ZIP
  99. elif isinstance(dataset, de.ConcatDataset):
  100. op_type = OpName.CONCAT
  101. elif isinstance(dataset, de.MapDataset):
  102. op_type = OpName.MAP
  103. elif isinstance(dataset, de.FilterDataset):
  104. op_type = OpName.FILTER
  105. elif isinstance(dataset, de.RepeatDataset):
  106. op_type = OpName.REPEAT
  107. elif isinstance(dataset, de.SkipDataset):
  108. op_type = OpName.SKIP
  109. elif isinstance(dataset, de.TakeDataset):
  110. op_type = OpName.TAKE
  111. elif isinstance(dataset, de.ImageFolderDatasetV2):
  112. op_type = OpName.IMAGEFOLDER
  113. elif isinstance(dataset, de.GeneratorDataset):
  114. op_type = OpName.GENERATOR
  115. elif isinstance(dataset, de.TransferDataset):
  116. op_type = OpName.DEVICEQUEUE
  117. elif isinstance(dataset, de.RenameDataset):
  118. op_type = OpName.RENAME
  119. elif isinstance(dataset, de.TFRecordDataset):
  120. op_type = OpName.TFREADER
  121. elif isinstance(dataset, de.ProjectDataset):
  122. op_type = OpName.PROJECT
  123. elif isinstance(dataset, de.MnistDataset):
  124. op_type = OpName.MNIST
  125. elif isinstance(dataset, de.ManifestDataset):
  126. op_type = OpName.MANIFEST
  127. elif isinstance(dataset, de.VOCDataset):
  128. op_type = OpName.VOC
  129. elif isinstance(dataset, de.CocoDataset):
  130. op_type = OpName.COCO
  131. elif isinstance(dataset, de.Cifar10Dataset):
  132. op_type = OpName.CIFAR10
  133. elif isinstance(dataset, de.Cifar100Dataset):
  134. op_type = OpName.CIFAR100
  135. elif isinstance(dataset, de.CelebADataset):
  136. op_type = OpName.CELEBA
  137. elif isinstance(dataset, de.RandomDataset):
  138. op_type = OpName.RANDOMDATA
  139. elif isinstance(dataset, de.TextFileDataset):
  140. op_type = OpName.TEXTFILE
  141. elif isinstance(dataset, de.BuildVocabDataset):
  142. op_type = OpName.BUILDVOCAB
  143. elif isinstance(dataset, de.CLUEDataset):
  144. op_type = OpName.CLUE
  145. else:
  146. raise ValueError("Unsupported DatasetOp")
  147. return op_type
  148. # Convert python node into C node and add to C layer execution tree in postorder traversal.
  149. def __convert_node_postorder(self, node):
  150. op_type = self.__get_dataset_type(node)
  151. c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
  152. for py_child in node.children:
  153. c_child = self.__convert_node_postorder(py_child)
  154. self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
  155. return c_nodes["top"]
  156. def __batch_node(self, dataset, level):
  157. """Recursively get batch node in the dataset tree."""
  158. if isinstance(dataset, de.BatchDataset):
  159. return
  160. for input_op in dataset.children:
  161. self.__batch_node(input_op, level + 1)
  162. @staticmethod
  163. def __print_local(dataset, level):
  164. """Recursively print the name and address of nodes in the dataset tree."""
  165. name = dataset.__class__.__name__
  166. ptr = hex(id(dataset))
  167. for _ in range(level):
  168. logger.info("\t", end='')
  169. if not dataset.children:
  170. logger.info("-%s (%s)", name, ptr)
  171. else:
  172. logger.info("+%s (%s)", name, ptr)
  173. for input_op in dataset.children:
  174. Iterator.__print_local(input_op, level + 1)
  175. def print(self):
  176. """Print the dataset tree"""
  177. self.__print_local(self.dataset, 0)
  178. def release(self):
  179. if hasattr(self, 'depipeline') and self.depipeline:
  180. del self.depipeline
  181. @abstractmethod
  182. def get_next(self):
  183. pass
  184. def __next__(self):
  185. data = self.get_next()
  186. if not data:
  187. if self._index == 0:
  188. logger.warning("No records available.")
  189. raise StopIteration
  190. self._index += 1
  191. return data
  192. def get_output_shapes(self):
  193. return [t for t in self.depipeline.GetOutputShapes()]
  194. def get_output_types(self):
  195. return [t for t in self.depipeline.GetOutputTypes()]
  196. def get_dataset_size(self):
  197. return self.depipeline.GetDatasetSize()
  198. def get_batch_size(self):
  199. return self.depipeline.GetBatchSize()
  200. def get_repeat_count(self):
  201. return self.depipeline.GetRepeatCount()
  202. def num_classes(self):
  203. return self.depipeline.GetNumClasses()
  204. def __deepcopy__(self, memo):
  205. return self
  206. class DictIterator(Iterator):
  207. """
  208. The derived class of Iterator with dict type.
  209. """
  210. def __iter__(self):
  211. return self
  212. def get_next(self):
  213. """
  214. Returns the next record in the dataset as dictionary
  215. Returns:
  216. Dict, the next record in the dataset.
  217. """
  218. return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
  219. class TupleIterator(Iterator):
  220. """
  221. The derived class of Iterator with list type.
  222. """
  223. def __init__(self, dataset, columns=None):
  224. if columns is not None:
  225. if not isinstance(columns, list):
  226. columns = [columns]
  227. dataset = dataset.project(columns)
  228. super().__init__(dataset)
  229. def __iter__(self):
  230. return self
  231. def get_next(self):
  232. """
  233. Returns the next record in the dataset as a list
  234. Returns:
  235. List, the next record in the dataset.
  236. """
  237. return [t.as_array() for t in self.depipeline.GetNextAsList()]