You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.cc 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/dataset_iterator.h"
  17. #include <unordered_map>
  18. #include <utility>
  19. #include "dataset/core/data_type.h"
  20. #include "dataset/core/tensor.h"
  21. #include "dataset/core/tensor_shape.h"
  22. #include "dataset/engine/data_buffer.h"
  23. #include "dataset/engine/execution_tree.h"
  24. #include "dataset/util/status.h"
  25. #include "dataset/engine/datasetops/dataset_op.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Constructor of the IteratorBase
  29. IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false), first_row_(true) {}
  30. IteratorBase::~IteratorBase() = default;
  31. // Fetches one row of data from the iterator as a column map.
  32. Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
  33. if (out_map == nullptr) {
  34. RETURN_STATUS_UNEXPECTED("Null output map in iterator!");
  35. }
  36. out_map->clear();
  37. TensorRow curr_row;
  38. RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
  39. // Return empty map if there's no data
  40. if (curr_row.empty()) {
  41. return Status::OK();
  42. }
  43. // The column name mapping is needed to be able to produce the tensor map output.
  44. // The column name mapping comes from the source operator that is producing the data into the iterator.
  45. // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping
  46. // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping.
  47. // Note: This can only be done after the first row has been produced, as this guarantees the the child has
  48. // it's column mapping set up.
  49. if (first_row_) {
  50. // Determine the column name map by calling the derived class method to retrieve the column
  51. // name map
  52. col_name_id_map_ = this->GetColumnNameMap();
  53. first_row_ = false;
  54. }
  55. // Populate the out map from the row and return it
  56. for (auto colMap : col_name_id_map_) {
  57. (*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
  58. }
  59. return Status::OK();
  60. }
  61. // Fetches one row of data from the iterator.
  62. // The base class version simply performs error handling and returns empty row. Actual
  63. // functionality exists in the derived versions of this function.
  64. Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
  65. if (out_row == nullptr) {
  66. RETURN_STATUS_UNEXPECTED("Null output row in iterator!");
  67. }
  68. // clear the old tensor row
  69. out_row->clear();
  70. return Status::OK();
  71. }
  72. // Constructor of the DatasetIterator
  73. DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree)
  74. : IteratorBase(), root_(exe_tree->root()), tracing_(nullptr), cur_batch_num_(0), cur_connector_size_(0) {
  75. std::shared_ptr<Tracing> node;
  76. Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
  77. if (s.IsOk()) {
  78. tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
  79. }
  80. }
  81. DatasetIterator::~DatasetIterator() = default;
  82. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  83. // from the tree root node directly.
  84. Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
  85. // Common code init and error checking in the base class.
  86. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  87. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  88. // want to iterate again.
  89. if (eof_handled_) {
  90. return Status::OK();
  91. }
  92. // Check if we need to get a new DataBuffer to iterate.
  93. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  94. if (tracing_ != nullptr) {
  95. cur_connector_size_ = root_->ConnectorSize();
  96. cur_connector_capacity_ = root_->ConnectorCapacity();
  97. }
  98. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  99. // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
  100. // handle eoe and eof messages here.
  101. //
  102. // An eoe buffer means we have iterated fully to the end of the tree.
  103. // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
  104. // all operators.
  105. if (curr_buffer_->eoe()) {
  106. MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
  107. // Before returning the last empty vector, fetch the eof buffer which should be the last
  108. // buffer, and then free it.
  109. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  110. if (!curr_buffer_->eof()) {
  111. RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
  112. }
  113. eof_handled_ = true;
  114. curr_buffer_.reset(); // explicitly free the eof buffer
  115. // Set tree to Finished state
  116. root_->Tree()->SetFinished();
  117. return Status::OK();
  118. }
  119. if (curr_buffer_->eof()) {
  120. // An eof by itself, without being preceded by an eoe, is possible if a repeat operator
  121. // exists below us in the stack. Repeat operator eats eoe's but eventually allows the
  122. // flow of an eof up the pipeline by itself.
  123. eof_handled_ = true;
  124. curr_buffer_.reset(); // explicitly free the eof buffer
  125. // Set tree to Finished state
  126. root_->Tree()->SetFinished();
  127. return Status::OK();
  128. }
  129. }
  130. // If we got this far, now it's time to pop that next row for return to caller
  131. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  132. if (tracing_ != nullptr) {
  133. cur_batch_num_++;
  134. tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
  135. }
  136. return Status::OK();
  137. }
  138. Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
  139. if (out_shapes == nullptr) {
  140. RETURN_STATUS_UNEXPECTED("Null output shape argument");
  141. }
  142. if (device_queue_row_.empty()) {
  143. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  144. }
  145. for (auto ts : device_queue_row_) {
  146. out_shapes->push_back(ts->shape());
  147. }
  148. return Status::OK();
  149. }
  150. Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
  151. if (out_types == nullptr) {
  152. RETURN_STATUS_UNEXPECTED("Null output type argument");
  153. }
  154. if (device_queue_row_.empty()) {
  155. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  156. }
  157. for (auto ts : device_queue_row_) {
  158. out_types->push_back(ts->type());
  159. }
  160. return Status::OK();
  161. }
  162. // Getter
  163. std::unordered_map<std::string, int32_t> DatasetIterator::GetColumnNameMap() const {
  164. return root_->column_name_id_map();
  165. }
  166. // Constructor of the ChildIterator
  167. ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx)
  168. : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {}
  169. ChildIterator::~ChildIterator() { current_op_ = nullptr; }
  170. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  171. // only from the child/worker id as given from the constructor.
  172. Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
  173. // Common code init and error checking in the base class.
  174. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  175. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  176. // want to iterate again.
  177. if (eof_handled_) {
  178. return Status::OK();
  179. }
  180. // Check if we need to get a new DataBuffer to iterate.
  181. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  182. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  183. // Unlike the DatasetIterator, this child iterator does not quit after eoe.
  184. // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
  185. // caller to decide what it wants to do next.
  186. if (curr_buffer_->eoe()) {
  187. MS_LOG(DEBUG) << "Child iterator picked up EOE.";
  188. end_epoch_ = true;
  189. return Status::OK();
  190. }
  191. if (curr_buffer_->eof()) {
  192. MS_LOG(DEBUG) << "Child iterator picked up EOF.";
  193. eof_handled_ = true;
  194. return Status::OK();
  195. }
  196. }
  197. // If we got this far, now it's time to pop that next row for return to caller
  198. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  199. return Status::OK();
  200. }
  201. // drain till the next eoe
  202. Status ChildIterator::Drain() {
  203. if (end_epoch_ == true) {
  204. // Calling drain against a child that is already at it's eoe state will not result in any action.
  205. // This allows you to do:
  206. // - fetch until empty row
  207. // - drain (will not actually drain because you are already at the end of the iteration)
  208. // However, the next time after that, it will perform it's normal draining activities.
  209. end_epoch_ = false;
  210. MS_LOG(DEBUG) << "No operation drain, already at end of epoch.";
  211. return Status::OK();
  212. }
  213. MS_LOG(DEBUG) << "Child draining buffers until eoe.";
  214. // else we drain until eoe or eof, eof here is for sanity check
  215. while (!curr_buffer_->eoe() && !curr_buffer_->eof()) {
  216. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  217. }
  218. if (curr_buffer_->eof()) {
  219. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
  220. }
  221. return Status::OK();
  222. }
  223. // Getter
  224. std::unordered_map<std::string, int32_t> ChildIterator::GetColumnNameMap() const {
  225. return current_op_->child(child_idx_)->column_name_id_map();
  226. }
  227. } // namespace dataset
  228. } // namespace mindspore