You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_op.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/dataset_op.h"
  17. #include <iomanip>
  18. #include <iostream>
  19. #include <memory>
  20. #include <utility>
  21. #include <string>
  22. #include "dataset/engine/execution_tree.h"
  23. #include "dataset/engine/datasetops/device_queue_op.h"
  24. #include "dataset/engine/data_buffer.h"
  25. #include "dataset/engine/db_connector.h"
  26. #include "dataset/engine/opt/pass.h"
  27. #include "utils/log_adapter.h"
  28. namespace mindspore {
  29. namespace dataset {
  30. // Constructor
  31. DatasetOp::DatasetOp(int32_t op_connector_size)
  32. : oc_queue_size_(op_connector_size),
  33. operator_id_(kInvalidOperatorId),
  34. tree_(nullptr),
  35. state_(OpState::kDeOpIdle),
  36. op_ctrl_flags_(kDeOpNone),
  37. first_fetch_(true) {
  38. // The operator starts out with an invalid operator id. The only way to
  39. // get it out of invalid state is to assign the operator to an execution tree.
  40. }
  41. // Adds a operator to become our child.
  42. Status DatasetOp::AddChild(std::shared_ptr<DatasetOp> child) {
  43. if (std::dynamic_pointer_cast<DeviceQueueOp>(child) != nullptr) {
  44. std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node");
  45. RETURN_STATUS_UNEXPECTED(err_msg);
  46. }
  47. if (operator_id_ == kInvalidOperatorId) {
  48. std::string err_msg(
  49. "Cannot add child node. Tree node connections can only"
  50. "be made if the node belongs to a tree.");
  51. RETURN_STATUS_UNEXPECTED(err_msg);
  52. }
  53. // disallow relationships with other trees
  54. if (tree_ != child->tree_) {
  55. std::string err_msg(
  56. "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree.");
  57. RETURN_STATUS_UNEXPECTED(err_msg);
  58. }
  59. child_.push_back(child);
  60. child->AddParent(this);
  61. return Status::OK();
  62. }
  63. // Adds a parent operator to this operator
  64. void DatasetOp::AddParent(const DatasetOp *parent) { parent_.push_back(parent); }
  65. // Getter function to get a shared pointer to our childAdds a operator to become our child.
  66. std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
  67. DS_ASSERT(child_index < static_cast<int>(child_.size()));
  68. // Return a shared pointer
  69. return child_[child_index];
  70. }
  71. // Creates the connector within this operator
  72. void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
  73. MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
  74. << ". Consumer: " << num_consumers << ".";
  75. if (oc_queue_size_ > 0) {
  76. out_connector_ = std::make_unique<DbConnector>(num_producers, // The number of producers
  77. num_consumers, // Only one consumer (the training App)
  78. oc_queue_size_);
  79. } else {
  80. // Some op's may choose not to have an output connector
  81. MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << ".";
  82. out_connector_ = nullptr;
  83. }
  84. }
  85. // A print method typically used for debugging. showAll of true will recursively descend to child prints
  86. void DatasetOp::Print(std::ostream &out, bool show_all) const {
  87. // When show_all is false, we display a 1 liner piece of text for the op.
  88. // When show_all is true, we display more detailed output for the op.
  89. // Derived printers should show their own header info, then call base class printer, followed by
  90. // derived-specific items.
  91. // For now, the base class doesn't have any summary info to show so it's a no-op in that case.
  92. if (show_all) {
  93. // The detailed display will show common base class info of the op. Allow the derived class to print
  94. // it's own id and name though as the first line.
  95. out << "\nNumber of children : " << child_.size();
  96. for (size_t i = 0; i < child_.size(); i++) {
  97. out << "\n Child[" << i << "] id: " << child_[i]->id();
  98. }
  99. out << "\nNumber of parents : " << parent_.size();
  100. for (size_t i = 0; i < parent_.size(); i++) {
  101. out << "\n Parent[" << i << "] id: " << parent_[i]->id();
  102. }
  103. out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex
  104. << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' ');
  105. }
  106. }
  107. // Gets the next buffer from the given child
  108. Status DatasetOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
  109. #if defined(_WIN32) || defined(_WIN64)
  110. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), p_buffer, retry_if_eoe));
  111. #else
  112. std::unique_ptr<DataBuffer> next_buff;
  113. // pop is a blocked call and will throw an interruption if the whole group shuts down.
  114. RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast<int>(worker_id), &next_buff, retry_if_eoe));
  115. *p_buffer = std::move(next_buff);
  116. #endif
  117. return Status::OK();
  118. }
  119. // Gets the next buffer from the given child . This function also has built-in eoe and eof
  120. // message handling so that child classes don't have to manually code pass-through logic when
  121. // those messages are received.
  122. Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, int32_t child_index) {
  123. if (child_.size() == 0) {
  124. return this->GetNextBuffer(p_buffer, worker_id);
  125. }
  126. CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index));
  127. std::shared_ptr<DatasetOp> child = child_[child_index];
  128. std::unique_ptr<DataBuffer> buf;
  129. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  130. // Loop until non EOE is received
  131. while (buf->eoe()) {
  132. RETURN_IF_NOT_OK(EoeReceived(worker_id));
  133. if (state_ == OpState::kDeOpIdle) {
  134. *p_buffer = std::move(buf);
  135. return Status::OK();
  136. }
  137. RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id));
  138. }
  139. // Check if the last buf is next eof
  140. if (buf->eof()) {
  141. RETURN_IF_NOT_OK(EofReceived(worker_id));
  142. }
  143. *p_buffer = std::move(buf);
  144. return Status::OK();
  145. }
  146. // Performs handling for when an eoe message is received.
  147. // The base class implementation simply flows the eoe message to output. Derived classes
  148. // may override if they need to perform special eoe handling.
  149. Status DatasetOp::EoeReceived(int32_t worker_id) {
  150. std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
  151. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eoe_buffer)));
  152. }
  153. // Performs handling for when an eof message is received.
  154. // The base class implementation simply flows the eof message to output. Derived classes
  155. // may override if they need to perform special eof handling.
  156. Status DatasetOp::EofReceived(int32_t worker_id) {
  157. std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
  158. return (out_connector_->Add(static_cast<int>(worker_id), std::move(eof_buffer)));
  159. }
  160. // During tree prepare phase, operators may have specific pre-operations to perform depending on
  161. // their role.
  162. Status DatasetOp::PrepareNodePreAction() {
  163. if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
  164. return Status::OK();
  165. }
  166. // During tree prepare phase, operators may have specific post-operations to perform depending on
  167. // their role.
  168. Status DatasetOp::PrepareNodePostAction() {
  169. // If this op does not have any children and it is in a repeat path of the tree...
  170. if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
  171. // push ourselves onto the tree repeat stack. Later, the repeat operator
  172. // above us will consume them.
  173. tree_->AddToRepeatStack(shared_from_this());
  174. }
  175. // Creating Connector object for each op.
  176. // The consumer of the root node is assumed to be one thread.
  177. // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
  178. if (parent_.empty()) {
  179. this->CreateConnector(num_producers(), 1);
  180. } else {
  181. this->CreateConnector(num_producers(), parent_[0]->num_consumers());
  182. }
  183. if (out_connector_) {
  184. RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
  185. }
  186. RETURN_IF_NOT_OK(this->RegisterWorkerConnectors());
  187. return Status::OK();
  188. }
  189. // Getter function. Base class does not have any special flags setting.
  190. uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; }
  191. // Derived classes may implement the reset function if the operator is stateful and needs
  192. // specific reset handling that is not contained in this common code version of the reset.
  193. Status DatasetOp::Reset() {
  194. state_ = OpState::kDeOpRunning;
  195. return Status::OK();
  196. }
  197. // gives a string output for the column map for handy debug printing
  198. std::string DatasetOp::ColumnNameMapAsString() const {
  199. std::string outStr = "Column name id map: ";
  200. for (auto &it : column_name_id_map_) {
  201. outStr += (" " + it.first + ":" + std::to_string(it.second));
  202. }
  203. return outStr;
  204. }
  205. // A helper function for providing assignment of the column name map.
  206. // This grabs the map from child 0 and assigns it into this op.
  207. // Can only be used if number of children is 1.
  208. Status DatasetOp::AssignColMapFromChild() {
  209. if (child_.size() > 1) {
  210. RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators.");
  211. }
  212. // Assign the correct column name map to this op by taking it from the input child.
  213. // This must be done AFTER the first fetch, but only needs to be done once by the first worker to
  214. // do the first fetch.
  215. if (first_fetch_) {
  216. // If there was a single worker, or this is being called from a master thread in a parallel op,
  217. // then the mutex is not really needed here, although it's harmless.
  218. std::unique_lock<std::mutex> lock(column_name_map_mutex_);
  219. // If the map has not been set up yet, then we are the first one in to set it up. The first_fetch_ (dirty read)
  220. // bool allows us to avoid acquiring the lock if the map has already been set.
  221. if (column_name_id_map_.empty()) {
  222. column_name_id_map_ = child_[0]->column_name_id_map();
  223. first_fetch_ = false;
  224. if (column_name_id_map_.empty()) {
  225. RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!");
  226. }
  227. }
  228. MS_LOG(DEBUG) << "Setting column map after first fetch:\n" << DatasetOp::ColumnNameMapAsString();
  229. }
  230. return Status::OK();
  231. }
  232. Status DatasetOp::Accept(NodePass *p, bool *modified) {
  233. // DatasetOp is the base class of visitor target.
  234. // This method will only be called if its derived class does not implement one.
  235. return p->RunOnNode(shared_from_this(), modified);
  236. }
  237. } // namespace dataset
  238. } // namespace mindspore