You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_op.cc 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/dataset_op.h"
  17. #include <iomanip>
  18. #include <iostream>
  19. #include <memory>
  20. #include <utility>
  21. #include <string>
  22. #include "dataset/engine/execution_tree.h"
  23. #include "dataset/engine/datasetops/device_queue_op.h"
  24. #include "dataset/engine/data_buffer.h"
  25. #include "dataset/engine/db_connector.h"
  26. #include "utils/log_adapter.h"
  27. namespace mindspore {
  28. namespace dataset {
  29. // Constructor
  30. DatasetOp::DatasetOp(int32_t op_connector_size)
  31. : oc_queue_size_(op_connector_size),
  32. operator_id_(kInvalidOperatorId),
  33. tree_(nullptr),
  34. state_(OpState::kDeOpIdle),
  35. op_ctrl_flags_(kDeOpNone) {
  36. // The operator starts out with an invalid operator id. The only way to
  37. // get it out of invalid state is to assign the operator to an execution tree.
  38. }
  39. // Adds a operator to become our child.
  40. Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) {
  41. if (std::dynamic_pointer_cast<DeviceQueueOp>(child) != nullptr) {
  42. std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node");
  43. RETURN_STATUS_UNEXPECTED(err_msg);
  44. }
  45. if (operator_id_ == kInvalidOperatorId) {
  46. std::string err_msg(
  47. "Cannot add child node. Tree node connections can only"
  48. "be made if the node belongs to a tree.");
  49. RETURN_STATUS_UNEXPECTED(err_msg);
  50. }
  51. // disallow relationships with other trees
  52. if (tree_ != child->tree_) {
  53. std::string err_msg(
  54. "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree.");
  55. RETURN_STATUS_UNEXPECTED(err_msg);
  56. }
  57. child_.push_back(child);
  58. child->AddParent(this);
  59. return Status::OK();
  60. }
  61. // Adds a parent operator to this operator
  62. void DatasetOp::AddParent(const DatasetOp *parent) { parent_.push_back(parent); }
  63. // Getter function to get a shared pointer to our childAdds a operator to become our child.
  64. std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
  65. DS_ASSERT(child_index < static_cast<int>(child_.size()));
  66. // Return a shared pointer
  67. return child_[child_index];
  68. }
  69. // Creates the connector within this operator
  70. void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
  71. MS_LOG(INFO) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
  72. << ". Consumer: " << num_consumers << ".";
  73. if (oc_queue_size_ > 0) {
  74. out_connector_ = std::make_unique<DbConnector>(num_producers, // The number of producers
  75. num_consumers, // Only one consumer (the training App)
  76. oc_queue_size_);
  77. } else {
  78. // Some op's may choose not to have an output connector
  79. MS_LOG(INFO) << "Bypassed connector creation for tree operator: " << operator_id_ << ".";
  80. out_connector_ = nullptr;
  81. }
  82. }
  83. // A print method typically used for debugging. showAll of true will recursively descend to child prints
  84. void DatasetOp::Print(std::ostream &out, bool show_all) const {
  85. if (show_all) {
  86. for (size_t i = 0; i < child_.size(); i++) {
  87. child_[i]->Print(out, show_all);
  88. }
  89. }
  90. out << "\n-------------------------"
  91. << "\nOperator # : " << operator_id_ << "\nNumber of children : " << child_.size()
  92. << "\nNumber of parents : " << parent_.size() << "\nConnector queue size : " << oc_queue_size_
  93. << "\nOperator control flags : 0x" << std::hex << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec
  94. << std::setfill(' ') << "\nHas parents:\n";
  95. for (size_t i = 0; i < parent_.size(); i++) {
  96. out << "Parent[" << i << "] id: " << parent_[i]->id() << "\n";
  97. }
  98. }
  99. // Gets the next buffer from the given child
  100. Status DatasetOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  101. #if defined(_WIN32) || defined(_WIN64)
  102. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), p_buffer, retry_if_eoe));
  103. #else
  104. std::unique_ptr<DataBuffer> next_buff;
  105. // pop is a blocked call and will throw an interruption if the whole group shuts down.
  106. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), &next_buff, retry_if_eoe));
  107. *p_buffer = std::move(next_buff);
  108. #endif
  109. return Status::OK();
  110. }
  111. // Gets the next buffer from the given child . This function also has built-in eoe and eof
  112. // message handling so that child classes don't have to manually code pass-through logic when
  113. // those messages are received.
  114. Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, int32_t child_index) {
  115. if (child_.size() == 0) {
  116. return this->GetNextBuffer(p_buffer, worker_id);
  117. }
  118. CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index));
  119. std::shared_ptr<DatasetOp> child = child_[child_index];
  120. std::unique_ptr<DataBuffer> buf;
  121. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  122. // Loop until non EOE is received
  123. while (buf->eoe()) {
  124. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  125. if (state_ == OpState::kDeOpIdle) {
  126. *p_buffer = std::move(buf);
  127. return Status::OK();
  128. }
  129. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  130. }
  131. // Check if the last buf is next eof
  132. if (buf->eof()) {
  133. RETURN_IF_NOT_OK(EofReceived(worker_id));
  134. }
  135. *p_buffer = std::move(buf);
  136. return Status::OK();
  137. }
  138. // Performs handling for when an eoe message is received.
  139. // The base class implementation simply flows the eoe message to output. Derived classes
  140. // may override if they need to perform special eoe handling.
  141. Status DatasetOp::EoeReceived(int32_t worker_id) {
  142. std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  143. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eoe_buffer)));
  144. }
  145. // Performs handling for when an eof message is received.
  146. // The base class implementation simply flows the eof message to output. Derived classes
  147. // may override if they need to perform special eof handling.
  148. Status DatasetOp::EofReceived(int32_t worker_id) {
  149. std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  150. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
  151. }
  152. // During tree prepare phase, operators may have specific pre-operations to perform depending on
  153. // their role.
  154. Status DatasetOp::PrepareNodePreAction() {
  155. if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
  156. return Status::OK();
  157. }
  158. // During tree prepare phase, operators may have specific post-operations to perform depending on
  159. // their role.
  160. Status DatasetOp::PrepareNodePostAction() {
  161. // If this op does not have any children and it is in a repeat path of the tree...
  162. if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
  163. // push ourselves onto the tree repeat stack. Later, the repeat operator
  164. // above us will consume them.
  165. tree_->AddToRepeatStack(shared_from_this());
  166. }
  167. // Creating Connector object for each op.
  168. // The consumer of the root node is assumed to be one thread.
  169. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
  170. if (parent_.empty()) {
  171. this->CreateConnector(num_producers(), 1);
  172. } else {
  173. this->CreateConnector(num_producers(), parent_[0]->num_consumers());
  174. }
  175. if (out_connector_) {
  176. RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
  177. }
  178. RETURN_IF_NOT_OK(this->RegisterWorkerConnectors());
  179. return Status::OK();
  180. }
  181. // Getter function. Base class does not have any special flags setting.
  182. uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; }
  183. // Derived classes may implement the reset function if the operator is stateful and needs
  184. // specific reset handling that is not contained in this common code version of the reset.
  185. Status DatasetOp::Reset() {
  186. state_ = OpState::kDeOpRunning;
  187. return Status::OK();
  188. }
  189. } // namespace dataset
  190. } // namespace mindspore