You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

execution_tree.h 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <stack>
  21. #include <string>
  22. #include <vector>
  23. #include "minddata/dataset/engine/datasetops/dataset_op.h"
  24. #include "minddata/dataset/util/status.h"
  25. #include "mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Forward declares
  29. class TaskGroup;
  30. class DatasetOp;
  31. class ExecutionTree {
  32. public:
  33. // Prepare flags used during tree prepare phase
  34. enum PrepareFlags {
  35. kDePrepNone = 0,
  36. kDePrepRepeat = 1, // Processing a repeat operation
  37. kDePrepCache = 2 // Processing a cache operation
  38. };
  39. // State flags for the lifecycle of the tree
  40. enum TreeState {
  41. kDeTStateInit = 0, // The freshly initialized state after construction
  42. kDeTStateBuilding, // The tree is being built, nodes are being added
  43. kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare
  44. kDeTStateReady, // The tree has been prepared and is ready to be launched
  45. kDeTStateExecuting, // The tree has been launched and is executing
  46. kDeTStateEpochEnd, // The tree has been received end of epoch signal, just for profiling
  47. kDeTStateFinished // The tree has been drained, dataset iterator received EOF
  48. };
  49. class Iterator {
  50. public:
  51. // Constructor
  52. // @param root The root node to start iterating from
  53. explicit Iterator(const std::shared_ptr<DatasetOp> &root = nullptr);
  54. // Destructor
  55. ~Iterator() {}
  56. Iterator &operator++() {
  57. ++ind_;
  58. return *this;
  59. } // prefix ++ overload
  60. Iterator operator++(int) {
  61. Iterator it = *this;
  62. it.ind_ = ind_;
  63. ind_++;
  64. return it;
  65. } // post-fix ++ overload
  66. Iterator &operator--() {
  67. --ind_;
  68. return *this;
  69. } // prefix -- overload
  70. Iterator operator--(int) {
  71. Iterator it = *this;
  72. it.ind_ = ind_;
  73. ind_--;
  74. return it;
  75. } // post-fix -- overload
  76. DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator
  77. std::shared_ptr<DatasetOp> operator->() { return nodes_[ind_]; }
  78. // getter function
  79. // @return Shared pointer to the current operator
  80. std::shared_ptr<DatasetOp> get() { return nodes_[ind_]; }
  81. bool operator==(const Iterator &rhs) { return nodes_[ind_] == rhs.nodes_[rhs.ind_]; }
  82. bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; }
  83. int32_t NumNodes() { return nodes_.size(); }
  84. private:
  85. int32_t ind_; // the cur node our Iterator points to
  86. std::vector<std::shared_ptr<DatasetOp>> nodes_; // store the nodes in post order
  87. void PostOrderTraverse(const std::shared_ptr<DatasetOp> &);
  88. };
  89. // Constructor
  90. ExecutionTree();
  91. // Destructor
  92. ~ExecutionTree();
  93. // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and
  94. // provides it with a link to the tree. A node cannot form any relationships (parent/child) with
  95. // other nodes unless they are associated with the same tree.
  96. // @param op - The operator to associate
  97. // @return Status - The error code return
  98. Status AssociateNode(const std::shared_ptr<DatasetOp> &op);
  99. // Sets the root node of the tree
  100. // @param op - The operator to assign as root
  101. // @return Status - The error code return
  102. Status AssignRoot(const std::shared_ptr<DatasetOp> &op);
  103. // Start the execution of the tree
  104. // @return Status - The error code return
  105. Status Launch();
  106. /// A print method typically used for debugging
  107. /// \param out - The output stream to write output to
  108. void Print(std::ostream &out, const std::shared_ptr<DatasetOp> &op = nullptr) const;
  109. // Returns an iterator positioned at the start
  110. // @return Iterator - The iterator
  111. ExecutionTree::Iterator begin(const std::shared_ptr<DatasetOp> &root = nullptr) const {
  112. return Iterator(root == nullptr ? root_ : root);
  113. }
  114. // Returns an iterator positioned at the end
  115. // @return Iterator - The iterator
  116. ExecutionTree::Iterator end() const { return Iterator(nullptr); }
  117. // << Stream output operator overload
  118. // @notes This allows you to write the debug print info using stream operators
  119. // @param out - reference to the output stream being overloaded
  120. // @param exe_tree - reference to the execution tree to display
  121. // @return - the output stream must be returned
  122. friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) {
  123. exe_tree.Print(out);
  124. return out;
  125. }
  126. // Given the number of workers, launches the worker entry function for each. Essentially a
  127. // wrapper for the TaskGroup handling that is stored inside the execution tree.
  128. // @param num_workers - The number of workers to launch
  129. // @param func - The function entry point that workers will execute
  130. // @return Status - The error code return
  131. Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func);
  132. // Getter method
  133. // @return shared_ptr to the root operator
  134. std::shared_ptr<DatasetOp> root() const { return root_; }
  135. // Getter method
  136. // @return the prepare flags
  137. uint32_t PrepareFlags() const { return prepare_flags_; }
  138. // The driver of the prepare phase of the execution tree.
  139. // Prepare phase consists of three sub phases
  140. //
  141. // 1. PrepareTreePreAction()
  142. // Compulsory transformation/action pre optimization.
  143. // For example, CacheOp Insertion
  144. //
  145. // 2. Optimize()
  146. // Optimization transformation/action, optional
  147. // For example, MapOp Fusion
  148. //
  149. // 3. PrepareTreePostAction()
  150. // Compulsory transformation/action post optimization.
  151. // For example, repeatOp inlining
  152. //
  153. // @return Status - The error code return
  154. Status Prepare(int num_epochs = -1);
  155. // Compulsory transformation/action pre optimization.
  156. // @return Status - The error code return
  157. Status PrepareTreePreAction();
  158. // Compulsory transformation/action post optimization.
  159. // @return Status - The error code return
  160. Status PrepareTreePostAction();
  161. // Optimization transformation/action, optional.
  162. // @return Status - The error code return
  163. Status Optimize();
  164. // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
  165. // walk the tree to perform modifications to the tree or specific nodes within the tree to get
  166. // it ready for execution.
  167. // @param Total number of epochs that will be run on this tree
  168. // @return Status - The error code return
  169. Status PrepareDeprecated();
  170. // Recursive function used during prepare phase to visit a node and drive any pre- and post-
  171. // node actions during a tree walk.
  172. // @param op - The dataset op to work on
  173. // @return Status - The error code return
  174. Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
  175. // Return the pointer to the TaskGroup
  176. // @return raw pointer to the TaskGroup
  177. TaskGroup *AllTasks() const { return tg_.get(); }
  178. // Return if the ExecutionTree is at end of epoch status
  179. // @return bool - true is ExecutionTree is end of epoch status
  180. bool IsEpochEnd() const { return tree_state_ == TreeState::kDeTStateEpochEnd; }
  181. // Set the ExecutionTree to EOE state
  182. void SetEpochEnd() { tree_state_ = TreeState::kDeTStateEpochEnd; }
  183. // Set the ExecutionTree to executing state
  184. void SetExecuting() { tree_state_ = TreeState::kDeTStateExecuting; }
  185. // Return if the ExecutionTree is finished (iterator receives EOF).
  186. // @return Bool - true is ExecutionTree is finished
  187. bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; }
  188. // Return if the ExecutionTree is ready.
  189. // @return Bool - true is ExecutionTree is ready
  190. bool isPrepared() const {
  191. return tree_state_ == TreeState::kDeTStateReady || tree_state_ == kDeTStateExecuting ||
  192. tree_state_ == kDeTStateFinished;
  193. }
  194. // Set the ExecutionTree to Finished state.
  195. void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; }
  196. // Getter for profiling manager, no ownership
  197. ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); }
  198. // Set optional optimization if tree has not been prepared yet
  199. Status SetOptimize(bool value) {
  200. if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
  201. std::string optimize = (optimize_ == true) ? "true" : "false";
  202. std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize;
  203. RETURN_STATUS_UNEXPECTED(msg);
  204. } else {
  205. optimize_ = value;
  206. return Status::OK();
  207. }
  208. }
  209. // Optional optimizations status
  210. bool OptimizationEnabled() const { return optimize_; }
  211. // Getter function to get the total number of epochs to be run on this tree.
  212. // @return total number of epochs
  213. int32_t num_epochs() { return num_epochs_; }
  214. private:
  215. // A helper functions for doing the recursive printing
  216. // @param dataset_op - The dataset op to print
  217. // @param indent - an indent string for aligning child levels in output
  218. // @param last - an indicator if it's the last child or not
  219. // @param detailed - should it display the detailed node output or the summary line
  220. void PrintNode(std::ostream &out, const std::shared_ptr<DatasetOp> &dataset_op, std::string indent, bool last,
  221. bool detailed) const;
  222. std::unique_ptr<TaskGroup> tg_; // Class for worker management
  223. std::shared_ptr<DatasetOp> root_; // The root node of the tree
  224. int32_t id_count_; // Counter for generating operator id's
  225. uint32_t prepare_flags_; // Flags used during tree prepare
  226. TreeState tree_state_; // Tracking the current tree state
  227. int32_t num_epochs_; // Total number of epochs to run for this tree
  228. std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
  229. bool optimize_; // Flag to enable optional optimizations
  230. };
  231. } // namespace dataset
  232. } // namespace mindspore
  233. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_