You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterators.py 12 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in iterators.
  16. """
  17. from abc import abstractmethod
  18. import copy
  19. import weakref
  20. import numpy as np
  21. from mindspore.common.tensor import Tensor
  22. from mindspore._c_dataengine import DEPipeline
  23. from mindspore._c_dataengine import OpName
  24. from mindspore import log as logger
  25. from . import datasets as de
  26. ITERATORS_LIST = list()
  27. def _cleanup():
  28. """Release all the Iterator."""
  29. for itr_ref in ITERATORS_LIST:
  30. itr = itr_ref()
  31. if itr is not None:
  32. itr.release()
  33. def alter_tree(node):
  34. """Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
  35. if not node.children:
  36. return _alter_node(node)
  37. converted_children = []
  38. for input_op in node.children:
  39. converted_children.append(alter_tree(input_op))
  40. node.children = converted_children
  41. return _alter_node(node)
  42. def _alter_node(node):
  43. """DEPRECATED"""
  44. # Please check ccsrc/dataset/engine/opt for tree transformation.
  45. if isinstance(node, de.MapDataset):
  46. if node.python_multiprocessing:
  47. # Bootstrap can only be performed on a copy of the original dataset node.
  48. # Bootstrap on original dataset node will make all iterators share the same process pool
  49. node.iterator_bootstrap()
  50. return node
  51. class Iterator:
  52. """
  53. General Iterator over a dataset.
  54. Attributes:
  55. dataset: Dataset to be iterated over
  56. """
  57. def __init__(self, dataset, num_epochs=-1):
  58. self.num_epochs = num_epochs
  59. ITERATORS_LIST.append(weakref.ref(self))
  60. # create a copy of tree and work on it.
  61. self.dataset = copy.deepcopy(dataset)
  62. self.ori_dataset = dataset
  63. self.parent_subtree = []
  64. # The dataset passed into the iterator is not the root of the tree.
  65. # Trim the tree by saving the parent subtree into self.parent_subtree and
  66. # restore it after launching our c++ pipeline.
  67. if self.dataset.parent:
  68. logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
  69. self.parent_subtree = self.dataset.parent
  70. self.dataset.parent = []
  71. self.dataset = alter_tree(self.dataset)
  72. if not self.__is_tree():
  73. raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
  74. self.depipeline = DEPipeline()
  75. # for manifest temporary use
  76. self.__batch_node(self.dataset, 0)
  77. root = self.__convert_node_postorder(self.dataset)
  78. self.depipeline.AssignRootNode(root)
  79. self.depipeline.LaunchTreeExec(self.num_epochs)
  80. self._index = 0
  81. def stop(self):
  82. """
  83. Manually terminate python iterator instead of relying on out of scope destruction.
  84. """
  85. logger.info("terminating python iterator. This will also terminate c++ pipeline.")
  86. if hasattr(self, 'depipeline') and self.depipeline:
  87. del self.depipeline
  88. def __is_tree_node(self, node):
  89. """Check if a node is tree node."""
  90. if not node.children:
  91. if len(node.parent) > 1:
  92. return False
  93. if len(node.parent) > 1:
  94. return False
  95. for input_node in node.children:
  96. cls = self.__is_tree_node(input_node)
  97. if not cls:
  98. return False
  99. return True
  100. def __is_tree(self):
  101. return self.__is_tree_node(self.dataset)
  102. @staticmethod
  103. def __get_dataset_type(dataset):
  104. """Get the dataset type."""
  105. op_type = None
  106. if isinstance(dataset, de.ShuffleDataset):
  107. op_type = OpName.SHUFFLE
  108. elif isinstance(dataset, de.MindDataset):
  109. op_type = OpName.MINDRECORD
  110. elif isinstance(dataset, de.BatchDataset):
  111. op_type = OpName.BATCH
  112. elif isinstance(dataset, de.BucketBatchByLengthDataset):
  113. op_type = OpName.BUCKETBATCH
  114. elif isinstance(dataset, de.SyncWaitDataset):
  115. op_type = OpName.BARRIER
  116. elif isinstance(dataset, de.ZipDataset):
  117. op_type = OpName.ZIP
  118. elif isinstance(dataset, de.ConcatDataset):
  119. op_type = OpName.CONCAT
  120. elif isinstance(dataset, de.MapDataset):
  121. op_type = OpName.MAP
  122. elif isinstance(dataset, de.FilterDataset):
  123. op_type = OpName.FILTER
  124. elif isinstance(dataset, de.RepeatDataset):
  125. op_type = OpName.REPEAT
  126. elif isinstance(dataset, de.SkipDataset):
  127. op_type = OpName.SKIP
  128. elif isinstance(dataset, de.TakeDataset):
  129. op_type = OpName.TAKE
  130. elif isinstance(dataset, de.ImageFolderDatasetV2):
  131. op_type = OpName.IMAGEFOLDER
  132. elif isinstance(dataset, de.GeneratorDataset):
  133. op_type = OpName.GENERATOR
  134. elif isinstance(dataset, de.TransferDataset):
  135. op_type = OpName.DEVICEQUEUE
  136. elif isinstance(dataset, de.RenameDataset):
  137. op_type = OpName.RENAME
  138. elif isinstance(dataset, de.TFRecordDataset):
  139. op_type = OpName.TFREADER
  140. elif isinstance(dataset, de.ProjectDataset):
  141. op_type = OpName.PROJECT
  142. elif isinstance(dataset, de.MnistDataset):
  143. op_type = OpName.MNIST
  144. elif isinstance(dataset, de.ManifestDataset):
  145. op_type = OpName.MANIFEST
  146. elif isinstance(dataset, de.VOCDataset):
  147. op_type = OpName.VOC
  148. elif isinstance(dataset, de.CocoDataset):
  149. op_type = OpName.COCO
  150. elif isinstance(dataset, de.Cifar10Dataset):
  151. op_type = OpName.CIFAR10
  152. elif isinstance(dataset, de.Cifar100Dataset):
  153. op_type = OpName.CIFAR100
  154. elif isinstance(dataset, de.CelebADataset):
  155. op_type = OpName.CELEBA
  156. elif isinstance(dataset, de.RandomDataset):
  157. op_type = OpName.RANDOMDATA
  158. elif isinstance(dataset, de.TextFileDataset):
  159. op_type = OpName.TEXTFILE
  160. elif isinstance(dataset, de.BuildVocabDataset):
  161. op_type = OpName.BUILDVOCAB
  162. elif isinstance(dataset, de.BuildSentencePieceVocabDataset):
  163. op_type = OpName.SENTENCEPIECEVOCAB
  164. elif isinstance(dataset, de.CLUEDataset):
  165. op_type = OpName.CLUE
  166. elif isinstance(dataset, de.CSVDataset):
  167. op_type = OpName.CSV
  168. else:
  169. raise ValueError("Unsupported DatasetOp")
  170. return op_type
  171. # Convert python node into C node and add to C layer execution tree in postorder traversal.
  172. def __convert_node_postorder(self, node):
  173. self.check_node_type(node)
  174. op_type = self.__get_dataset_type(node)
  175. c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args())
  176. for py_child in node.children:
  177. c_child = self.__convert_node_postorder(py_child)
  178. self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"])
  179. return c_nodes["top"]
  180. def __batch_node(self, dataset, level):
  181. """Recursively get batch node in the dataset tree."""
  182. if isinstance(dataset, de.BatchDataset):
  183. return
  184. for input_op in dataset.children:
  185. self.__batch_node(input_op, level + 1)
  186. @staticmethod
  187. def __print_local(dataset, level):
  188. """Recursively print the name and address of nodes in the dataset tree."""
  189. name = dataset.__class__.__name__
  190. ptr = hex(id(dataset))
  191. for _ in range(level):
  192. logger.info("\t", end='')
  193. if not dataset.children:
  194. logger.info("-%s (%s)", name, ptr)
  195. else:
  196. logger.info("+%s (%s)", name, ptr)
  197. for input_op in dataset.children:
  198. Iterator.__print_local(input_op, level + 1)
  199. def print(self):
  200. """Print the dataset tree"""
  201. self.__print_local(self.dataset, 0)
  202. def release(self):
  203. if hasattr(self, 'depipeline') and self.depipeline:
  204. del self.depipeline
  205. @abstractmethod
  206. def get_next(self):
  207. raise RuntimeError("Calling base class Iterator's get_next is invalid.")
  208. def __next__(self):
  209. if not self.depipeline:
  210. logger.warning("Iterator does not have a running c++ pipeline." +
  211. "It can be because Iterator stop() had been called, or c++ pipeline crashed silently.")
  212. raise RuntimeError("Iterator does not have a running c++ pipeline.")
  213. data = self.get_next()
  214. if not data:
  215. if self._index == 0:
  216. logger.warning("No records available.")
  217. if self.ori_dataset.dataset_size is None:
  218. self.ori_dataset.dataset_size = self._index
  219. raise StopIteration
  220. self._index += 1
  221. return data
  222. @abstractmethod
  223. def check_node_type(self, node):
  224. pass
  225. def get_output_shapes(self):
  226. return [t for t in self.depipeline.GetOutputShapes()]
  227. def get_output_types(self):
  228. return [t for t in self.depipeline.GetOutputTypes()]
  229. def get_dataset_size(self):
  230. return self.depipeline.GetDatasetSize()
  231. def get_batch_size(self):
  232. return self.depipeline.GetBatchSize()
  233. def get_repeat_count(self):
  234. return self.depipeline.GetRepeatCount()
  235. def num_classes(self):
  236. return self.depipeline.GetNumClasses()
  237. def __deepcopy__(self, memo):
  238. return self
  239. class SaveOp(Iterator):
  240. """
  241. The derived class of Iterator with dict type.
  242. """
  243. def get_next(self):
  244. pass
  245. def check_node_type(self, node):
  246. if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)):
  247. logger.warning("Used shuffle, repeat, batch before save operator.")
  248. def save(self, file_names, file_type):
  249. return self.depipeline.SaveDataset(file_names, file_type)
  250. class DictIterator(Iterator):
  251. """
  252. The derived class of Iterator with dict type.
  253. """
  254. def check_node_type(self, node):
  255. pass
  256. def __iter__(self):
  257. return self
  258. def get_next(self):
  259. """
  260. Returns the next record in the dataset as dictionary
  261. Returns:
  262. Dict, the next record in the dataset.
  263. """
  264. return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()}
  265. class TupleIterator(Iterator):
  266. """
  267. The derived class of Iterator with list type.
  268. """
  269. def check_node_type(self, node):
  270. pass
  271. def __init__(self, dataset, columns=None, num_epochs=-1):
  272. if columns is not None:
  273. if not isinstance(columns, list):
  274. columns = [columns]
  275. dataset = dataset.project(columns)
  276. super().__init__(dataset, num_epochs)
  277. def __iter__(self):
  278. return self
  279. def get_next(self):
  280. """
  281. Returns the next record in the dataset as a list
  282. Returns:
  283. List, the next record in the dataset.
  284. """
  285. return [t.as_array() for t in self.depipeline.GetNextAsList()]
  286. class DummyIterator():
  287. """
  288. A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED"
  289. """
  290. def __init__(self, dataset, mode):
  291. self.mode = mode
  292. self.shapes = dataset.output_shapes()
  293. self.types = dataset.output_types()
  294. self.fetched_first = False
  295. def __get_tensor(self):
  296. tensor_row = []
  297. for np_shape, np_type in zip(self.shapes, self.types):
  298. input_np = np.zeros(np_shape, np_type)
  299. tensor = Tensor(input_np)
  300. tensor_row.append(tensor)
  301. return tensor_row
  302. def __iter__(self):
  303. return self
  304. def __next__(self):
  305. if self.mode == "tuple":
  306. if not self.fetched_first:
  307. self.fetched_first = True
  308. return self.__get_tensor()
  309. raise StopIteration()