You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zip_op.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/zip_op.h"
  17. #include <utility>
  18. #include <iomanip>
  19. #include "dataset/core/constants.h"
  20. #include "dataset/engine/data_buffer.h"
  21. #include "dataset/engine/db_connector.h"
  22. #include "dataset/core/config_manager.h"
  23. #include "dataset/core/global_context.h"
  24. #include "utils/log_adapter.h"
  25. namespace mindspore {
  26. namespace dataset {
  27. ZipOp::Builder::Builder() {
  28. // Some arguments to the ZipOp constructor have a default argument that is taken
  29. // from the client config.
  30. // The user may choose to change these values for the construction of the ZipOp by
  31. // using the various builder set methods.
  32. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  33. builder_rows_per_buffer_ = cfg->rows_per_buffer();
  34. builder_op_connector_size_ = cfg->op_connector_size();
  35. }
  36. Status ZipOp::Builder::SanityCheck() const { return Status::OK(); }
  37. Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) {
  38. RETURN_IF_NOT_OK(SanityCheck());
  39. *ptr = std::make_shared<ZipOp>(builder_rows_per_buffer_, builder_op_connector_size_);
  40. return Status::OK();
  41. }
  42. // Construct ZipOp here, local variables initialized in operator due to tree construction restrictions
  43. ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size)
  44. : PipelineOp(op_connector_size),
  45. children_num_(0),
  46. rows_per_buffer_(rows_per_buffer),
  47. buffer_id_(0),
  48. draining_(false),
  49. eof_(false) {}
  50. // destructor
  51. ZipOp::~ZipOp() {}
  52. // Entry point for Zip, called by launch()
  53. Status ZipOp::operator()() {
  54. // The children_num_ parameter needs to be put here
  55. children_num_ = child_.size();
  56. // Synchronize with TaskManager once the thread is created.
  57. TaskManager::FindMe()->Post();
  58. // initialize the iterators
  59. for (int32_t i = 0; i < children_num_; ++i) {
  60. // magic number 0 since Zip is not a parallel Op
  61. child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
  62. }
  63. // Loop until eof is true
  64. while (!eof_) {
  65. // Create tensor table and prepare it by fetching and packing the first zipped row into it.
  66. std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
  67. RETURN_IF_NOT_OK(prepare(curr_table.get()));
  68. // If an eof got picked up during the above prepare, then we're done
  69. if (eof_) {
  70. break;
  71. }
  72. while (!draining_) {
  73. // 1. If a previous loop iteration sent the current table out, then create a new one.
  74. if (curr_table == nullptr) {
  75. curr_table = std::make_unique<TensorQTable>();
  76. }
  77. // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done
  78. RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
  79. // 3 create and update buffer and send it to the out connector
  80. if (!curr_table->empty()) {
  81. std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
  82. curr_buffer->set_tensor_table(std::move(curr_table));
  83. MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
  84. << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << ".";
  85. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
  86. buffer_id_++;
  87. }
  88. }
  89. // 4 handle drain state.
  90. if (draining_) {
  91. MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
  92. RETURN_IF_NOT_OK(drainPipeline());
  93. // Now that we have drained child inputs, send the eoe up.
  94. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
  95. }
  96. }
  97. // 5 handle eof
  98. // propagate eof here.
  99. MS_LOG(INFO) << "Zip operator got EOF, propagating.";
  100. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
  101. return Status::OK();
  102. }
  103. // Handles preprocessing of the main loop, used when starting new epoch
  104. Status ZipOp::prepare(TensorQTable *const table) {
  105. MS_LOG(DEBUG) << "Zip operator prepares for new epoch.";
  106. draining_ = false;
  107. buffer_id_ = 0;
  108. if (table == nullptr) {
  109. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table.");
  110. }
  111. // fill initial row
  112. TensorRow new_row;
  113. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  114. // If the first row fetching resulted in eof, then we are done.
  115. if (eof_) {
  116. return Status::OK();
  117. }
  118. if (new_row.empty()) {
  119. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
  120. }
  121. // Pack this first row into our tensor table
  122. table->push_back(std::move(new_row));
  123. // At this point we have at least 1 row produced, so all child iterators have their column names such that we
  124. // can produce our column name map now.
  125. column_name_id_map_ = {};
  126. for (int32_t i = 0; i < children_num_; ++i) {
  127. // Initializing col_name_id_map_ from the first data buffer.
  128. const std::unordered_map<std::string, int32_t> col_name_id_map = child_iterators_[i]->GetColumnNameMap();
  129. int32_t colsCurrent = column_name_id_map_.size();
  130. // the update code below shouldn't do anything bad if the column name already exists.
  131. for (const auto &pair : col_name_id_map) {
  132. std::string name = pair.first;
  133. int32_t old_id = pair.second;
  134. // check if name already exists in column name descriptor
  135. if (column_name_id_map_.count(name) == 1) {
  136. RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets");
  137. }
  138. column_name_id_map_[name] = old_id + colsCurrent;
  139. }
  140. }
  141. return Status::OK();
  142. }
  143. // fillBuffer always expects a new table to fill
  144. Status ZipOp::fillBuffer(TensorQTable *const table) {
  145. if (table == nullptr) {
  146. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer.");
  147. }
  148. TensorRow new_row;
  149. while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
  150. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  151. // Early exit the loop if we got empty row from any of our child iterations
  152. if (new_row.empty()) {
  153. return Status::OK();
  154. }
  155. // else we got a row so pack it into the tensor table.
  156. table->push_back(std::move(new_row));
  157. }
  158. return Status::OK();
  159. }
  160. // fetches next zip buffer row (merged row)
  161. Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) {
  162. // iterate over all iterators and generate a row
  163. for (int32_t i = 0; i < children_num_; ++i) {
  164. TensorRow new_row = {};
  165. RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row));
  166. // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
  167. if (new_row.empty()) {
  168. // If we did not get a row from any of the children, then it's the end of an epoch and we can move
  169. // to drain state.
  170. MS_LOG(INFO) << "Zip operator child iterator produced empty row.";
  171. draining_ = true;
  172. new_zip_row->clear();
  173. // If we picked up an eof here, then we are completely done.
  174. if ((child_iterators_[i])->eof_handled()) {
  175. MS_LOG(INFO) << "Zip operator iterator got EOF.";
  176. eof_ = true;
  177. }
  178. return Status::OK();
  179. } else {
  180. MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << ".";
  181. // if row isn't empty then we can append the fetched row with new_zip_row
  182. new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end());
  183. }
  184. }
  185. MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << ".";
  186. return Status::OK();
  187. }
  188. // drain end of epoch messages from iterator for this epoch
  189. Status ZipOp::drainPipeline() {
  190. // we don't need to drain if we reached eof
  191. if (eof_) {
  192. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  193. "ZipOp draining should not be done if already at eof!");
  194. }
  195. for (int32_t con = 0; con < children_num_; ++con) {
  196. MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
  197. RETURN_IF_NOT_OK(child_iterators_[con]->Drain());
  198. }
  199. // at this point all connectors don't contain end of epoch messages. next iteration should be clean
  200. return Status::OK();
  201. }
  202. // A function that prints info about the Operator
  203. void ZipOp::Print(std::ostream &out, // In: The output stream to print to
  204. bool show_all) const { // In: T/F if it should print everything
  205. // Always show the id and name as first line regardless if this is summary or detailed print
  206. out << "(" << std::setw(2) << operator_id_ << ") <ZipOp>:";
  207. if (!show_all) {
  208. // Call the super class for displaying any common 1-liner info
  209. PipelineOp::Print(out, show_all);
  210. // Then show any custom derived-internal 1-liner info for this op
  211. out << "\n";
  212. } else {
  213. // Call the super class for displaying any common detailed info
  214. PipelineOp::Print(out, show_all);
  215. // Then show any custom derived-internal stuff
  216. out << "\nDatasets: " << children_num_ << "\n\n";
  217. }
  218. }
  219. // overwrite function and handle eof
  220. Status ZipOp::EofReceived(int32_t) {
  221. MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now.";
  222. return Status::OK();
  223. }
  224. // overwrite function and handle eoe
  225. Status ZipOp::EoeReceived(int32_t) {
  226. state_ = OpState::kDeOpIdle;
  227. return Status::OK();
  228. }
  229. } // namespace dataset
  230. } // namespace mindspore