You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zip_op.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/zip_op.h"
  17. #include <utility>
  18. #include <iomanip>
  19. #include "dataset/core/constants.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/db_connector.h"
  22. #include "dataset/engine/opt/pass.h"
  23. #include "dataset/core/config_manager.h"
  24. #include "dataset/core/global_context.h"
  25. #include "utils/log_adapter.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. ZipOp::Builder::Builder() {
  29. // Some arguments to the ZipOp constructor have a default argument that is taken
  30. // from the client config.
  31. // The user may choose to change these values for the construction of the ZipOp by
  32. // using the various builder set methods.
  33. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  34. builder_rows_per_buffer_ = cfg->rows_per_buffer();
  35. builder_op_connector_size_ = cfg->op_connector_size();
  36. }
  37. Status ZipOp::Builder::SanityCheck() const { return Status::OK(); }
  38. Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) {
  39. RETURN_IF_NOT_OK(SanityCheck());
  40. *ptr = std::make_shared<ZipOp>(builder_rows_per_buffer_, builder_op_connector_size_);
  41. return Status::OK();
  42. }
  43. // Construct ZipOp here, local variables initialized in operator due to tree construction restrictions
  44. ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size)
  45. : PipelineOp(op_connector_size),
  46. children_num_(0),
  47. rows_per_buffer_(rows_per_buffer),
  48. buffer_id_(0),
  49. draining_(false),
  50. eof_(false) {}
  51. // destructor
  52. ZipOp::~ZipOp() {}
  53. // Entry point for Zip, called by launch()
  54. Status ZipOp::operator()() {
  55. // The children_num_ parameter needs to be put here
  56. children_num_ = child_.size();
  57. // Synchronize with TaskManager once the thread is created.
  58. TaskManager::FindMe()->Post();
  59. // initialize the iterators
  60. for (int32_t i = 0; i < children_num_; ++i) {
  61. // magic number 0 since Zip is not a parallel Op
  62. child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
  63. }
  64. // Loop until eof is true
  65. while (!eof_) {
  66. // Create tensor table and prepare it by fetching and packing the first zipped row into it.
  67. std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
  68. RETURN_IF_NOT_OK(prepare(curr_table.get()));
  69. // If an eof got picked up during the above prepare, then we're done
  70. if (eof_) {
  71. break;
  72. }
  73. while (!draining_) {
  74. // 1. If a previous loop iteration sent the current table out, then create a new one.
  75. if (curr_table == nullptr) {
  76. curr_table = std::make_unique<TensorQTable>();
  77. }
  78. // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done
  79. RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
  80. // 3 create and update buffer and send it to the out connector
  81. if (!curr_table->empty()) {
  82. std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
  83. curr_buffer->set_tensor_table(std::move(curr_table));
  84. MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
  85. << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << ".";
  86. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
  87. buffer_id_++;
  88. }
  89. }
  90. // 4 handle drain state.
  91. if (draining_) {
  92. MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
  93. RETURN_IF_NOT_OK(drainPipeline());
  94. // Now that we have drained child inputs, send the eoe up.
  95. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
  96. }
  97. }
  98. // 5 handle eof
  99. // propagate eof here.
  100. MS_LOG(DEBUG) << "Zip operator got EOF, propagating.";
  101. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
  102. return Status::OK();
  103. }
  104. // Handles preprocessing of the main loop, used when starting new epoch
  105. Status ZipOp::prepare(TensorQTable *const table) {
  106. MS_LOG(DEBUG) << "Zip operator prepares for new epoch.";
  107. draining_ = false;
  108. buffer_id_ = 0;
  109. if (table == nullptr) {
  110. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table.");
  111. }
  112. // fill initial row
  113. TensorRow new_row;
  114. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  115. // If the first row fetching resulted in eof, then we are done.
  116. if (eof_) {
  117. return Status::OK();
  118. }
  119. if (new_row.empty()) {
  120. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
  121. }
  122. // Pack this first row into our tensor table
  123. table->push_back(std::move(new_row));
  124. // At this point we have at least 1 row produced, so all child iterators have their column names such that we
  125. // can produce our column name map now.
  126. column_name_id_map_ = {};
  127. for (int32_t i = 0; i < children_num_; ++i) {
  128. // Initializing col_name_id_map_ from the first data buffer.
  129. const std::unordered_map<std::string, int32_t> col_name_id_map = child_iterators_[i]->GetColumnNameMap();
  130. int32_t colsCurrent = column_name_id_map_.size();
  131. // the update code below shouldn't do anything bad if the column name already exists.
  132. for (const auto &pair : col_name_id_map) {
  133. std::string name = pair.first;
  134. int32_t old_id = pair.second;
  135. // check if name already exists in column name descriptor
  136. if (column_name_id_map_.count(name) == 1) {
  137. RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets");
  138. }
  139. column_name_id_map_[name] = old_id + colsCurrent;
  140. }
  141. }
  142. return Status::OK();
  143. }
  144. // fillBuffer always expects a new table to fill
  145. Status ZipOp::fillBuffer(TensorQTable *const table) {
  146. if (table == nullptr) {
  147. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer.");
  148. }
  149. TensorRow new_row;
  150. while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
  151. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  152. // Early exit the loop if we got empty row from any of our child iterations
  153. if (new_row.empty()) {
  154. return Status::OK();
  155. }
  156. // else we got a row so pack it into the tensor table.
  157. table->push_back(std::move(new_row));
  158. }
  159. return Status::OK();
  160. }
  161. // fetches next zip buffer row (merged row)
  162. Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) {
  163. // iterate over all iterators and generate a row
  164. for (int32_t i = 0; i < children_num_; ++i) {
  165. TensorRow new_row = {};
  166. RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row));
  167. // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
  168. if (new_row.empty()) {
  169. // If we did not get a row from any of the children, then it's the end of an epoch and we can move
  170. // to drain state.
  171. MS_LOG(DEBUG) << "Zip operator child iterator produced empty row.";
  172. draining_ = true;
  173. new_zip_row->clear();
  174. // If we picked up an eof here, then we are completely done.
  175. if ((child_iterators_[i])->eof_handled()) {
  176. MS_LOG(DEBUG) << "Zip operator iterator got EOF.";
  177. eof_ = true;
  178. }
  179. return Status::OK();
  180. } else {
  181. MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << ".";
  182. // if row isn't empty then we can append the fetched row with new_zip_row
  183. new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end());
  184. }
  185. }
  186. MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << ".";
  187. return Status::OK();
  188. }
  189. // drain end of epoch messages from iterator for this epoch
  190. Status ZipOp::drainPipeline() {
  191. // we don't need to drain if we reached eof
  192. if (eof_) {
  193. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  194. "ZipOp draining should not be done if already at eof!");
  195. }
  196. for (int32_t con = 0; con < children_num_; ++con) {
  197. MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
  198. RETURN_IF_NOT_OK(child_iterators_[con]->Drain());
  199. }
  200. // at this point all connectors don't contain end of epoch messages. next iteration should be clean
  201. return Status::OK();
  202. }
  203. // A function that prints info about the Operator
  204. void ZipOp::Print(std::ostream &out, // In: The output stream to print to
  205. bool show_all) const { // In: T/F if it should print everything
  206. // Always show the id and name as first line regardless if this is summary or detailed print
  207. out << "(" << std::setw(2) << operator_id_ << ") <ZipOp>:";
  208. if (!show_all) {
  209. // Call the super class for displaying any common 1-liner info
  210. PipelineOp::Print(out, show_all);
  211. // Then show any custom derived-internal 1-liner info for this op
  212. out << "\n";
  213. } else {
  214. // Call the super class for displaying any common detailed info
  215. PipelineOp::Print(out, show_all);
  216. // Then show any custom derived-internal stuff
  217. out << "\nDatasets: " << children_num_ << "\n\n";
  218. }
  219. }
  220. // overwrite function and handle eof
  221. Status ZipOp::EofReceived(int32_t) {
  222. MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now.";
  223. return Status::OK();
  224. }
  225. // overwrite function and handle eoe
  226. Status ZipOp::EoeReceived(int32_t) {
  227. state_ = OpState::kDeOpIdle;
  228. return Status::OK();
  229. }
  230. // Visitor accept method for NodePass
  231. Status ZipOp::Accept(NodePass *p, bool *modified) {
  232. // Downcast shared pointer then call visitor
  233. return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
  234. }
  235. } // namespace dataset
  236. } // namespace mindspore