You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.cc 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/dataset_iterator.h"
  17. #include <utility>
  18. #include "dataset/core/data_type.h"
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/core/tensor_shape.h"
  21. #include "dataset/engine/data_buffer.h"
  22. #include "dataset/engine/execution_tree.h"
  23. #include "dataset/util/status.h"
  24. #include "dataset/engine/datasetops/dataset_op.h"
  25. namespace mindspore {
  26. namespace dataset {
  27. // Constructor of the IteratorBase
  28. IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {}
  29. IteratorBase::~IteratorBase() = default;
  30. // Fetches one row of data from the iterator as a column map.
  31. Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
  32. if (out_map == nullptr) {
  33. RETURN_STATUS_UNEXPECTED("Null output map in iterator!");
  34. }
  35. out_map->clear();
  36. TensorRow curr_row;
  37. RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
  38. // Return empty map if there's no data
  39. if (curr_row.empty()) {
  40. return Status::OK();
  41. }
  42. // Populate the out map from the row and return it
  43. for (auto colMap : col_name_id_map_) {
  44. (*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
  45. }
  46. return Status::OK();
  47. }
  48. // Fetches one row of data from the iterator.
  49. // The base class version simply performs error handling and returns empty row. Actual
  50. // functionality exists in the derived versions of this function.
  51. Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
  52. if (out_row == nullptr) {
  53. RETURN_STATUS_UNEXPECTED("Null output row in iterator!");
  54. }
  55. // clear the old tensor row
  56. out_row->clear();
  57. return Status::OK();
  58. }
  59. // Constructor of the DatasetIterator
  60. DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree) : IteratorBase(), root_(exe_tree->root()) {}
  61. DatasetIterator::~DatasetIterator() = default;
  62. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  63. // from the tree root node directly.
  64. Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
  65. // Common code init and error checking in the base class.
  66. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  67. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  68. // want to iterate again.
  69. if (eof_handled_) {
  70. return Status::OK();
  71. }
  72. // Check if we need to get a new DataBuffer to iterate.
  73. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  74. col_name_id_map_.clear();
  75. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  76. // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
  77. // handle eoe and eof messages here.
  78. //
  79. // An eoe buffer means we have iterated fully to the end of the tree.
  80. // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
  81. // all operators.
  82. if (curr_buffer_->eoe()) {
  83. MS_LOG(INFO) << "End of data iteration. Fetch eof and then return empty row.";
  84. // Before returning the last empty vector, fetch the eof buffer which should be the last
  85. // buffer, and then free it.
  86. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  87. if (!curr_buffer_->eof()) {
  88. RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
  89. }
  90. eof_handled_ = true;
  91. curr_buffer_.reset(); // explicitly free the eof buffer
  92. return Status::OK();
  93. }
  94. if (curr_buffer_->eof()) {
  95. // An eof by itself, without being preceded by an eoe, is possible if a repeat operator
  96. // exists below us in the stack. Repeat operator eats eoe's but eventually allows the
  97. // flow of an eof up the pipeline by itself.
  98. eof_handled_ = true;
  99. curr_buffer_.reset(); // explicitly free the eof buffer
  100. return Status::OK();
  101. }
  102. col_name_id_map_ = curr_buffer_->column_name_map();
  103. }
  104. // If we got this far, now it's time to pop that next row for return to caller
  105. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  106. return Status::OK();
  107. }
  108. Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
  109. if (out_shapes == nullptr) {
  110. RETURN_STATUS_UNEXPECTED("Null output shape argument");
  111. }
  112. if (device_queue_row_.empty()) {
  113. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  114. }
  115. for (auto ts : device_queue_row_) {
  116. out_shapes->push_back(ts->shape());
  117. }
  118. return Status::OK();
  119. }
  120. Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
  121. if (out_types == nullptr) {
  122. RETURN_STATUS_UNEXPECTED("Null output type argument");
  123. }
  124. if (device_queue_row_.empty()) {
  125. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  126. }
  127. for (auto ts : device_queue_row_) {
  128. out_types->push_back(ts->type());
  129. }
  130. return Status::OK();
  131. }
  132. // Constructor of the ChildIterator
  133. ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx)
  134. : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {}
  135. ChildIterator::~ChildIterator() { current_op_ = nullptr; }
  136. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  137. // only from the child/worker id as given from the constructor.
  138. Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
  139. // Common code init and error checking in the base class.
  140. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  141. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  142. // want to iterate again.
  143. if (eof_handled_) {
  144. return Status::OK();
  145. }
  146. // Check if we need to get a new DataBuffer to iterate.
  147. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  148. col_name_id_map_.clear();
  149. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  150. // Unlike the DatasetIterator, this child iterator does not quit after eoe.
  151. // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
  152. // caller to decide what it wants to do next.
  153. if (curr_buffer_->eoe()) {
  154. MS_LOG(INFO) << "Child iterator picked up EOE.";
  155. end_epoch_ = true;
  156. return Status::OK();
  157. }
  158. if (curr_buffer_->eof()) {
  159. MS_LOG(INFO) << "Child iterator picked up EOF.";
  160. eof_handled_ = true;
  161. return Status::OK();
  162. }
  163. col_name_id_map_ = curr_buffer_->column_name_map();
  164. }
  165. // If we got this far, now it's time to pop that next row for return to caller
  166. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  167. return Status::OK();
  168. }
  169. // drain till the next eoe
  170. Status ChildIterator::Drain() {
  171. if (end_epoch_ == true) {
  172. // Calling drain against a child that is already at it's eoe state will not result in any action.
  173. // This allows you to do:
  174. // - fetch until empty row
  175. // - drain (will not actually drain because you are already at the end of the iteration)
  176. // However, the next time after that, it will perform it's normal draining activities.
  177. end_epoch_ = false;
  178. MS_LOG(INFO) << "No operation drain, already at end of epoch.";
  179. return Status::OK();
  180. }
  181. MS_LOG(INFO) << "Child draining buffers until eoe.";
  182. // else we drain until eoe or eof, eof here is for sanity check
  183. while (!curr_buffer_->eoe() && !curr_buffer_->eof()) {
  184. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  185. }
  186. if (curr_buffer_->eof()) {
  187. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
  188. }
  189. return Status::OK();
  190. }
  191. } // namespace dataset
  192. } // namespace mindspore