You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterators.py 9.8 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in iterators.
  16. """
  17. from abc import abstractmethod
  18. import copy
  19. import weakref
  20. from mindspore._c_dataengine import DEPipeline
  21. from mindspore._c_dataengine import OpName
  22. from mindspore import log as logger
  23. from . import datasets as de
  24. ITERATORS_LIST = list()
  25. def _cleanup():
  26. for itr_ref in ITERATORS_LIST:
  27. itr = itr_ref()
  28. if itr is not None:
  29. itr.release()
  30. def alter_tree(node):
  31. """Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
  32. if not node.input:
  33. return _alter_node(node)
  34. converted_children = []
  35. for input_op in node.input:
  36. converted_children.append(alter_tree(input_op))
  37. node.input = converted_children
  38. return _alter_node(node)
  39. def _alter_node(node):
  40. """Performing some alteration to a dataset node. A common alteration is to insert a node."""
  41. if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL:
  42. # Remove the connection between the parent's node to the current node because we are inserting a node.
  43. if node.output:
  44. node.output.pop()
  45. # Perform a fast scan for average rows per file
  46. if isinstance(node, de.TFRecordDataset):
  47. avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files)
  48. else:
  49. avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files)
  50. # Shuffle between 4 files with a minimum size of 10000 rows
  51. new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000))
  52. return new_shuffle
  53. if isinstance(node, de.MapDataset):
  54. if node.python_multiprocessing:
  55. # Bootstrap can only be performed on a copy of the original dataset node.
  56. # Bootstrap on original dataset node will make all iterators share the same process pool
  57. node.iterator_bootstrap()
  58. if node.columns_order is not None:
  59. # Remove the connection between the parent's node to the current node because we are inserting a node.
  60. if node.output:
  61. node.output.pop()
  62. return node.project(node.columns_order)
  63. return node
  64. class Iterator:
  65. """
  66. General Iterator over a dataset.
  67. Attributes:
  68. dataset: Dataset to be iterated over
  69. """
  70. def __init__(self, dataset):
  71. ITERATORS_LIST.append(weakref.ref(self))
  72. # create a copy of tree and work on it.
  73. self.dataset = copy.deepcopy(dataset)
  74. self.dataset = alter_tree(self.dataset)
  75. if not self.__is_tree():
  76. raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
  77. self.depipeline = DEPipeline()
  78. # for manifest temporary use
  79. self.__batch_node(self.dataset, 0)
  80. root = self.__convert_node_postorder(self.dataset)
  81. self.depipeline.AssignRootNode(root)
  82. self.depipeline.LaunchTreeExec()
  83. self._index = 0
  84. def __is_tree_node(self, node):
  85. """Check if a node is tree node."""
  86. if not node.input:
  87. if len(node.output) > 1:
  88. return False
  89. if len(node.output) > 1:
  90. return False
  91. for input_node in node.input:
  92. cls = self.__is_tree_node(input_node)
  93. if not cls:
  94. return False
  95. return True
  96. def __is_tree(self):
  97. return self.__is_tree_node(self.dataset)
  98. @staticmethod
  99. def __get_dataset_type(dataset):
  100. """Get the dataset type."""
  101. op_type = None
  102. if isinstance(dataset, de.ShuffleDataset):
  103. op_type = OpName.SHUFFLE
  104. elif isinstance(dataset, de.MindDataset):
  105. op_type = OpName.MINDRECORD
  106. elif isinstance(dataset, de.BatchDataset):
  107. op_type = OpName.BATCH
  108. elif isinstance(dataset, de.SyncWaitDataset):
  109. op_type = OpName.BARRIER
  110. elif isinstance(dataset, de.ZipDataset):
  111. op_type = OpName.ZIP
  112. elif isinstance(dataset, de.MapDataset):
  113. op_type = OpName.MAP
  114. elif isinstance(dataset, de.FilterDataset):
  115. op_type = OpName.FILTER
  116. elif isinstance(dataset, de.RepeatDataset):
  117. op_type = OpName.REPEAT
  118. elif isinstance(dataset, de.SkipDataset):
  119. op_type = OpName.SKIP
  120. elif isinstance(dataset, de.TakeDataset):
  121. op_type = OpName.TAKE
  122. elif isinstance(dataset, de.StorageDataset):
  123. op_type = OpName.STORAGE
  124. elif isinstance(dataset, de.ImageFolderDatasetV2):
  125. op_type = OpName.IMAGEFOLDER
  126. elif isinstance(dataset, de.GeneratorDataset):
  127. op_type = OpName.GENERATOR
  128. elif isinstance(dataset, de.TransferDataset):
  129. op_type = OpName.DEVICEQUEUE
  130. elif isinstance(dataset, de.RenameDataset):
  131. op_type = OpName.RENAME
  132. elif isinstance(dataset, de.TFRecordDataset):
  133. op_type = OpName.TFREADER
  134. elif isinstance(dataset, de.ProjectDataset):
  135. op_type = OpName.PROJECT
  136. elif isinstance(dataset, de.MnistDataset):
  137. op_type = OpName.MNIST
  138. elif isinstance(dataset, de.ManifestDataset):
  139. op_type = OpName.MANIFEST
  140. elif isinstance(dataset, de.VOCDataset):
  141. op_type = OpName.VOC
  142. elif isinstance(dataset, de.Cifar10Dataset):
  143. op_type = OpName.CIFAR10
  144. elif isinstance(dataset, de.Cifar100Dataset):
  145. op_type = OpName.CIFAR100
  146. elif isinstance(dataset, de.CelebADataset):
  147. op_type = OpName.CELEBA
  148. elif isinstance(dataset, de.TextFileDataset):
  149. op_type = OpName.TEXTFILE
  150. else:
  151. raise ValueError("Unsupported DatasetOp")
  152. return op_type
  153. # Convert python node into C node and add to C layer execution tree in postorder traversal.
  154. def __convert_node_postorder(self, node):
  155. op_type = self.__get_dataset_type(node)
  156. c_node = self.depipeline.AddNodeToTree(op_type, node.get_args())
  157. for py_child in node.input:
  158. c_child = self.__convert_node_postorder(py_child)
  159. self.depipeline.AddChildToParentNode(c_child, c_node)
  160. return c_node
  161. def __batch_node(self, dataset, level):
  162. """Recursively get batch node in the dataset tree."""
  163. if isinstance(dataset, de.BatchDataset):
  164. return
  165. for input_op in dataset.input:
  166. self.__batch_node(input_op, level + 1)
  167. @staticmethod
  168. def __print_local(dataset, level):
  169. """Recursively print the name and address of nodes in the dataset tree."""
  170. name = dataset.__class__.__name__
  171. ptr = hex(id(dataset))
  172. for _ in range(level):
  173. logger.info("\t", end='')
  174. if not dataset.input:
  175. logger.info("-%s (%s)", name, ptr)
  176. else:
  177. logger.info("+%s (%s)", name, ptr)
  178. for input_op in dataset.input:
  179. Iterator.__print_local(input_op, level + 1)
  180. def print(self):
  181. """Print the dataset tree"""
  182. self.__print_local(self.dataset, 0)
  183. def release(self):
  184. if hasattr(self, 'depipeline') and self.depipeline:
  185. del self.depipeline
  186. @abstractmethod
  187. def get_next(self):
  188. pass
  189. def __next__(self):
  190. data = self.get_next()
  191. if not data:
  192. if self._index == 0:
  193. logger.warning("No records available.")
  194. raise StopIteration
  195. self._index += 1
  196. return data
  197. def get_output_shapes(self):
  198. return [t for t in self.depipeline.GetOutputShapes()]
  199. def get_output_types(self):
  200. return [t for t in self.depipeline.GetOutputTypes()]
  201. def get_dataset_size(self):
  202. return self.depipeline.GetDatasetSize()
  203. def get_batch_size(self):
  204. return self.depipeline.GetBatchSize()
  205. def get_repeat_count(self):
  206. return self.depipeline.GetRepeatCount()
  207. def num_classes(self):
  208. return self.depipeline.GetNumClasses()
  209. def __deepcopy__(self, memo):
  210. return self
  211. class DictIterator(Iterator):
  212. """
  213. The derived class of Iterator with dict type.
  214. """
  215. def __iter__(self):
  216. return self
  217. def get_next(self):
  218. """
  219. Returns the next record in the dataset as dictionary
  220. Returns:
  221. Dict, the next record in the dataset.
  222. """
  223. return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
  224. class TupleIterator(Iterator):
  225. """
  226. The derived class of Iterator with list type.
  227. """
  228. def __init__(self, dataset, columns=None):
  229. if columns is not None:
  230. if not isinstance(columns, list):
  231. columns = [columns]
  232. dataset = dataset.project(columns)
  233. super().__init__(dataset)
  234. def __iter__(self):
  235. return self
  236. def get_next(self):
  237. """
  238. Returns the next record in the dataset as a list
  239. Returns:
  240. List, the next record in the dataset.
  241. """
  242. return [t.as_array() for t in self.depipeline.GetNextAsList()]