You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_iterator.h 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASET_ITERATOR_H_
  17. #define DATASET_ENGINE_DATASET_ITERATOR_H_
  18. #include <memory>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <vector>
  22. #include "dataset/util/status.h"
  23. #include "dataset/core/tensor.h"
  24. #include "dataset/engine/datasetops/dataset_op.h"
  25. #include "dataset/engine/execution_tree.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. using TensorMap = std::unordered_map<std::string, std::shared_ptr<Tensor>>;
  29. // forward declare
  30. class ExecutionTree;
  31. class DataBuffer;
  32. // IteratorBase class is used to iterate data from an executionTree one row at a time.
  33. // The base class provides the general interface, whereas derived classes provide slightly
  34. // different implementations.
  35. class IteratorBase {
  36. public:
  37. // Constructor of IteratorBase
  38. IteratorBase();
  39. // Destructor
  40. virtual ~IteratorBase();
  41. // Fetches one row of data from the iterator.
  42. // the base class version simply performs error handling and returns empty row. Actual
  43. // functionality exists in the derived versions of this function.
  44. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  45. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  46. // @return Status - The error code return
  47. // @note The position of a Tensor/column might be different from the initial column order
  48. // in the storageOp. User must be aware that MapOp, ZipOps, and others might change
  49. // the column ordering.
  50. virtual Status FetchNextTensorRow(TensorRow *out_row);
  51. // Fetches one row of data from the iterator as a column map.
  52. // @return A unordered map from column name to shared pointer to Tensor.
  53. Status GetNextAsMap(TensorMap *out_map);
  54. // Getter
  55. // @return T/F if this iterator is completely done after getting an eof
  56. bool eof_handled() const { return eof_handled_; }
  57. // Getter
  58. // @return The string to column id mapping.
  59. virtual std::unordered_map<std::string, int32_t> GetColumnNameMap() const = 0;
  60. protected:
  61. std::unique_ptr<DataBuffer> curr_buffer_; // holds the current buffer
  62. bool eof_handled_; // T/F if this op got an eof
  63. bool first_row_; // internal tracking for first row case
  64. std::unordered_map<std::string, int32_t> col_name_id_map_;
  65. };
  66. // The DatasetIterator derived class is for fetching rows off the end/root of the execution tree.
  67. class DatasetIterator : public IteratorBase {
  68. public:
  69. // Constructor of the DatasetIterator
  70. // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node.
  71. explicit DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree);
  72. // Destructor
  73. ~DatasetIterator();
  74. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  75. // from the tree root node directly.
  76. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  77. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  78. // @return Status - The error code return
  79. Status FetchNextTensorRow(TensorRow *out_row) override;
  80. // Fetches the next tensor row into device row, and returns it's shape.
  81. // @param out_shapes - A vector of tensor shapes (one shape per column)
  82. // @return Status - The error code return
  83. Status GetOutputShapes(std::vector<TensorShape> *out_shapes);
  84. // Fetches the next tensor row into device row, and returns it's shape.
  85. // @param outShapes - A vector of tensor shapes (one shape per column)
  86. // @return Status - The error code return
  87. Status GetOutputTypes(std::vector<DataType> *out_types);
  88. // Getter
  89. // @return The string to column id mapping.
  90. std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
  91. private:
  92. std::shared_ptr<DatasetOp> root_; // saves the root of the executionTree
  93. TensorRow device_queue_row_;
  94. };
  95. // The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree.
  96. // This one should only be used by internal Dataset operators, rather than an end-user.
  97. class ChildIterator : public IteratorBase {
  98. public:
  99. // Constructor of the DatasetIterator
  100. // @param current_op - The parent op from which we'll fetch from it's children.
  101. // @param worker_id - The worker id to use when fetching from the children.
  102. // @param child_idx - The index to the child to fetch from.
  103. ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx);
  104. // Destructor
  105. ~ChildIterator();
  106. // Fetches one row of data from the iterator. Overrides the base class. This one fetches
  107. // only from the child/worker id as given from the constructor.
  108. // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
  109. // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
  110. // @return Status - The error code return
  111. Status FetchNextTensorRow(TensorRow *out_row) override;
  112. // This function drains buffer until next eoe has been received.
  113. // It will be a no-op if the previous row returned is empty.
  114. // @return Status - The error code return
  115. Status Drain();
  116. // Getter
  117. // @return The string to column id mapping.
  118. std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
  119. private:
  120. DatasetOp *current_op_; // The parent operator. We consume from it's children.
  121. int32_t child_idx_; // The specific child this iterator will fetch from.
  122. int32_t worker_id_; // The worker id uses for fetching the child data.
  123. bool end_epoch_; // the flag used when an empty row has been returned.
  124. };
  125. } // namespace dataset
  126. } // namespace mindspore
  127. #endif // DATASET_ENGINE_DATASET_ITERATOR_H_