You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterators.py 14 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in iterators.
  16. """
  17. from abc import abstractmethod
  18. import copy
  19. import weakref
  20. import numpy as np
  21. from mindspore.common.tensor import Tensor
  22. from mindspore._c_dataengine import DEPipeline
  23. from mindspore._c_dataengine import OpName
  24. from mindspore import log as logger
  25. from . import datasets as de
  26. _ITERATOR_CLEANUP = False
  27. def _set_iterator_cleanup():
  28. global _ITERATOR_CLEANUP
  29. _ITERATOR_CLEANUP = True
  30. def _unset_iterator_cleanup():
  31. global _ITERATOR_CLEANUP
  32. _ITERATOR_CLEANUP = False
  33. def check_iterator_cleanup():
  34. global _ITERATOR_CLEANUP
  35. return _ITERATOR_CLEANUP
  36. ITERATORS_LIST = list()
  37. def _cleanup():
  38. """Release all the Iterator."""
  39. _set_iterator_cleanup()
  40. for itr_ref in ITERATORS_LIST:
  41. itr = itr_ref()
  42. if itr is not None:
  43. itr.release()
  44. def alter_tree(node):
  45. """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes."""
  46. if not node.children:
  47. return _alter_node(node)
  48. converted_children = []
  49. for input_op in node.children:
  50. converted_children.append(alter_tree(input_op))
  51. node.children = converted_children
  52. return _alter_node(node)
  53. def _alter_node(node):
  54. """DEPRECATED"""
  55. # Please check ccsrc/dataset/engine/opt for tree transformation.
  56. if isinstance(node, de.MapDataset):
  57. if node.python_multiprocessing:
  58. # Bootstrap can only be performed on a copy of the original dataset node.
  59. # Bootstrap on original dataset node will make all iterators share the same process pool
  60. node.iterator_bootstrap()
  61. return node
  62. class Iterator:
  63. """
  64. General Iterator over a dataset.
  65. Attributes:
  66. dataset: Dataset to be iterated over
  67. """
  68. def __init__(self, dataset, num_epochs=-1, output_numpy=False):
  69. self.num_epochs = num_epochs
  70. self.output_numpy = output_numpy
  71. ITERATORS_LIST.append(weakref.ref(self))
  72. _unset_iterator_cleanup()
  73. # create a copy of tree and work on it.
  74. self.dataset = copy.deepcopy(dataset)
  75. self.ori_dataset = dataset
  76. self.parent_subtree = []
  77. # The dataset passed into the iterator is not the root of the tree.
  78. # Trim the tree by saving the parent subtree into self.parent_subtree and
  79. # restore it after launching our c++ pipeline.
  80. if self.dataset.parent:
  81. logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
  82. self.parent_subtree = self.dataset.parent
  83. self.dataset.parent = []
  84. self.dataset = alter_tree(self.dataset)
  85. if not self.__is_tree():
  86. raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
  87. self.depipeline = DEPipeline()
  88. # for manifest temporary use
  89. self.__batch_node(self.dataset, 0)
  90. root = self.__convert_node_postorder(self.dataset)
  91. self.depipeline.AssignRootNode(root)
  92. self.depipeline.PrepareTree(self.num_epochs)
  93. self._index = 0
  94. def stop(self):
  95. """
  96. Manually terminate Python iterator instead of relying on out of scope destruction.
  97. """
  98. logger.info("terminating Python iterator. This will also terminate c++ pipeline.")
  99. if hasattr(self, 'depipeline') and self.depipeline:
  100. del self.depipeline
  101. def __is_tree_node(self, node):
  102. """Check if a node is tree node."""
  103. if not node.children:
  104. if len(node.parent) > 1:
  105. return False
  106. if len(node.parent) > 1:
  107. return False
  108. for input_node in node.children:
  109. cls = self.__is_tree_node(input_node)
  110. if not cls:
  111. return False
  112. return True
  113. def __is_tree(self):
  114. return self.__is_tree_node(self.dataset)
  115. @staticmethod
  116. def __get_dataset_type(dataset):
  117. """Get the dataset type."""
  118. op_type = None
  119. if isinstance(dataset, de.ShuffleDataset):
  120. op_type = OpName.SHUFFLE
  121. elif isinstance(dataset, de.MindDataset):
  122. op_type = OpName.MINDRECORD
  123. elif isinstance(dataset, de.BatchDataset):
  124. op_type = OpName.BATCH
  125. elif isinstance(dataset, de.BucketBatchByLengthDataset):
  126. op_type = OpName.BUCKETBATCH
  127. elif isinstance(dataset, de.SyncWaitDataset):
  128. op_type = OpName.BARRIER
  129. elif isinstance(dataset, de.ZipDataset):
  130. op_type = OpName.ZIP
  131. elif isinstance(dataset, de.ConcatDataset):
  132. op_type = OpName.CONCAT
  133. elif isinstance(dataset, de.MapDataset):
  134. op_type = OpName.MAP
  135. elif isinstance(dataset, de.FilterDataset):
  136. op_type = OpName.FILTER
  137. elif isinstance(dataset, de.RepeatDataset):
  138. op_type = OpName.REPEAT
  139. elif isinstance(dataset, de.SkipDataset):
  140. op_type = OpName.SKIP
  141. elif isinstance(dataset, de.TakeDataset):
  142. op_type = OpName.TAKE
  143. elif isinstance(dataset, de.ImageFolderDataset):
  144. op_type = OpName.IMAGEFOLDER
  145. elif isinstance(dataset, de.GeneratorDataset):
  146. op_type = OpName.GENERATOR
  147. elif isinstance(dataset, de.TransferDataset):
  148. op_type = OpName.DEVICEQUEUE
  149. elif isinstance(dataset, de.RenameDataset):
  150. op_type = OpName.RENAME
  151. elif isinstance(dataset, de.TFRecordDataset):
  152. op_type = OpName.TFREADER
  153. elif isinstance(dataset, de.ProjectDataset):
  154. op_type = OpName.PROJECT
  155. elif isinstance(dataset, de.MnistDataset):
  156. op_type = OpName.MNIST
  157. elif isinstance(dataset, de.ManifestDataset):
  158. op_type = OpName.MANIFEST
  159. elif isinstance(dataset, de.VOCDataset):
  160. op_type = OpName.VOC
  161. elif isinstance(dataset, de.CocoDataset):
  162. op_type = OpName.COCO
  163. elif isinstance(dataset, de.Cifar10Dataset):
  164. op_type = OpName.CIFAR10
  165. elif isinstance(dataset, de.Cifar100Dataset):
  166. op_type = OpName.CIFAR100
  167. elif isinstance(dataset, de.CelebADataset):
  168. op_type = OpName.CELEBA
  169. elif isinstance(dataset, de.RandomDataset):
  170. op_type = OpName.RANDOMDATA
  171. elif isinstance(dataset, de.TextFileDataset):
  172. op_type = OpName.TEXTFILE
  173. elif isinstance(dataset, de.BuildVocabDataset):
  174. op_type = OpName.BUILDVOCAB
  175. elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
  176. op_type = OpName.SENTENCEPIECEVOCAB
  177. elif isinstance(dataset, de.CLUEDataset):
  178. op_type = OpName.CLUE
  179. elif isinstance(dataset, de.CSVDataset):
  180. op_type = OpName.CSV
  181. else:
  182. raise ValueError("Unsupported DatasetOp")
  183. return op_type
  184. # Convert Python node into C node and add to C layer execution tree in postorder traversal.
  185. def __convert_node_postorder(self, node):
  186. self.check_node_type(node)
  187. op_type = self.__get_dataset_type(node)
  188. c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
  189. for py_child in node.children:
  190. c_child = self.__convert_node_postorder(py_child)
  191. self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
  192. return c_nodes["top"]
  193. def __batch_node(self, dataset, level):
  194. """Recursively get batch node in the dataset tree."""
  195. if isinstance(dataset, de.BatchDataset):
  196. return
  197. for input_op in dataset.children:
  198. self.__batch_node(input_op, level + 1)
  199. @staticmethod
  200. def __print_local(dataset, level):
  201. """Recursively print the name and address of nodes in the dataset tree."""
  202. name = dataset.__class__.__name__
  203. ptr = hex(id(dataset))
  204. for _ in range(level):
  205. logger.info("\t", end='')
  206. if not dataset.children:
  207. logger.info("-%s (%s)", name, ptr)
  208. else:
  209. logger.info("+%s (%s)", name, ptr)
  210. for input_op in dataset.children:
  211. Iterator.__print_local(input_op, level + 1)
  212. def print(self):
  213. """Print the dataset tree"""
  214. self.__print_local(self.dataset, 0)
  215. def release(self):
  216. if hasattr(self, 'depipeline') and self.depipeline:
  217. del self.depipeline
  218. @abstractmethod
  219. def get_next(self):
  220. raise RuntimeError("Calling base class Iterator's get_next is invalid.")
  221. def __next__(self):
  222. if not self.depipeline:
  223. logger.warning("Iterator does not have a running c++ pipeline." +
  224. "It can be because Iterator stop() had been called, or c++ pipeline crashed silently.")
  225. raise RuntimeError("Iterator does not have a running c++ pipeline.")
  226. data = self.get_next()
  227. if not data:
  228. if self._index == 0:
  229. logger.warning("No records available.")
  230. if self.ori_dataset.dataset_size is None:
  231. self.ori_dataset.dataset_size = self._index
  232. raise StopIteration
  233. self._index += 1
  234. return data
  235. @abstractmethod
  236. def check_node_type(self, node):
  237. pass
  238. def get_output_shapes(self):
  239. return [t for t in self.depipeline.GetOutputShapes()]
  240. def get_output_types(self):
  241. return [t for t in self.depipeline.GetOutputTypes()]
  242. def get_dataset_size(self):
  243. return self.depipeline.GetDatasetSize()
  244. def get_batch_size(self):
  245. return self.depipeline.GetBatchSize()
  246. def get_repeat_count(self):
  247. return self.depipeline.GetRepeatCount()
  248. def num_classes(self):
  249. return self.depipeline.GetNumClasses()
  250. def get_col_names(self):
  251. return self.depipeline.GetColumnNames()
  252. def __deepcopy__(self, memo):
  253. return self
  254. class SaveOp(Iterator):
  255. """
  256. The derived class of Iterator with dict type.
  257. """
  258. def __init__(self, dataset, num_epochs=-1):
  259. super().__init__(dataset, num_epochs)
  260. self.depipeline.LaunchTreeExec()
  261. def get_next(self):
  262. pass
  263. def check_node_type(self, node):
  264. if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
  265. logger.warning("Used shuffle, repeat, batch before save operator.")
  266. def save(self, file_names, file_type):
  267. return self.depipeline.SaveDataset(file_names, file_type)
  268. class DictIterator(Iterator):
  269. """
  270. The derived class of Iterator with dict type.
  271. """
  272. def __init__(self, dataset, num_epochs=-1, output_numpy=False):
  273. super().__init__(dataset, num_epochs, output_numpy)
  274. self.depipeline.LaunchTreeExec()
  275. def check_node_type(self, node):
  276. pass
  277. def __iter__(self):
  278. return self
  279. def get_next(self):
  280. """
  281. Returns the next record in the dataset as dictionary
  282. Returns:
  283. Dict, the next record in the dataset.
  284. """
  285. if self.output_numpy:
  286. return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
  287. return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()}
  288. class TupleIterator(Iterator):
  289. """
  290. The derived class of Iterator with list type.
  291. """
  292. def check_node_type(self, node):
  293. pass
  294. def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False):
  295. if columns is not None:
  296. if not isinstance(columns, list):
  297. columns = [columns]
  298. dataset = dataset.project(columns)
  299. super().__init__(dataset, num_epochs, output_numpy)
  300. self.depipeline.LaunchTreeExec()
  301. def __iter__(self):
  302. return self
  303. def get_next(self):
  304. """
  305. Returns the next record in the dataset as a list
  306. Returns:
  307. List, the next record in the dataset.
  308. """
  309. if self.output_numpy:
  310. return [t.as_array() for t in self.depipeline.GetNextAsList()]
  311. return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()]
  312. class DummyIterator():
  313. """
  314. A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
  315. """
  316. def __init__(self, dataset, mode):
  317. self.mode = mode
  318. self.shapes = dataset.output_shapes()
  319. self.types = dataset.output_types()
  320. self.fetched_first = False
  321. def __get_tensor(self):
  322. tensor_row = []
  323. for np_shape, np_type in zip(self.shapes, self.types):
  324. input_np = np.zeros(np_shape, np_type)
  325. tensor = Tensor(input_np)
  326. tensor_row.append(tensor)
  327. return tensor_row
  328. def __iter__(self):
  329. return self
  330. def __next__(self):
  331. if self.mode == "tuple":
  332. if not self.fetched_first:
  333. self.fetched_first = True
  334. return self.__get_tensor()
  335. raise StopIteration()