You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

barrier_op.cc 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/barrier_op.h"
  17. #include <utility>
  18. #include "dataset/core/constants.h"
  19. #include "dataset/engine/data_buffer.h"
  20. #include "dataset/engine/db_connector.h"
  21. #include "dataset/core/config_manager.h"
  22. #include "dataset/core/global_context.h"
  23. #include "utils/log_adapter.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. BarrierOp::Builder::Builder() {
  27. // Some arguments to the BarrierOp constructor have a default argument that is taken
  28. // from the client config.
  29. // The user may choose to change these values for the construction of the BarrierOp by
  30. // using the various builder set methods.
  31. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  32. builder_rows_per_buffer_ = cfg->rows_per_buffer();
  33. builder_op_connector_size_ = cfg->op_connector_size();
  34. }
  35. Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }
  36. Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
  37. RETURN_IF_NOT_OK(SanityCheck());
  38. *ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_,
  39. builder_condition_func_);
  40. return Status::OK();
  41. }
  42. // Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
  43. BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
  44. py::function condition_func)
  45. : PipelineOp(op_connector_size),
  46. rows_per_buffer_(rows_per_buffer),
  47. buffer_id_(0),
  48. clean_up_(false),
  49. eof_(false),
  50. condition_name_(condition_name),
  51. condition_function_(condition_func) {}
  52. // destructor
  53. BarrierOp::~BarrierOp() {}
  54. // Entry point for Barrier, called by launch()
  55. Status BarrierOp::operator()() {
  56. // The children_num_ parameter needs to be put here
  57. // Synchronize with TaskManager once the thread is created.
  58. TaskManager::FindMe()->Post();
  59. // create child iterator, right now this barrier is a pipeline operator
  60. const int32_t worker_id = 0;
  61. const int32_t child_idx = 0;
  62. child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
  63. // Loop until eof is true
  64. while (!eof_) {
  65. // Create new table to put the new tensor rows
  66. std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
  67. RETURN_IF_NOT_OK(prepare(curr_table.get()));
  68. // If an eof got picked up during the above prepare, then we're done
  69. if (eof_) {
  70. break;
  71. }
  72. // we have to output new buffer with possibly different buffer size, possibly one row
  73. while (!clean_up_) {
  74. // 1. If a previous loop iteration sent the current table out, then create a new one.
  75. if (curr_table == nullptr) {
  76. curr_table = std::make_unique<TensorQTable>();
  77. }
  78. // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
  79. RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
  80. // 3 create and update buffer and send it to the out connector
  81. if (!curr_table->empty()) {
  82. std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
  83. curr_buffer->set_tensor_table(std::move(curr_table));
  84. curr_buffer->set_column_name_map(col_name_id_map_);
  85. MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
  86. << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
  87. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
  88. buffer_id_++;
  89. }
  90. }
  91. // 4 handle drain state.
  92. if (clean_up_) {
  93. MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal.";
  94. // Send the eoe up.
  95. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
  96. }
  97. }
  98. // 5 handle eof
  99. // propagate eof here.
  100. MS_LOG(INFO) << "Barrier operator got EOF, propagating.";
  101. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
  102. return Status::OK();
  103. }
  104. // Handles preprocessing of the main loop, used when starting new epoch
  105. Status BarrierOp::prepare(TensorQTable *const table) {
  106. MS_LOG(DEBUG) << "Barrier operator prepares for new epoch.";
  107. clean_up_ = false;
  108. buffer_id_ = 0;
  109. if (table == nullptr) {
  110. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
  111. }
  112. // fill initial row
  113. TensorRow new_row = {};
  114. // use iterator to get next row and invoke pyfunc wait
  115. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  116. // If the first row fetching resulted in eof, then we are done.
  117. if (eof_) {
  118. return Status::OK();
  119. }
  120. if (new_row.empty()) {
  121. // This epoch is empty
  122. return Status::OK();
  123. }
  124. // Pack this first row into our tensor table
  125. // first row we also have to check if we should block
  126. RETURN_IF_NOT_OK(blockCond());
  127. table->push_back(std::move(new_row));
  128. // At this point we have 1 row produced, we take the old column map id and use it in the new table
  129. // Initializing col_name_id_map_ from the first data buffer.
  130. col_name_id_map_ = child_iterator_->col_name_id_map();
  131. // the update code below shouldn't do anything bad if the column name already exists.
  132. return Status::OK();
  133. }
  134. // fillBuffer always expects a new table to fill
  135. Status BarrierOp::fillBuffer(TensorQTable *const table) {
  136. if (table == nullptr) {
  137. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
  138. }
  139. TensorRow new_row = {};
  140. while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
  141. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  142. // Early exit the loop if we got empty row from any of our child iterations
  143. if (new_row.empty()) {
  144. return Status::OK();
  145. }
  146. // else we got a row so pack it into the tensor table.
  147. RETURN_IF_NOT_OK(blockCond());
  148. table->push_back(std::move(new_row));
  149. }
  150. return Status::OK();
  151. }
  152. // function executes a py_func and blocks until condition becomes true.
  153. Status BarrierOp::blockCond() {
  154. {
  155. py::gil_scoped_acquire gil_acquire;
  156. if (Py_IsInitialized() == 0) {
  157. return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
  158. }
  159. // we have condition name, however the flexibility is in python today
  160. try {
  161. // Invoke python function
  162. py::object ret_py_obj = condition_function_();
  163. // Process the return value
  164. if (!py::isinstance<py::bool_>(ret_py_obj)) {
  165. return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false");
  166. }
  167. } catch (const py::error_already_set &e) {
  168. return Status(StatusCode::kPyFuncException, e.what());
  169. }
  170. }
  171. return Status::OK();
  172. }
  173. // fetches next Barrier buffer row
  174. Status BarrierOp::getNextTensorRow(TensorRow *new_row) {
  175. // iterate over all iterators and generate a row
  176. RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row));
  177. // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
  178. if (new_row->empty()) {
  179. // If we did not get a row from any of the children, then it's the end of an epoch and we can move
  180. // to drain state.
  181. MS_LOG(INFO) << "Barrier operator child iterator produced empty row.";
  182. clean_up_ = true;
  183. // If we picked up an eof here, then we are completely done.
  184. if ((child_iterator_)->eof_handled()) {
  185. MS_LOG(INFO) << "Barrier operator iterator got EOF.";
  186. eof_ = true;
  187. }
  188. return Status::OK();
  189. }
  190. return Status::OK();
  191. }
  192. // A function that prints info about the Operator
  193. void BarrierOp::Print(std::ostream &out, bool show_all) const {
  194. // Call base class printer first
  195. PipelineOp::Print(out, show_all);
  196. out << "\nBarrierOp:\n"
  197. << "\nCondition " << condition_name_ << "\n\n";
  198. }
  199. // overwrite function and handle eof
  200. Status BarrierOp::EofReceived(int32_t) {
  201. MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now.";
  202. return Status::OK();
  203. }
  204. // overwrite function and handle eoe
  205. Status BarrierOp::EoeReceived(int32_t) {
  206. state_ = OpState::kDeOpIdle;
  207. return Status::OK();
  208. }
  209. } // namespace dataset
  210. } // namespace mindspore