You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_op.h 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_
  18. #include <memory>
  19. #include <mutex>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include "dataset/core/constants.h"
  24. #include "dataset/engine/db_connector.h"
  25. #include "dataset/util/status.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Forward declare
  29. class ExecutionTree;
  30. class DataBuffer;
  31. class NodePass;
  32. class Sampler;
  33. /// \brief The base class DatasetOp is the main tree node. It is an abstract class, so
  34. /// the actual implementation of the operators will be derived from here.
  35. class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
  36. // Allow execution tree to access internal members
  37. friend class ExecutionTree;
  38. public:
  39. static constexpr int32_t kInvalidOperatorId = -1;
  40. // Flags that control operator runtime behaviours
  41. enum OpControlFlags {
  42. kDeOpNone = 0,
  43. kDeOpRepeated = 1, // Operator is a leaf node in a repeat path
  44. kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
  45. };
  46. // Flags that control operator runtime behaviours
  47. enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
  48. /// Constructor
  49. /// \param op_connector_size - The size for the output connector of this operator.
  50. /// \param sampler - The sampler for the op
  51. explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler);
  52. /// Destructor
  53. virtual ~DatasetOp() { tree_ = nullptr; }
  54. /// Adds a operator to become our child.
  55. /// \param child - shared pointer to the child to add.
  56. Status AddChild(std::shared_ptr<DatasetOp> child);
  57. /// Remove a operator from our children.
  58. /// \param child - shared pointer to the child to remove.
  59. Status RemoveChild(std::shared_ptr<DatasetOp> child);
  60. /// \brief Removes this node from the tree and connects it's parent/child together.
  61. /// \return Status eerror code returned
  62. Status Remove();
  63. /// \brief Getter function to get a shared pointer to our child
  64. /// \param child_index - An operator can have n children. Indicates choose which child to return.
  65. std::shared_ptr<DatasetOp> child(int32_t child_index) const;
  66. /// \brief Inserts a operator as the parent current op.
  67. /// Inserted op will become the sole parent of the current op.
  68. /// The existing parent of the current op will be transferred to the inserted op.
  69. Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
  70. /// \brief Creates the connector within this operator
  71. /// \param num_producers - number of threads that write into this connector
  72. /// \param num_consumers - number of threads that read from this connector
  73. void CreateConnector(int32_t num_producers, int32_t num_consumers);
  74. /// \brief A print method typically used for debugging
  75. /// \param out - The output stream to write output to
  76. /// \param show_all - A bool to control if you want to show all info or just a summary
  77. virtual void Print(std::ostream &out, bool show_all) const;
  78. /// \brief << Stream output operator overload
  79. /// \notes This allows you to write the debug print info using stream operators
  80. /// \param out - reference to the output stream being overloaded
  81. /// \param dO - reference to the DatasetOp to display
  82. /// \return - the output stream must be returned
  83. friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) {
  84. dO.Print(out, false);
  85. return out;
  86. }
  87. /// \brief Class functor operator ().
  88. /// DatasetOps operate by launching a thread (see ExecutionTree).
  89. /// This pure virtual version makes the requirement that derived classes must provide a functor
  90. /// that will execute their main runtime loop code.
  91. /// \return Status - The error code return
  92. virtual Status operator()() = 0;
  93. /// \brief Gets the next buffer from the given child
  94. /// \notes See GetNextInput for similar function that has built-in message handling
  95. /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
  96. /// \param worker_id - The worker id
  97. /// \return Status - The error code return
  98. virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) {
  99. return GetNextBuffer(p_buffer, worker_id, false);
  100. }
  101. /// \brief Gets the next buffer from the given child
  102. /// \notes See GetNextInput for similar function that has built-in message handling
  103. /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
  104. /// \return Status - The error code return
  105. virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); }
  106. /// \brief Gets the next buffer from the given child
  107. /// \notes See GetNextInput for similar function that has built-in message handling
  108. /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
  109. /// \param worker_id - The worker id
  110. /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
  111. /// \return Status - The error code return
  112. virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe);
  113. /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof
  114. /// message handling so that child classes don't have to manually code pass-through logic when
  115. /// those messages are received.
  116. /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
  117. /// \param worker_id - The worker id
  118. /// \return Status - The error code return
  119. Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
  120. /// \brief Performs handling for when an eoe message is received.
  121. /// The base class implementation simply flows the eoe message to output. Derived classes
  122. /// may override if they need to perform special eoe handling.
  123. /// \param worker_id - The worker id
  124. /// \return Status - The error code return
  125. virtual Status EoeReceived(int32_t worker_id);
  126. /// \brief Performs handling for when an eof message is received.
  127. /// The base class implementation simply flows the eof message to output. Derived classes
  128. /// may override if they need to perform special eof handling.
  129. /// \param worker_id - The worker id
  130. /// \return Status - The error code return
  131. virtual Status EofReceived(int32_t worker_id);
  132. /// \brief Derived classes may implement the reset function if the operator is stateful and needs
  133. /// specific reset handling that is not contained in this common code version of the reset
  134. /// \return Status - The error code return
  135. virtual Status Reset();
  136. /// \brief This calls the reset function on this subtree in pre-order
  137. /// \return Status - The error code return
  138. virtual Status ResetSubtree() {
  139. RETURN_IF_NOT_OK(Reset());
  140. for (const auto &c : child_) {
  141. RETURN_IF_NOT_OK(c->ResetSubtree());
  142. }
  143. return Status::OK();
  144. }
  145. /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
  146. /// their role.
  147. /// \notes Derived versions of this function should always call it's superclass version first
  148. /// before providing their own implementations.
  149. virtual Status PrepareNodePreAction();
  150. /// \brief During tree prepare phase, operators may have specific post-operations to perform depending on
  151. /// their role.
  152. /// \notes Derived versions of this function should always call it's superclass version first
  153. /// before providing their own implementations.
  154. virtual Status PrepareNodePostAction();
  155. /// \brief Getter function
  156. /// \return The operator id
  157. int32_t id() const { return operator_id_; }
  158. /// \brief Getter function
  159. /// \return The prepare flags
  160. virtual uint32_t PrepareFlags() const;
  161. /// \brief Getter function
  162. /// \return The number of workers in this op
  163. virtual int32_t num_workers() const = 0;
  164. /// \brief Getter function
  165. /// \return The number of threads consuming from previous op.
  166. virtual int32_t num_consumers() const = 0;
  167. /// \brief Getter function
  168. /// \return The number of threads producing to the output connector.
  169. virtual int32_t num_producers() const = 0;
  170. /// \brief Getter function
  171. /// \return T/F if this is an inlined operator
  172. bool inlined() const { return (oc_queue_size_ == 0); }
  173. /// \brief Setter function
  174. /// \return Sets the control flags
  175. void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); }
  176. /// \brief Setter function
  177. /// \return Sets the control flags
  178. void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); }
  179. /// \brief Register the internal worker connectors. No op unless it is a parallel op
  180. /// \return Status
  181. virtual Status RegisterWorkerConnectors() { return Status::OK(); }
  182. /// \brief Getter for the column name mapping
  183. /// \return The returned map
  184. std::unordered_map<std::string, int32_t> column_name_id_map() const { return column_name_id_map_; }
  185. /// \brief Checks if the column name map has been set up yet for this op
  186. /// \return - T/F if the operator has the map set up
  187. bool HasColumnNameMap() const { return (column_name_id_map_.empty()); }
  188. /// \brief gives a string output for the column map for handy debug printing
  189. /// \return - the column name map as a string
  190. std::string ColumnNameMapAsString() const;
  191. /// \brief Getter function
  192. /// \return connector size of current op
  193. int32_t ConnectorSize() const {
  194. if (!inlined()) {
  195. return out_connector_->size();
  196. }
  197. // Return child connector size for inlined op
  198. return ChildOpConnectorSize();
  199. }
  200. /// \brief Counting number of buffer sent out by a connector
  201. int64_t ConnectorOutBufferCount() const {
  202. return out_connector_ == nullptr ? int64_t(-1) : static_cast<int64_t>(out_connector_->out_buffers_count());
  203. }
  204. /// \brief Getter function
  205. /// \return connector size of current op
  206. int32_t ConnectorCapacity() const {
  207. if (!inlined()) {
  208. return out_connector_->capacity();
  209. }
  210. // Return child connector capacity for inlined op
  211. return ChildOpConnectorCapacity();
  212. }
  213. /// \brief Getter function
  214. /// \return connector size of child op
  215. int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); }
  216. /// \brief Getter function
  217. /// \return connector capacity of child op
  218. int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); }
  219. /// \brief Children Getter
  220. /// \return Vector of Children
  221. std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
  222. /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up
  223. /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main
  224. /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it
  225. /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details.
  226. /// \param[in] p The node to visit
  227. /// \param[out] modified Indicator if the node was modified
  228. /// \return Status of the node visit
  229. virtual Status PreAccept(NodePass *p, bool *modified);
  230. /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access.
  231. /// Check "dataset/engine/opt/pass.h" for more details.
  232. /// \param[in] p The node to visit
  233. /// \param[out] modified Indicator if the node was modified
  234. /// \return Status of the node visit
  235. virtual Status Accept(NodePass *p, bool *modified);
  236. /// Op name getter
  237. /// \return Name of the current Op
  238. virtual std::string Name() const { return "DatasetOp"; }
  239. /// Execution Tree getter
  240. /// \return Pointer to the ExecutionTree the current op belongs to, no ownership
  241. ExecutionTree *Tree() { return tree_; }
  242. /// Getter for the sampler
  243. /// \return Shared pointer to the sampler (may return nullptr)
  244. std::shared_ptr<Sampler> sampler() { return sampler_; }
  245. /// Computes a CRC value for the operator
  246. static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
  247. /// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
  248. /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
  249. /// \return A shared_ptr casted to the derived class
  250. template <typename Derived>
  251. std::shared_ptr<Derived> shared_from_base() {
  252. return std::static_pointer_cast<Derived>(shared_from_this());
  253. }
  254. protected:
  255. /// Adds a parent operator to this operator
  256. /// \notes External callers do not have access to this function.
  257. /// \param parent - The parent node to add
  258. void AddParent(DatasetOp *parent);
  259. /// Removes a parent operator from this operator
  260. /// \notes External callers do not have access to this function.
  261. /// \param parent - The parent node to remove
  262. void RemoveParent(const DatasetOp *parent);
  263. /// Compute the current op's column map using its child's column map.
  264. /// Get called during the tree post-prepare phase in PrepareNodePostAction.
  265. /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
  266. /// Operations changing the column map it inherits from the child must overwrite this function.
  267. /// \return - Status
  268. virtual Status ComputeColMap();
  269. /// A helper function with some common code that leaf nodes can use during
  270. /// pre/pare phase for checking if they need to assign a sampler to the cache.
  271. /// \param random_access_op - indicate if this is a mappable random access leaf or not
  272. /// \return - Status
  273. Status SaveSamplerForCache(bool random_access_op);
  274. std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
  275. std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
  276. std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
  277. int32_t oc_queue_size_; // Capacity for each out_connector_
  278. int32_t operator_id_; // Generated id for the node
  279. ExecutionTree *tree_; // Back pointer to our tree.
  280. OpState state_; // The state of the operator, Running, Idle, Terminated
  281. uint32_t op_ctrl_flags_; // Flags for the operator
  282. std::unique_ptr<DbConnector> out_connector_; // Output Connector
  283. std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
  284. std::mutex column_name_map_mutex_; // For protecting shared access to the column map
  285. private:
  286. /// Sets the operator id.
  287. /// \notes No public interface. Only the class itself, or it's friend the execution tree can set
  288. /// this
  289. /// \param op_id - the Id value to set into the operator
  290. void set_id(int32_t op_id) { operator_id_ = op_id; }
  291. /// Sets the tree into the op so that the operator has a back pointer to the tree.
  292. /// \param tree - the tree to assign to the op.
  293. void set_tree(ExecutionTree *tree) { tree_ = tree; }
  294. };
  295. } // namespace dataset
  296. } // namespace mindspore
  297. #endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_