You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/engine/opt/pass.h"
  17. #include "minddata/dataset/engine/ir/datasetops/batch_node.h"
  18. #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h"
  19. #ifndef ENABLE_ANDROID
  20. #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h"
  21. #endif
  22. #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
  23. #ifndef ENABLE_ANDROID
  24. #include "minddata/dataset/engine/ir/datasetops/cache_node.h"
  25. #include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
  26. #include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
  27. #endif
  28. #include "minddata/dataset/engine/ir/datasetops/concat_node.h"
  29. #include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
  30. #include "minddata/dataset/engine/ir/datasetops/filter_node.h"
  31. #include "minddata/dataset/engine/ir/datasetops/map_node.h"
  32. #include "minddata/dataset/engine/ir/datasetops/project_node.h"
  33. #include "minddata/dataset/engine/ir/datasetops/rename_node.h"
  34. #include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
  35. #include "minddata/dataset/engine/ir/datasetops/root_node.h"
  36. #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
  37. #include "minddata/dataset/engine/ir/datasetops/skip_node.h"
  38. #ifndef ENABLE_ANDROID
  39. #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
  40. #endif
  41. #ifdef ENABLE_PYTHON
  42. #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
  43. #endif
  44. #include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
  45. #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
  46. #ifdef ENABLE_PYTHON
  47. #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
  48. #endif
  49. #include "minddata/dataset/engine/ir/datasetops/take_node.h"
  50. #include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
  51. #include "minddata/dataset/engine/ir/datasetops/zip_node.h"
  52. namespace mindspore {
  53. namespace dataset {
  54. // Driver method for TreePass
  55. Status IRTreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
  56. if (root_ir == nullptr || modified == nullptr) {
  57. return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to TreePass");
  58. }
  59. // Initialize modified flag
  60. *modified = false;
  61. return this->RunOnTree(root_ir, modified);
  62. }
  63. // Driver method for NodePass
  64. Status IRNodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
  65. if (root_ir == nullptr || modified == nullptr) {
  66. return Status(StatusCode::kMDUnexpectedError, "Null pointer passed to NodePass");
  67. }
  68. // Initialize modified flag
  69. *modified = false;
  70. if (traversalOrder_ == Order::DFS) {
  71. // DFS
  72. return DFSNodeVisit(root_ir, modified);
  73. } else if (traversalOrder_ == Order::BFS) {
  74. // BFS
  75. return BFSNodeVisit(root_ir, modified);
  76. }
  77. return Status::OK();
  78. }
  79. // Helper function to perform DFS visit
  80. Status IRNodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
  81. bool m = false;
  82. RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
  83. *modified = *modified || m;
  84. for (const auto &c : node_ir->Children()) {
  85. RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
  86. *modified = *modified || m;
  87. }
  88. RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
  89. *modified = *modified || m;
  90. return Status::OK();
  91. }
  92. // Helper function to perform BFS visit
  93. Status IRNodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified) {
  94. bool m = false;
  95. // Initialize bfs queue with root
  96. std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
  97. bfsQueue.push(node_ir);
  98. // BFS loop
  99. while (!bfsQueue.empty()) {
  100. // Pop the front of the bfs queue
  101. auto curNode = bfsQueue.front();
  102. bfsQueue.pop();
  103. // Run node pass
  104. RETURN_IF_NOT_OK(curNode->Accept(this, &m));
  105. *modified = *modified || m;
  106. // Push children into bfs queue
  107. for (const auto &c : curNode->Children()) {
  108. bfsQueue.push(c);
  109. }
  110. }
  111. return Status::OK();
  112. }
  113. // For non-leaf IR node
  114. Status IRNodePass::Visit(std::shared_ptr<BatchNode> node, bool *const modified) {
  115. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  116. }
  117. Status IRNodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified) {
  118. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  119. }
  120. Status IRNodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
  121. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  122. }
  123. Status IRNodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified) {
  124. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  125. }
  126. Status IRNodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
  127. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  128. }
  129. Status IRNodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
  130. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  131. }
  132. Status IRNodePass::Visit(std::shared_ptr<ConcatNode> node, bool *const modified) {
  133. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  134. }
  135. Status IRNodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified) {
  136. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  137. }
  138. #ifndef ENABLE_ANDROID
  139. Status IRNodePass::Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
  140. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  141. }
  142. Status IRNodePass::VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified) {
  143. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  144. }
  145. Status IRNodePass::Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
  146. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  147. }
  148. Status IRNodePass::VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified) {
  149. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  150. }
  151. Status IRNodePass::Visit(std::shared_ptr<CacheNode> node, bool *const modified) {
  152. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  153. }
  154. Status IRNodePass::VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified) {
  155. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  156. }
  157. #endif
  158. Status IRNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
  159. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  160. }
  161. Status IRNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
  162. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  163. }
  164. Status IRNodePass::Visit(std::shared_ptr<FilterNode> node, bool *const modified) {
  165. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  166. }
  167. Status IRNodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified) {
  168. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  169. }
  170. #ifdef ENABLE_PYTHON
  171. Status IRNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
  172. return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
  173. }
  174. Status IRNodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified) {
  175. return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
  176. }
  177. #endif
  178. Status IRNodePass::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
  179. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  180. }
  181. Status IRNodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *const modified) {
  182. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  183. }
  184. #ifndef ENABLE_ANDROID
  185. Status IRNodePass::Visit(std::shared_ptr<MindDataNode> node, bool *const modified) {
  186. return Visit(std::static_pointer_cast<MappableSourceNode>(node), modified);
  187. }
  188. Status IRNodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified) {
  189. return VisitAfter(std::static_pointer_cast<MappableSourceNode>(node), modified);
  190. }
  191. #endif
  192. Status IRNodePass::Visit(std::shared_ptr<ProjectNode> node, bool *const modified) {
  193. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  194. }
  195. Status IRNodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified) {
  196. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  197. }
  198. Status IRNodePass::Visit(std::shared_ptr<RandomNode> node, bool *const modified) {
  199. return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
  200. }
  201. Status IRNodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified) {
  202. return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
  203. }
  204. Status IRNodePass::Visit(std::shared_ptr<RenameNode> node, bool *const modified) {
  205. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  206. }
  207. Status IRNodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified) {
  208. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  209. }
  210. Status IRNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
  211. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  212. }
  213. Status IRNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
  214. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  215. }
  216. Status IRNodePass::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
  217. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  218. }
  219. Status IRNodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *const modified) {
  220. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  221. }
  222. Status IRNodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) {
  223. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  224. }
  225. Status IRNodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified) {
  226. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  227. }
  228. Status IRNodePass::Visit(std::shared_ptr<SkipNode> node, bool *const modified) {
  229. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  230. }
  231. Status IRNodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified) {
  232. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  233. }
  234. Status IRNodePass::Visit(std::shared_ptr<TakeNode> node, bool *const modified) {
  235. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  236. }
  237. Status IRNodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified) {
  238. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  239. }
  240. Status IRNodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *const modified) {
  241. return Visit(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
  242. }
  243. Status IRNodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified) {
  244. return VisitAfter(std::static_pointer_cast<NonMappableSourceNode>(node), modified);
  245. }
  246. Status IRNodePass::Visit(std::shared_ptr<TransferNode> node, bool *const modified) {
  247. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  248. }
  249. Status IRNodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
  250. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  251. }
  252. Status IRNodePass::Visit(std::shared_ptr<ZipNode> node, bool *const modified) {
  253. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  254. }
  255. Status IRNodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified) {
  256. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  257. }
  258. #ifdef ENABLE_PYTHON
  259. Status IRNodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
  260. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  261. }
  262. Status IRNodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified) {
  263. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  264. }
  265. #endif
  266. #ifndef ENABLE_ANDROID
  267. Status IRNodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
  268. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  269. }
  270. Status IRNodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
  271. return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
  272. }
  273. #endif
  274. // leaf-IR Node
  275. Status IRNodePass::Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified) {
  276. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  277. }
  278. Status IRNodePass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified) {
  279. return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
  280. }
  281. } // namespace dataset
  282. } // namespace mindspore