You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

execution_tree.h 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_EXECUTION_TREE_H_
  17. #define DATASET_ENGINE_EXECUTION_TREE_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <stack>
  21. #include <vector>
  22. #include "dataset/engine/datasetops/dataset_op.h"
  23. #include "dataset/util/status.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // Forward declares
  27. class TaskGroup;
  28. class DatasetOp;
  29. class ExecutionTree {
  30. public:
  31. // Prepare flags used during tree prepare phase
  32. enum PrepareFlags {
  33. kDePrepNone = 0,
  34. kDePrepRepeat = 1 // Processing a repeat operation
  35. };
  36. // State flags for the lifecycle of the tree
  37. enum TreeState {
  38. kDeTStateInit = 0, // The freshly initialized state after construction
  39. kDeTStateBuilding, // The tree is being built, nodes are being added
  40. kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare
  41. kDeTStateReady, // The tree has been prepared and is ready to be launched
  42. kDeTStateExecuting // The tree has been launched and is executing
  43. };
  44. class Iterator {
  45. public:
  46. // Constructor
  47. // @param root The root node to start iterating from
  48. explicit Iterator(const std::shared_ptr<DatasetOp> &root = nullptr);
  49. // Destructor
  50. ~Iterator() {}
  51. Iterator &operator++() {
  52. ++ind_;
  53. return *this;
  54. } // prefix ++ overload
  55. Iterator operator++(int) {
  56. Iterator it = *this;
  57. it.ind_ = ind_;
  58. ind_++;
  59. return it;
  60. } // post-fix ++ overload
  61. Iterator &operator--() {
  62. --ind_;
  63. return *this;
  64. } // prefix -- overload
  65. Iterator operator--(int) {
  66. Iterator it = *this;
  67. it.ind_ = ind_;
  68. ind_--;
  69. return it;
  70. } // post-fix -- overload
  71. DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator
  72. std::shared_ptr<DatasetOp> operator->() { return nodes_[ind_]; }
  73. // getter function
  74. // @return Shared pointer to the current operator
  75. std::shared_ptr<DatasetOp> get() { return nodes_[ind_]; }
  76. bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; }
  77. private:
  78. int ind_; // the cur node our Iterator points to
  79. std::vector<std::shared_ptr<DatasetOp>> nodes_; // store the nodes in post order
  80. void PostOrderTraverse(const std::shared_ptr<DatasetOp> &);
  81. };
  82. // Constructor
  83. ExecutionTree();
  84. // Destructor
  85. ~ExecutionTree();
  86. // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and
  87. // provides it with a link to the tree. A node cannot form any relationships (parent/child) with
  88. // other nodes unless they are associated with the same tree.
  89. // @param op - The operator to associate
  90. // @return Status - The error code return
  91. Status AssociateNode(const std::shared_ptr<DatasetOp> &op);
  92. // Sets the root node of the tree
  93. // @param op - The operator to assign as root
  94. // @return Status - The error code return
  95. Status AssignRoot(const std::shared_ptr<DatasetOp> &op);
  96. // Start the execution of the tree
  97. // @return Status - The error code return
  98. Status Launch();
  99. // A print method typically used for debugging
  100. // @param out - The output stream to write output to
  101. // @param show_all - A bool to control if you want to show all info or just a summary
  102. void Print(std::ostream &out, bool show_all) const;
  103. // Returns an iterator positioned at the start
  104. // @return Iterator - The iterator
  105. ExecutionTree::Iterator begin(const std::shared_ptr<DatasetOp> &root = nullptr) const {
  106. return Iterator((root == nullptr) ? root_ : root);
  107. }
  108. // Returns an iterator positioned at the end
  109. // @return Iterator - The iterator
  110. ExecutionTree::Iterator end() const { return Iterator(nullptr); }
  111. // << Stream output operator overload
  112. // @notes This allows you to write the debug print info using stream operators
  113. // @param out - reference to the output stream being overloaded
  114. // @param exe_tree - reference to the execution tree to display
  115. // @return - the output stream must be returned
  116. friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) {
  117. exe_tree.Print(out, false);
  118. return out;
  119. }
  120. // Given the number of workers, launches the worker entry function for each. Essentially a
  121. // wrapper for the TaskGroup handling that is stored inside the execution tree.
  122. // @param num_workers - The number of workers to launch
  123. // @param func - The function entry point that workers will execute
  124. // @return Status - The error code return
  125. Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func);
  126. // Getter method
  127. // @return shared_ptr to the root operator
  128. std::shared_ptr<DatasetOp> root() const { return root_; }
  129. // Getter method
  130. // @return the prepare flags
  131. uint32_t PrepareFlags() const { return prepare_flags_; }
  132. // The driver of the prepare phase of the execution tree. The prepare phase will recursively
  133. // walk the tree to perform modifications to the tree or specific nodes within the tree to get
  134. // it ready for execution.
  135. // @return Status - The error code return
  136. Status Prepare();
  137. // Recursive function used during prepare phase to visit a node and drive any pre- and post-
  138. // node actions during a tree walk.
  139. // @param op - The dataset op to work on
  140. // @return Status - The error code return
  141. Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
  142. // Adds an operator to the repeat stack during prepare phase.
  143. // @param op - The dataset op to work add to repeat stack
  144. // @return Status - The error code return
  145. void AddToRepeatStack(std::shared_ptr<DatasetOp> dataset_op);
  146. // Pops an operator from the repeat stack during prepare phase.
  147. // @return shared_ptr to the popped operator
  148. std::shared_ptr<DatasetOp> PopFromRepeatStack();
  149. // Return the pointer to the TaskGroup
  150. // @return raw pointer to the TaskGroup
  151. TaskGroup *AllTasks() const { return tg_.get(); }
  152. private:
  153. std::unique_ptr<TaskGroup> tg_; // Class for worker management
  154. std::shared_ptr<DatasetOp> root_; // The root node of the tree
  155. int32_t id_count_; // Counter for generating operator id's
  156. uint32_t prepare_flags_; // Flags used during tree prepare
  157. TreeState tree_state_; // Tracking the current tree state
  158. std::stack<std::shared_ptr<DatasetOp>> repeat_stack_; // A stack used during prepare phase
  159. };
  160. } // namespace dataset
  161. } // namespace mindspore
  162. #endif // DATASET_ENGINE_EXECUTION_TREE_H_