You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.cc 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/dataset_iterator.h"
  17. #include <unordered_map>
  18. #include <utility>
  19. #include "dataset/core/data_type.h"
  20. #include "dataset/core/tensor.h"
  21. #include "dataset/core/tensor_shape.h"
  22. #include "dataset/engine/data_buffer.h"
  23. #include "dataset/engine/execution_tree.h"
  24. #include "dataset/util/status.h"
  25. #include "dataset/engine/datasetops/dataset_op.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Constructor of the IteratorBase
  29. IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false), first_row_(true) {}
  30. IteratorBase::~IteratorBase() = default;
  31. // Fetches one row of data from the iterator as a column map.
  32. Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
  33. if (out_map == nullptr) {
  34. RETURN_STATUS_UNEXPECTED("Null output map in iterator!");
  35. }
  36. out_map->clear();
  37. TensorRow curr_row;
  38. RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
  39. // Return empty map if there's no data
  40. if (curr_row.empty()) {
  41. return Status::OK();
  42. }
  43. // The column name mapping is needed to be able to produce the tensor map output.
  44. // The column name mapping comes from the source operator that is producing the data into the iterator.
  45. // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping
  46. // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping.
  47. // Note: This can only be done after the first row has been produced, as this guarantees the the child has
  48. // it's column mapping set up.
  49. if (first_row_) {
  50. // Determine the column name map by calling the derived class method to retrieve the column
  51. // name map
  52. col_name_id_map_ = this->GetColumnNameMap();
  53. first_row_ = false;
  54. }
  55. // Populate the out map from the row and return it
  56. for (auto colMap : col_name_id_map_) {
  57. (*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
  58. }
  59. return Status::OK();
  60. }
  61. // Fetches one row of data from the iterator.
  62. // The base class version simply performs error handling and returns empty row. Actual
  63. // functionality exists in the derived versions of this function.
  64. Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
  65. if (out_row == nullptr) {
  66. RETURN_STATUS_UNEXPECTED("Null output row in iterator!");
  67. }
  68. // clear the old tensor row
  69. out_row->clear();
  70. return Status::OK();
  71. }
  72. // Constructor of the DatasetIterator
  73. DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree) : IteratorBase(), root_(exe_tree->root()) {}
  74. DatasetIterator::~DatasetIterator() = default;
  75. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  76. // from the tree root node directly.
  77. Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
  78. // Common code init and error checking in the base class.
  79. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  80. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  81. // want to iterate again.
  82. if (eof_handled_) {
  83. return Status::OK();
  84. }
  85. // Check if we need to get a new DataBuffer to iterate.
  86. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  87. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  88. // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
  89. // handle eoe and eof messages here.
  90. //
  91. // An eoe buffer means we have iterated fully to the end of the tree.
  92. // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
  93. // all operators.
  94. if (curr_buffer_->eoe()) {
  95. MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
  96. // Before returning the last empty vector, fetch the eof buffer which should be the last
  97. // buffer, and then free it.
  98. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  99. if (!curr_buffer_->eof()) {
  100. RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
  101. }
  102. eof_handled_ = true;
  103. curr_buffer_.reset(); // explicitly free the eof buffer
  104. return Status::OK();
  105. }
  106. if (curr_buffer_->eof()) {
  107. // An eof by itself, without being preceded by an eoe, is possible if a repeat operator
  108. // exists below us in the stack. Repeat operator eats eoe's but eventually allows the
  109. // flow of an eof up the pipeline by itself.
  110. eof_handled_ = true;
  111. curr_buffer_.reset(); // explicitly free the eof buffer
  112. return Status::OK();
  113. }
  114. }
  115. // If we got this far, now it's time to pop that next row for return to caller
  116. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  117. return Status::OK();
  118. }
  119. Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
  120. if (out_shapes == nullptr) {
  121. RETURN_STATUS_UNEXPECTED("Null output shape argument");
  122. }
  123. if (device_queue_row_.empty()) {
  124. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  125. }
  126. for (auto ts : device_queue_row_) {
  127. out_shapes->push_back(ts->shape());
  128. }
  129. return Status::OK();
  130. }
  131. Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
  132. if (out_types == nullptr) {
  133. RETURN_STATUS_UNEXPECTED("Null output type argument");
  134. }
  135. if (device_queue_row_.empty()) {
  136. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  137. }
  138. for (auto ts : device_queue_row_) {
  139. out_types->push_back(ts->type());
  140. }
  141. return Status::OK();
  142. }
  143. // Getter
  144. std::unordered_map<std::string, int32_t> DatasetIterator::GetColumnNameMap() const {
  145. return root_->column_name_id_map();
  146. }
  147. // Constructor of the ChildIterator
  148. ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx)
  149. : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {}
  150. ChildIterator::~ChildIterator() { current_op_ = nullptr; }
  151. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  152. // only from the child/worker id as given from the constructor.
  153. Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
  154. // Common code init and error checking in the base class.
  155. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  156. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  157. // want to iterate again.
  158. if (eof_handled_) {
  159. return Status::OK();
  160. }
  161. // Check if we need to get a new DataBuffer to iterate.
  162. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  163. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  164. // Unlike the DatasetIterator, this child iterator does not quit after eoe.
  165. // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
  166. // caller to decide what it wants to do next.
  167. if (curr_buffer_->eoe()) {
  168. MS_LOG(DEBUG) << "Child iterator picked up EOE.";
  169. end_epoch_ = true;
  170. return Status::OK();
  171. }
  172. if (curr_buffer_->eof()) {
  173. MS_LOG(DEBUG) << "Child iterator picked up EOF.";
  174. eof_handled_ = true;
  175. return Status::OK();
  176. }
  177. }
  178. // If we got this far, now it's time to pop that next row for return to caller
  179. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  180. return Status::OK();
  181. }
  182. // drain till the next eoe
  183. Status ChildIterator::Drain() {
  184. if (end_epoch_ == true) {
  185. // Calling drain against a child that is already at it's eoe state will not result in any action.
  186. // This allows you to do:
  187. // - fetch until empty row
  188. // - drain (will not actually drain because you are already at the end of the iteration)
  189. // However, the next time after that, it will perform it's normal draining activities.
  190. end_epoch_ = false;
  191. MS_LOG(DEBUG) << "No operation drain, already at end of epoch.";
  192. return Status::OK();
  193. }
  194. MS_LOG(DEBUG) << "Child draining buffers until eoe.";
  195. // else we drain until eoe or eof, eof here is for sanity check
  196. while (!curr_buffer_->eoe() && !curr_buffer_->eof()) {
  197. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  198. }
  199. if (curr_buffer_->eof()) {
  200. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
  201. }
  202. return Status::OK();
  203. }
  204. // Getter
  205. std::unordered_map<std::string, int32_t> ChildIterator::GetColumnNameMap() const {
  206. return current_op_->child(child_idx_)->column_name_id_map();
  207. }
  208. } // namespace dataset
  209. } // namespace mindspore