You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.h 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASET_ITERATOR_H_
  17. #define DATASET_ENGINE_DATASET_ITERATOR_H_
  18. #include <memory>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <vector>
  22. #include "dataset/util/status.h"
  23. #include "dataset/core/tensor.h"
  24. #include "dataset/engine/datasetops/dataset_op.h"
  25. #include "dataset/engine/execution_tree.h"
  26. #include "dataset/engine/perf/dataset_iterator_tracing.h"
  27. namespace mindspore {
  28. namespace dataset {
  29. using TensorMap = std::unordered_map<std::string, std::shared_ptr<Tensor>>;
  30. // forward declare
  31. class ExecutionTree;
  32. class DataBuffer;
  33. // IteratorBase class is used to iterate data from an executionTree one row at a time.
  34. // The base class provides the general interface, whereas derived classes provide slightly
  35. // different implementations.
  36. class IteratorBase {
  37. public:
  38. // Constructor of IteratorBase
  39. IteratorBase();
  40. // Destructor
  41. virtual ~IteratorBase();
  42. // Fetches one row of data from the iterator.
  43. // the base class version simply performs error handling and returns empty row. Actual
  44. // functionality exists in the derived versions of this function.
  45. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  46. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  47. // @return Status - The error code return
  48. // @note The position of a Tensor/column might be different from the initial column order
  49. // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
  50. // the column ordering.
  51. virtual Status FetchNextTensorRow(TensorRow *out_row);
  52. // Fetches one row of data from the iterator as a column map.
  53. // @return A unordered map from column name to shared pointer to Tensor.
  54. Status GetNextAsMap(TensorMap *out_map);
  55. // Getter
  56. // @return T/F if this iterator is completely done after getting an eof
  57. bool eof_handled() const { return eof_handled_; }
  58. // Getter
  59. // @return The string to column id mapping.
  60. virtual std::unordered_map<std::string, int32_t> GetColumnNameMap() const = 0;
  61. protected:
  62. std::unique_ptr<DataBuffer> curr_buffer_; // holds the current buffer
  63. bool eof_handled_; // T/F if this op got an eof
  64. bool first_row_; // internal tracking for first row case
  65. std::unordered_map<std::string, int32_t> col_name_id_map_;
  66. };
  67. // The DatasetIterator derived class is for fetching rows off the end/root of the execution tree.
  68. class DatasetIterator : public IteratorBase {
  69. public:
  70. // Constructor of the DatasetIterator
  71. // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node.
  72. explicit DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree);
  73. // Destructor
  74. ~DatasetIterator();
  75. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  76. // from the tree root node directly.
  77. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  78. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  79. // @return Status - The error code return
  80. Status FetchNextTensorRow(TensorRow *out_row) override;
  81. // Fetches the next tensor row into device row, and returns it's shape.
  82. // @param out_shapes - A vector of tensor shapes (one shape per column)
  83. // @return Status - The error code return
  84. Status GetOutputShapes(std::vector<TensorShape> *out_shapes);
  85. // Fetches the next tensor row into device row, and returns it's shape.
  86. // @param outShapes - A vector of tensor shapes (one shape per column)
  87. // @return Status - The error code return
  88. Status GetOutputTypes(std::vector<DataType> *out_types);
  89. // Getter
  90. // @return The string to column id mapping.
  91. std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
  92. private:
  93. std::shared_ptr<DatasetOp> root_; // saves the root of the executionTree
  94. TensorRow device_queue_row_;
  95. std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
  96. int32_t cur_batch_num_; // current batch number,used for profiling
  97. int32_t cur_connector_size_; // current connector size of root op,used for profiling
  98. int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
  99. };
  100. // The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree.
  101. // This one should only be used by internal Dataset operators, rather than an end-user.
  102. class ChildIterator : public IteratorBase {
  103. public:
  104. // Constructor of the DatasetIterator
  105. // @param current_op - The parent op from which we'll fetch from it's children.
  106. // @param worker_id - The worker id to use when fetching from the children.
  107. // @param child_idx - The index to the child to fetch from.
  108. ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx);
  109. // Destructor
  110. ~ChildIterator();
  111. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  112. // only from the child/worker id as given from the constructor.
  113. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  114. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  115. // @return Status - The error code return
  116. Status FetchNextTensorRow(TensorRow *out_row) override;
  117. // This function drains buffer until next eoe has been received.
  118. // It will be a no-op if the previous row returned is empty.
  119. // @return Status - The error code return
  120. Status Drain();
  121. // Getter
  122. // @return The string to column id mapping.
  123. std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
  124. private:
  125. DatasetOp *current_op_; // The parent operator. We consume from it's children.
  126. int32_t child_idx_; // The specific child this iterator will fetch from.
  127. int32_t worker_id_; // The worker id uses for fetching the child data.
  128. bool end_epoch_; // the flag used when an empty row has been returned.
  129. };
  130. } // namespace dataset
  131. } // namespace mindspore
  132. #endif // DATASET_ENGINE_DATASET_ITERATOR_H_