You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.cc 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/engine/dataset_iterator.h"
  17. #include <unordered_map>
  18. #include <utility>
  19. #include "dataset/core/data_type.h"
  20. #include "dataset/core/tensor.h"
  21. #include "dataset/core/tensor_shape.h"
  22. #include "dataset/engine/data_buffer.h"
  23. #include "dataset/engine/execution_tree.h"
  24. #include "dataset/util/status.h"
  25. #include "dataset/engine/datasetops/dataset_op.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Constructor of the IteratorBase
  29. IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false), first_row_(true) {}
  30. IteratorBase::~IteratorBase() = default;
  31. // Fetches one row of data from the iterator as a column map.
  32. Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
  33. if (out_map == nullptr) {
  34. RETURN_STATUS_UNEXPECTED("Null output map in iterator!");
  35. }
  36. out_map->clear();
  37. TensorRow curr_row;
  38. RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
  39. // Return empty map if there's no data
  40. if (curr_row.empty()) {
  41. return Status::OK();
  42. }
  43. // The column name mapping is needed to be able to produce the tensor map output.
  44. // The column name mapping comes from the source operator that is producing the data into the iterator.
  45. // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping
  46. // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping.
  47. // Note: This can only be done after the first row has been produced, as this guarantees the the child has
  48. // it's column mapping set up.
  49. if (first_row_) {
  50. // Determine the column name map by calling the derived class method to retrieve the column
  51. // name map
  52. col_name_id_map_ = this->GetColumnNameMap();
  53. first_row_ = false;
  54. }
  55. // Populate the out map from the row and return it
  56. for (auto colMap : col_name_id_map_) {
  57. (*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
  58. }
  59. return Status::OK();
  60. }
  61. // Fetches one row of data from the iterator.
  62. // The base class version simply performs error handling and returns empty row. Actual
  63. // functionality exists in the derived versions of this function.
  64. Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
  65. if (out_row == nullptr) {
  66. RETURN_STATUS_UNEXPECTED("Null output row in iterator!");
  67. }
  68. // clear the old tensor row
  69. out_row->clear();
  70. return Status::OK();
  71. }
  72. // Constructor of the DatasetIterator
  73. DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree)
  74. : IteratorBase(),
  75. root_(exe_tree->root()),
  76. tracing_(nullptr),
  77. cur_batch_num_(0),
  78. cur_connector_size_(0),
  79. cur_connector_capacity_(0) {
  80. std::shared_ptr<Tracing> node;
  81. Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
  82. if (s.IsOk()) {
  83. tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
  84. }
  85. }
  86. DatasetIterator::~DatasetIterator() = default;
  87. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  88. // from the tree root node directly.
  89. Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
  90. // Common code init and error checking in the base class.
  91. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  92. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  93. // want to iterate again.
  94. if (eof_handled_) {
  95. return Status::OK();
  96. }
  97. // Check if we need to get a new DataBuffer to iterate.
  98. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  99. if (tracing_ != nullptr) {
  100. cur_connector_size_ = root_->ConnectorSize();
  101. cur_connector_capacity_ = root_->ConnectorCapacity();
  102. }
  103. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  104. // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
  105. // handle eoe and eof messages here.
  106. //
  107. // An eoe buffer means we have iterated fully to the end of the tree.
  108. // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
  109. // all operators.
  110. if (curr_buffer_->eoe()) {
  111. MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
  112. // Before returning the last empty vector, fetch the eof buffer which should be the last
  113. // buffer, and then free it.
  114. RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
  115. if (!curr_buffer_->eof()) {
  116. RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
  117. }
  118. eof_handled_ = true;
  119. curr_buffer_.reset(); // explicitly free the eof buffer
  120. // Set tree to Finished state
  121. root_->Tree()->SetFinished();
  122. return Status::OK();
  123. }
  124. if (curr_buffer_->eof()) {
  125. // An eof by itself, without being preceded by an eoe, is possible if a repeat operator
  126. // exists below us in the stack. Repeat operator eats eoe's but eventually allows the
  127. // flow of an eof up the pipeline by itself.
  128. eof_handled_ = true;
  129. curr_buffer_.reset(); // explicitly free the eof buffer
  130. // Set tree to Finished state
  131. root_->Tree()->SetFinished();
  132. return Status::OK();
  133. }
  134. }
  135. // If we got this far, now it's time to pop that next row for return to caller
  136. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  137. if (tracing_ != nullptr) {
  138. cur_batch_num_++;
  139. tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
  140. }
  141. return Status::OK();
  142. }
  143. Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
  144. if (out_shapes == nullptr) {
  145. RETURN_STATUS_UNEXPECTED("Null output shape argument");
  146. }
  147. if (device_queue_row_.empty()) {
  148. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  149. }
  150. for (auto ts : device_queue_row_) {
  151. out_shapes->push_back(ts->shape());
  152. }
  153. return Status::OK();
  154. }
  155. Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
  156. if (out_types == nullptr) {
  157. RETURN_STATUS_UNEXPECTED("Null output type argument");
  158. }
  159. if (device_queue_row_.empty()) {
  160. RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
  161. }
  162. for (auto ts : device_queue_row_) {
  163. out_types->push_back(ts->type());
  164. }
  165. return Status::OK();
  166. }
  167. // Getter
  168. std::unordered_map<std::string, int32_t> DatasetIterator::GetColumnNameMap() const {
  169. return root_->column_name_id_map();
  170. }
  171. // Constructor of the ChildIterator
  172. ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx)
  173. : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {}
  174. ChildIterator::~ChildIterator() { current_op_ = nullptr; }
  175. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  176. // only from the child/worker id as given from the constructor.
  177. Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
  178. // Common code init and error checking in the base class.
  179. RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row));
  180. // Once eof is handled, always return empty row. Class must be destroyed and recreated if you
  181. // want to iterate again.
  182. if (eof_handled_) {
  183. return Status::OK();
  184. }
  185. // Check if we need to get a new DataBuffer to iterate.
  186. if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
  187. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  188. // Unlike the DatasetIterator, this child iterator does not quit after eoe.
  189. // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
  190. // caller to decide what it wants to do next.
  191. if (curr_buffer_->eoe()) {
  192. MS_LOG(DEBUG) << "Child iterator picked up EOE.";
  193. end_epoch_ = true;
  194. return Status::OK();
  195. }
  196. if (curr_buffer_->eof()) {
  197. MS_LOG(DEBUG) << "Child iterator picked up EOF.";
  198. eof_handled_ = true;
  199. return Status::OK();
  200. }
  201. }
  202. // If we got this far, now it's time to pop that next row for return to caller
  203. RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
  204. return Status::OK();
  205. }
  206. // drain till the next eoe
  207. Status ChildIterator::Drain() {
  208. if (end_epoch_ == true) {
  209. // Calling drain against a child that is already at it's eoe state will not result in any action.
  210. // This allows you to do:
  211. // - fetch until empty row
  212. // - drain (will not actually drain because you are already at the end of the iteration)
  213. // However, the next time after that, it will perform it's normal draining activities.
  214. end_epoch_ = false;
  215. MS_LOG(DEBUG) << "No operation drain, already at end of epoch.";
  216. return Status::OK();
  217. }
  218. MS_LOG(DEBUG) << "Child draining buffers until eoe.";
  219. // else we drain until eoe or eof, eof here is for sanity check
  220. while (!curr_buffer_->eoe() && !curr_buffer_->eof()) {
  221. RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
  222. }
  223. if (curr_buffer_->eof()) {
  224. return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain.");
  225. }
  226. return Status::OK();
  227. }
  228. // Getter
  229. std::unordered_map<std::string, int32_t> ChildIterator::GetColumnNameMap() const {
  230. return current_op_->child(child_idx_)->column_name_id_map();
  231. }
  232. } // namespace dataset
  233. } // namespace mindspore