You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_
  18. #include <memory>
  19. #include <queue>
  20. #include "minddata/dataset/engine/execution_tree.h"
  21. #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
  22. #include "minddata/dataset/util/status.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. // Non-leaf IR node
  26. class BatchNode;
  27. class BucketBatchByLengthNode;
  28. class BuildVocabNode;
  29. #ifndef ENABLE_ANDROID
  30. class CacheLookupNode;
  31. class CacheMergeNode;
  32. class CacheNode;
  33. #endif
  34. class ConcatNode;
  35. class EpochCtrlNode;
  36. class FilterNode;
  37. class MapNode;
  38. class ProjectNode;
  39. class RenameNode;
  40. class RepeatNode;
  41. class RootNode;
  42. class ShuffleNode;
  43. class SkipNode;
  44. class TakeNode;
  45. class TFRecordNode;
  46. class TransferNode;
  47. class ZipNode;
  48. #ifdef ENABLE_PYTHON
  49. class SyncWaitNode;
  50. #endif
  51. #ifndef ENABLE_ANDROID
  52. class BuildSentenceVocabNode;
  53. #endif
  54. // Leaf IR node
  55. class AlbumNode;
  56. class CelebANode;
  57. class Cifar100Node;
  58. class Cifar10Node;
  59. class CocoNode;
  60. class ImageFolderNode;
  61. class ManifestNode;
  62. class MnistNode;
  63. class RandomNode;
  64. class VOCNode;
  65. #ifdef ENABLE_PYTHON
  66. class GeneratorNode;
  67. #endif
  68. #ifndef ENABLE_ANDROID
  69. class CLUENode;
  70. class CSVNode;
  71. class MindDataNode;
  72. class TextFileNode;
  73. class TFRecordNode;
  74. #endif
  75. // The base class Pass is the basic unit of tree transformation.
  76. // The actual implementation of the passes will be derived from here.
  77. class IRPass : public std::enable_shared_from_this<IRPass> {
  78. public:
  79. // Run the transformation pass against the IR tree.
  80. // @param root_ir - Pointer to the IR tree to be transformed.
  81. // @param modified - Pointer to the modified flag,
  82. virtual Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) = 0;
  83. virtual ~IRPass() = default;
  84. };
  85. // IRTreePass is a basic Pass class which performs transformation on IR tree directly.
  86. class IRTreePass : public IRPass {
  87. public:
  88. /// \brief Run the transformation pass against the IR tree.
  89. /// \param[in,out] root_ir Pointer to the IR tree to be transformed.
  90. /// \param[in,out] modified Indicate if the tree was modified
  91. Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
  92. /// \brief Derived classes may implement the runOnTree function to implement tree transformation.
  93. /// "modified" flag needs to be set to true if tree is modified during the pass execution.
  94. /// \param[in,out] tree The tree to operate on.
  95. /// \param[in,out] Indicate if the tree was modified.
  96. /// \return Status The status code returned
  97. virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) { return Status::OK(); }
  98. };
  99. // IRNodePass is a base Pass class which performs transformation on node visiting.
  100. // IRNodePass implements Visitor design pattern.
  101. // The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
  102. // and the other when all the descending nodes are visited.
  103. // Actual transformation is done by implementing a new derived class of IRNodePass.
  104. // The derived class will implement the method Visit()/VisitAfter() passing specified node types
  105. // it wants to action on them, overriding the ones defined in IRNodePass.
  106. // If the derived class wants to perform the same action on all node types,
  107. // it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
  108. // This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
  109. // to call the Visit()/VisitAfter() in this parent IRNodePass class.
  110. class IRNodePass : public IRPass {
  111. public:
  112. // Tree traversal order
  113. enum Order { DFS, BFS };
  114. // Constructor
  115. // Default DFS traversal
  116. explicit IRNodePass(Order order = Order::DFS) { traversalOrder_ = order; }
  117. ~IRNodePass() = default;
  118. /// \brief Run the transformation pass against the IR tree
  119. /// \param[in,out] root_ir Pointer to the IR tree to be transformed
  120. /// \param[in,out] modified Indicator if the tree was changed
  121. Status Run(std::shared_ptr<DatasetNode> root_ir, bool *const modified) final;
  122. /// \brief Derived classes may implement the Visit function to implement any initial visit work on the way down
  123. /// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
  124. /// \param[in] node The node being visited
  125. /// \param[out] modified Indicator if the node was changed at all
  126. /// \return Status The status code returned
  127. virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); }
  128. /// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
  129. /// "modified" flag needs to be set to true if node is modified during the pass execution
  130. /// \param[in] node The node being visited
  131. /// \param[out] modified Indicator if the node was changed at all.
  132. /// \return Status The status code returned
  133. virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *const modified) { return Status::OK(); }
  134. // Visit()/VisitAfter() method to be overridden.
  135. // These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here.
  136. // Their implementation are in .cc file to avoid adding the include files of those derived classes.
  137. // The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of
  138. // the derived classes. With this technique, the transformation classes derived from NodePass needs only to
  139. // implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes
  140. // of DatasetNode in the same way.
  141. // Note that virtual template functions are not permitted in C++.
  142. //
  143. // Non-leaf IR node
  144. virtual Status Visit(std::shared_ptr<BatchNode> node, bool *const modified);
  145. virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *const modified);
  146. virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified);
  147. virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *const modified);
  148. #ifndef ENABLE_ANDROID
  149. virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified);
  150. virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified);
  151. #endif
  152. virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified);
  153. virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *const modified);
  154. virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *const modified);
  155. virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *const modified);
  156. #ifndef ENABLE_ANDROID
  157. virtual Status Visit(std::shared_ptr<CacheMergeNode> node, bool *const modified);
  158. virtual Status VisitAfter(std::shared_ptr<CacheMergeNode> node, bool *const modified);
  159. virtual Status Visit(std::shared_ptr<CacheLookupNode> node, bool *const modified);
  160. virtual Status VisitAfter(std::shared_ptr<CacheLookupNode> node, bool *const modified);
  161. virtual Status Visit(std::shared_ptr<CacheNode> node, bool *const modified);
  162. virtual Status VisitAfter(std::shared_ptr<CacheNode> node, bool *const modified);
  163. #endif
  164. virtual Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
  165. virtual Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified);
  166. virtual Status Visit(std::shared_ptr<FilterNode> node, bool *const modified);
  167. virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *const modified);
  168. #ifdef ENABLE_PYTHON
  169. virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified);
  170. virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *const modified);
  171. #endif
  172. virtual Status Visit(std::shared_ptr<MapNode> node, bool *const modified);
  173. virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *const modified);
  174. #ifndef ENABLE_ANDROID
  175. virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *const modified);
  176. virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *const modified);
  177. #endif
  178. virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *const modified);
  179. virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *const modified);
  180. virtual Status Visit(std::shared_ptr<RandomNode> node, bool *const modified);
  181. virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *const modified);
  182. virtual Status Visit(std::shared_ptr<RenameNode> node, bool *const modified);
  183. virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *const modified);
  184. virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified);
  185. virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified);
  186. virtual Status Visit(std::shared_ptr<RootNode> node, bool *const modified);
  187. virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *const modified);
  188. virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *const modified);
  189. virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *const modified);
  190. virtual Status Visit(std::shared_ptr<SkipNode> node, bool *const modified);
  191. virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *const modified);
  192. #ifdef ENABLE_PYTHON
  193. virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *const modified);
  194. virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *const modified);
  195. #endif
  196. virtual Status Visit(std::shared_ptr<TakeNode> node, bool *const modified);
  197. virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *const modified);
  198. virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *const modified);
  199. virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *const modified);
  200. virtual Status Visit(std::shared_ptr<TransferNode> node, bool *const modified);
  201. virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified);
  202. virtual Status Visit(std::shared_ptr<ZipNode> node, bool *const modified);
  203. virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *const modified);
  204. // leaf-IR Node
  205. virtual Status Visit(std::shared_ptr<MappableSourceNode> node, bool *const modified);
  206. virtual Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *const modified);
  207. private:
  208. // Helper function to perform DFS visit
  209. Status DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified);
  210. // Helper function to perform BFS visit
  211. Status BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *const modified);
  212. // Tree traversal order of the NodePass
  213. Order traversalOrder_;
  214. };
  215. } // namespace dataset
  216. } // namespace mindspore
  217. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_H_