You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zip_op.cc 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/datasetops/zip_op.h"
  17. #include <utility>
  18. #include "dataset/core/constants.h"
  19. #include "dataset/engine/data_buffer.h"
  20. #include "dataset/engine/db_connector.h"
  21. #include "dataset/core/config_manager.h"
  22. #include "dataset/core/global_context.h"
  23. #include "utils/log_adapter.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. ZipOp::Builder::Builder() {
  27. // Some arguments to the ZipOp constructor have a default argument that is taken
  28. // from the client config.
  29. // The user may choose to change these values for the construction of the ZipOp by
  30. // using the various builder set methods.
  31. std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  32. builder_rows_per_buffer_ = cfg->rows_per_buffer();
  33. builder_op_connector_size_ = cfg->op_connector_size();
  34. }
  35. Status ZipOp::Builder::SanityCheck() const { return Status::OK(); }
  36. Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) {
  37. RETURN_IF_NOT_OK(SanityCheck());
  38. *ptr = std::make_shared<ZipOp>(builder_rows_per_buffer_, builder_op_connector_size_);
  39. return Status::OK();
  40. }
  41. // Construct ZipOp here, local variables initialized in operator due to tree construction restrictions
  42. ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size)
  43. : PipelineOp(op_connector_size),
  44. children_num_(0),
  45. rows_per_buffer_(rows_per_buffer),
  46. buffer_id_(0),
  47. draining_(false),
  48. eof_(false) {}
  49. // destructor
  50. ZipOp::~ZipOp() {}
  51. // Entry point for Zip, called by launch()
  52. Status ZipOp::operator()() {
  53. // The children_num_ parameter needs to be put here
  54. children_num_ = child_.size();
  55. // Synchronize with TaskManager once the thread is created.
  56. TaskManager::FindMe()->Post();
  57. // initialize the iterators
  58. for (int32_t i = 0; i < children_num_; ++i) {
  59. // magic number 0 since Zip is not a parallel Op
  60. child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
  61. }
  62. // Loop until eof is true
  63. while (!eof_) {
  64. // Create tensor table and prepare it by fetching and packing the first zipped row into it.
  65. std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
  66. RETURN_IF_NOT_OK(prepare(curr_table.get()));
  67. // If an eof got picked up during the above prepare, then we're done
  68. if (eof_) {
  69. break;
  70. }
  71. while (!draining_) {
  72. // 1. If a previous loop iteration sent the current table out, then create a new one.
  73. if (curr_table == nullptr) {
  74. curr_table = std::make_unique<TensorQTable>();
  75. }
  76. // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done
  77. RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
  78. // 3 create and update buffer and send it to the out connector
  79. if (!curr_table->empty()) {
  80. std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
  81. curr_buffer->set_tensor_table(std::move(curr_table));
  82. curr_buffer->set_column_name_map(col_name_id_map_);
  83. MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
  84. << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
  85. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
  86. buffer_id_++;
  87. }
  88. }
  89. // 4 handle drain state.
  90. if (draining_) {
  91. MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
  92. RETURN_IF_NOT_OK(drainPipeline());
  93. // Now that we have drained child inputs, send the eoe up.
  94. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
  95. }
  96. }
  97. // 5 handle eof
  98. // propagate eof here.
  99. MS_LOG(INFO) << "Zip operator got EOF, propagating.";
  100. RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
  101. return Status::OK();
  102. }
  103. // Handles preprocessing of the main loop, used when starting new epoch
  104. Status ZipOp::prepare(TensorQTable *const table) {
  105. MS_LOG(DEBUG) << "Zip operator prepares for new epoch.";
  106. draining_ = false;
  107. buffer_id_ = 0;
  108. if (table == nullptr) {
  109. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table.");
  110. }
  111. // fill initial row
  112. TensorRow new_row;
  113. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  114. // If the first row fetching resulted in eof, then we are done.
  115. if (eof_) {
  116. return Status::OK();
  117. }
  118. if (new_row.empty()) {
  119. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
  120. }
  121. // Pack this first row into our tensor table
  122. table->push_back(std::move(new_row));
  123. // At this point we have at least 1 row produced, so all child iterators have their column names such that we
  124. // can produce our column name map now.
  125. col_name_id_map_ = {};
  126. for (int32_t i = 0; i < children_num_; ++i) {
  127. // Initializing col_name_id_map_ from the first data buffer.
  128. const std::unordered_map<std::string, int32_t> col_name_id_map = child_iterators_[i]->col_name_id_map();
  129. int32_t colsCurrent = col_name_id_map_.size();
  130. // the update code below shouldn't do anything bad if the column name already exists.
  131. for (const auto &pair : col_name_id_map) {
  132. std::string name = pair.first;
  133. int32_t old_id = pair.second;
  134. // check if name already exists in column name descriptor
  135. if (col_name_id_map_.count(name) == 1) {
  136. RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets");
  137. }
  138. col_name_id_map_[name] = old_id + colsCurrent;
  139. }
  140. }
  141. return Status::OK();
  142. }
  143. // fillBuffer always expects a new table to fill
  144. Status ZipOp::fillBuffer(TensorQTable *const table) {
  145. if (table == nullptr) {
  146. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer.");
  147. }
  148. TensorRow new_row;
  149. while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
  150. RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
  151. // Early exit the loop if we got empty row from any of our child iterations
  152. if (new_row.empty()) {
  153. return Status::OK();
  154. }
  155. // else we got a row so pack it into the tensor table.
  156. table->push_back(std::move(new_row));
  157. }
  158. return Status::OK();
  159. }
  160. // fetches next zip buffer row (merged row)
  161. Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) {
  162. // iterate over all iterators and generate a row
  163. for (int32_t i = 0; i < children_num_; ++i) {
  164. TensorRow new_row = {};
  165. RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row));
  166. // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
  167. if (new_row.empty()) {
  168. // If we did not get a row from any of the children, then it's the end of an epoch and we can move
  169. // to drain state.
  170. MS_LOG(INFO) << "Zip operator child iterator produced empty row.";
  171. draining_ = true;
  172. new_zip_row->clear();
  173. // If we picked up an eof here, then we are completely done.
  174. if ((child_iterators_[i])->eof_handled()) {
  175. MS_LOG(INFO) << "Zip operator iterator got EOF.";
  176. eof_ = true;
  177. }
  178. return Status::OK();
  179. } else {
  180. MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << ".";
  181. // if row isn't empty then we can append the fetched row with new_zip_row
  182. new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end());
  183. }
  184. }
  185. MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << ".";
  186. return Status::OK();
  187. }
  188. // drain end of epoch messages from iterator for this epoch
  189. Status ZipOp::drainPipeline() {
  190. // we don't need to drain if we reached eof
  191. if (eof_) {
  192. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
  193. "ZipOp draining should not be done if already at eof!");
  194. }
  195. for (int32_t con = 0; con < children_num_; ++con) {
  196. MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
  197. RETURN_IF_NOT_OK(child_iterators_[con]->Drain());
  198. }
  199. // at this point all connectors don't contain end of epoch messages. next iteration should be clean
  200. return Status::OK();
  201. }
  202. // A function that prints info about the Operator
  203. void ZipOp::Print(std::ostream &out, // In: The output stream to print to
  204. bool show_all) const { // In: T/F if it should print everything
  205. // Call base class printer first
  206. PipelineOp::Print(out, show_all);
  207. out << "\nZipOp:\n"
  208. << "\nDatasets: " << children_num_ << "\n\n";
  209. }
  210. // overwrite function and handle eof
  211. Status ZipOp::EofReceived(int32_t) {
  212. MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now.";
  213. return Status::OK();
  214. }
  215. // overwrite function and handle eoe
  216. Status ZipOp::EoeReceived(int32_t) {
  217. state_ = OpState::kDeOpIdle;
  218. return Status::OK();
  219. }
  220. } // namespace dataset
  221. } // namespace mindspore