You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.cc 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/dataset_iterator.h"
  17. #include <unordered_map>
  18. #include <utility>
  19. #include "dataset/core/data_type.h"
  20. #include "dataset/core/tensor.h"
  21. #include "dataset/core/tensor_shape.h"
  22. #include "dataset/engine/data_buffer.h"
  23. #include "dataset/engine/execution_tree.h"
  24. #include "dataset/util/status.h"
  25. #include "dataset/engine/datasetops/dataset_op.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Constructor of the IteratorBase
  29. IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {}
  30. IteratorBase::~IteratorBase() = default;
  31. // Fetches one row of data from the iterator as a column map.
  32. Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
  33. if (out_map == nullptr) {
  34. RETURN_STATUS_UNEXPECTED("Null output map in iterator!");
  35. }
  36. out_map->clear();
  37. TensorRow curr_row;
  38. RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
  39. // Return empty map if there's no data
  40. if (curr_row.empty()) {
  41. return Status::OK();
  42. }
  43. // The column name mapping is needed to be able to produce the tensor map output.
  44. // The column name mapping comes from the source operator that is producing the data into the iterator.
  45. // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping
  46. // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping.
  47. if (col_name_id_map_.empty()) {
  48. // Determine the column name map by calling the derived class method to retrieve the column
  49. // name map
  50. col_name_id_map_ = this->GetColumnNameMap();
  51. }
  52. // Populate the out map from the row and return it
  53. for (auto colMap : col_name_id_map_) {
  54. (*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
  55. }
  56. return Status::OK();
  57. }
  58. // Fetches one row of data from the iterator.
  59. // The base class version simply performs error handling and returns empty row. Actual
  60. // functionality exists in the derived versions of this function.
  61. Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
  62. if (out_row == nullptr) {
  63. RETURN_STATUS_UNEXPECTED("Null output row in iterator!");
  64. }
  65. // clear the old tensor row
  66. out_row->clear();
  67. return Status::OK();
  68. }
  69. // Constructor of the DatasetIterator
  70. DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree)
  71. : IteratorBase(),
  72. root_(exe_tree->root()),
  73. tracing_(nullptr),
  74. cur_batch_num_(0),
  75. cur_connector_size_(0),
  76. cur_connector_capacity_(0) {
  77. std::shared_ptr<Tracing> node;
  78. Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
  79. if (s.IsOk()) {
  80. tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
  81. }
  82. }
  83. DatasetIterator::~DatasetIterator() = default;
  84. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  85. // from the tree root node directly.
  86. Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
  87. // Common code init and error checking in the base class.
  88. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  89. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  90. // want to iterate again.
  91. if (eof_handled_) {
  92. return Status::OK();
  93. }
  94. // Check if we need to get a new DataBuffer to iterate.
  95. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  96. if (tracing_ != nullptr) {
  97. cur_connector_size_ = root_->ConnectorSize();
  98. cur_connector_capacity_ = root_->ConnectorCapacity();
  99. }
  100. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  101. // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
  102. // handle eoe and eof messages here.
  103. //
  104. // An eoe buffer means we have iterated fully to the end of the tree.
  105. // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
  106. // all operators.
  107. if (curr_buffer_->eoe()) {
  108. MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
  109. // Before returning the last empty vector, fetch the eof buffer which should be the last
  110. // buffer, and then free it.
  111. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  112. if (!curr_buffer_->eof()) {
  113. RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
  114. }
  115. eof_handled_ = true;
  116. curr_buffer_.reset(); // explicitly free the eof buffer
  117. // Set tree to Finished state
  118. root_->Tree()->SetFinished();
  119. return Status::OK();
  120. }
  121. if (curr_buffer_->eof()) {
  122. // An eof by itself, without being preceded by an eoe, is possible if a repeat operator
  123. // exists below us in the stack. Repeat operator eats eoe's but eventually allows the
  124. // flow of an eof up the pipeline by itself.
  125. eof_handled_ = true;
  126. curr_buffer_.reset(); // explicitly free the eof buffer
  127. // Set tree to Finished state
  128. root_->Tree()->SetFinished();
  129. return Status::OK();
  130. }
  131. }
  132. // If we got this far, now it's time to pop that next row for return to caller
  133. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  134. if (tracing_ != nullptr) {
  135. cur_batch_num_++;
  136. tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
  137. }
  138. return Status::OK();
  139. }
  140. Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
  141. if (out_shapes == nullptr) {
  142. RETURN_STATUS_UNEXPECTED("Null output shape argument");
  143. }
  144. if (device_queue_row_.empty()) {
  145. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  146. }
  147. for (auto ts : device_queue_row_) {
  148. out_shapes->push_back(ts->shape());
  149. }
  150. return Status::OK();
  151. }
  152. Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
  153. if (out_types == nullptr) {
  154. RETURN_STATUS_UNEXPECTED("Null output type argument");
  155. }
  156. if (device_queue_row_.empty()) {
  157. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  158. }
  159. for (auto ts : device_queue_row_) {
  160. out_types->push_back(ts->type());
  161. }
  162. return Status::OK();
  163. }
  164. // Getter
  165. std::unordered_map<std::string, int32_t> DatasetIterator::GetColumnNameMap() const {
  166. return root_->column_name_id_map();
  167. }
  168. // Constructor of the ChildIterator
  169. ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx)
  170. : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {}
  171. ChildIterator::~ChildIterator() { current_op_ = nullptr; }
  172. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  173. // only from the child/worker id as given from the constructor.
  174. Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
  175. // Common code init and error checking in the base class.
  176. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  177. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  178. // want to iterate again.
  179. if (eof_handled_) {
  180. return Status::OK();
  181. }
  182. // Check if we need to get a new DataBuffer to iterate.
  183. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  184. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  185. // Unlike the DatasetIterator, this child iterator does not quit after eoe.
  186. // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
  187. // caller to decide what it wants to do next.
  188. if (curr_buffer_->eoe()) {
  189. MS_LOG(DEBUG) << "Child iterator picked up EOE.";
  190. end_epoch_ = true;
  191. return Status::OK();
  192. }
  193. if (curr_buffer_->eof()) {
  194. MS_LOG(DEBUG) << "Child iterator picked up EOF.";
  195. eof_handled_ = true;
  196. return Status::OK();
  197. }
  198. }
  199. // If we got this far, now it's time to pop that next row for return to caller
  200. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  201. return Status::OK();
  202. }
  203. // drain till the next eoe
  204. Status ChildIterator::Drain() {
  205. if (end_epoch_ == true) {
  206. // Calling drain against a child that is already at it's eoe state will not result in any action.
  207. // This allows you to do:
  208. // - fetch until empty row
  209. // - drain (will not actually drain because you are already at the end of the iteration)
  210. // However, the next time after that, it will perform it's normal draining activities.
  211. end_epoch_ = false;
  212. MS_LOG(DEBUG) << "No operation drain, already at end of epoch.";
  213. return Status::OK();
  214. }
  215. MS_LOG(DEBUG) << "Child draining buffers until eoe.";
  216. // else we drain until eoe or eof, eof here is for sanity check
  217. while (!curr_buffer_->eoe() && !curr_buffer_->eof()) {
  218. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  219. }
  220. if (curr_buffer_->eof()) {
  221. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
  222. }
  223. return Status::OK();
  224. }
  225. // Getter
  226. std::unordered_map<std::string, int32_t> ChildIterator::GetColumnNameMap() const {
  227. return current_op_->child(child_idx_)->column_name_id_map();
  228. }
  229. } // namespace dataset
  230. } // namespace mindspore