You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iterator.h 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <vector>
  23. #include "include/api/dual_abi_helper.h"
  24. #include "include/api/status.h"
  25. #include "include/api/types.h"
  26. namespace mindspore {
  27. namespace dataset {
  28. // Forward declare
  29. class ExecutionTree;
  30. class DatasetIterator;
  31. class DatasetOp;
  32. class Tensor;
  33. class NativeRuntimeContext;
  34. class IteratorConsumer;
  35. class PullBasedIteratorConsumer;
  36. class Dataset;
  37. using MSTensorMap = std::unordered_map<std::string, mindspore::MSTensor>;
  38. using MSTensorMapChar = std::map<std::vector<char>, mindspore::MSTensor>;
  39. using MSTensorVec = std::vector<mindspore::MSTensor>;
  40. // Abstract class for iterating over the dataset.
  41. class Iterator {
  42. public:
  43. /// \brief Constructor
  44. Iterator();
  45. /// \brief Destructor
  46. ~Iterator();
  47. /// \brief Method for building and launching the pipeline.
  48. /// \param[in] ops - a vector of DatasetOp in the data pipeline.
  49. /// \param[in] num_epochs Number of epochs passed down to EpochCtrlNode, default -1, infinite epochs
  50. /// \return - a Status error code, returns OK if no error encountered.
  51. Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds, int32_t num_epochs);
  52. /// \brief Function to get the next row from the data pipeline.
  53. /// \note Type of return data is a map(with column name).
  54. /// \param[out] row - the output tensor row.
  55. /// \return - a Status error code, returns OK if no error encountered.
  56. Status GetNextRow(MSTensorMap *row) {
  57. MSTensorMapChar row_;
  58. row_.clear();
  59. row->clear();
  60. Status s = GetNextRowCharIF(&row_);
  61. TensorMapCharToString(&row_, row);
  62. return s;
  63. }
  64. // Char interface(CharIF) of GetNextRow
  65. // This api exists because std::string will constrained by ABI compile macro but char don't.
  66. Status GetNextRowCharIF(MSTensorMapChar *row);
  67. /// \brief Function to get the next row from the data pipeline.
  68. /// \note Type of return data is a vector(without column name).
  69. /// \param[out] row - the output tensor row.
  70. /// \return - a Status error code, returns OK if no error encountered.
  71. virtual Status GetNextRow(MSTensorVec *row);
  72. /// \brief Function to shut down the data pipeline.
  73. void Stop();
  74. class _Iterator {
  75. public:
  76. explicit _Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr} {
  77. if (lt_) {
  78. cur_row_ = new MSTensorMap();
  79. lt_->GetNextRow(cur_row_);
  80. }
  81. }
  82. // Destructor
  83. ~_Iterator() {
  84. if (cur_row_) {
  85. delete cur_row_;
  86. }
  87. }
  88. _Iterator &operator++() {
  89. if (lt_) {
  90. ++ind_;
  91. lt_->GetNextRow(cur_row_);
  92. }
  93. if (cur_row_ && cur_row_->size() == 0) {
  94. delete cur_row_;
  95. cur_row_ = nullptr;
  96. }
  97. return *this;
  98. } // prefix ++ overload
  99. MSTensorMap &operator*() { return *cur_row_; } // dereference operator
  100. MSTensorMap *operator->() { return cur_row_; }
  101. bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; }
  102. private:
  103. int ind_; // the cur node our Iterator points to
  104. Iterator *lt_;
  105. MSTensorMap *cur_row_;
  106. };
  107. _Iterator begin() { return _Iterator(this); }
  108. _Iterator end() { return _Iterator(nullptr); }
  109. private:
  110. std::unique_ptr<NativeRuntimeContext> runtime_context_;
  111. IteratorConsumer *consumer_;
  112. };
  113. class PullIterator : public Iterator {
  114. public:
  115. /// \brief Constructor
  116. PullIterator();
  117. /// \brief Function to get next row from the data pipeline.
  118. /// \note Type of return data is a vector(without column name).
  119. /// \param[out] row - the output tensor row.
  120. /// \return Returns true if no error encountered else false.
  121. Status GetNextRow(MSTensorVec *row) override;
  122. /// \brief Function to get specified rows from the data pipeline.
  123. /// \note Type of return data is a vector(without column name).
  124. /// \note This behavior is subject to change
  125. /// \param[in] num_rows - the number of rows to fetch.
  126. /// \param[out] row - the output tensor row.
  127. /// \return Returns true if no error encountered else false.
  128. Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *row);
  129. /// \brief Method for building and launching the pipeline.
  130. /// \note Consider making this function protected.
  131. /// \param[in] ds - The root node that calls the function
  132. /// \return - a Status error code, returns OK if no error encountered.
  133. Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
  134. private:
  135. std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
  136. };
  137. } // namespace dataset
  138. } // namespace mindspore
  139. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_