You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shuffle_op.h 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_
  17. #define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_
  18. #include <map>
  19. #include <memory>
  20. #include <queue>
  21. #include <random>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <vector>
  25. #include "dataset/core/tensor.h"
  26. #include "dataset/core/tensor_shape.h"
  27. #include "dataset/engine/dataset_iterator.h"
  28. #include "dataset/engine/datasetops/pipeline_op.h"
  29. #include "dataset/util/status.h"
  30. namespace mindspore {
  31. namespace dataset {
  32. // Forward declare
  33. class ExecutionTree;
  34. class DbConnector;
  35. class DataBuffer;
  36. class ShuffleOp : public PipelineOp {
  37. // Shuffle buffer state flags
  38. //
  39. // Shuffle buffer is in a state of being initialized
  40. static constexpr int32_t kShuffleStateInit = 0;
  41. // Shuffle buffer is in a state of being actively drained from, but refilling as well
  42. static constexpr int32_t kShuffleStateActive = 1;
  43. // Shuffle buffer is in a state of being drained
  44. static constexpr int32_t kShuffleStateDrain = 2;
  45. public:
  46. // The nested builder class inside of the ShuffleOp is used to help manage all of the arguments
  47. // for constructing it. The shuffle op is fairly simple though, but the builder provides a
  48. // consistent look and feel for creators of Dataset operators overall.
  49. class Builder {
  50. public:
  51. // Builder constructor. Creates the builder object.
  52. // @note No default args
  53. // @return This is a constructor.
  54. Builder();
  55. // Default destructor
  56. ~Builder() = default;
  57. // Setter method.
  58. // @return Builder setter method returns reference to the builder.
  59. Builder &SetShuffleSize(int32_t shuffle_size) {
  60. build_shuffle_size_ = shuffle_size;
  61. return *this;
  62. }
  63. // Setter method.
  64. // @return Builder setter method returns reference to the builder.
  65. Builder &SetShuffleSeed(uint32_t shuffle_seed) {
  66. build_shuffle_seed_ = shuffle_seed;
  67. return *this;
  68. }
  69. // Setter method.
  70. // @return Builder setter method returns reference to the builder.
  71. Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
  72. build_rows_per_buffer_ = rows_per_buffer;
  73. return *this;
  74. }
  75. // Setter method.
  76. // @return Builder setter method returns reference to the builder.
  77. Builder &SetReshuffleEachEpoch(bool reshuffle_each_epoch) {
  78. build_reshuffle_each_epoch_ = reshuffle_each_epoch;
  79. return *this;
  80. }
  81. // Setter method.
  82. // @return Builder setter method returns reference to the builder.
  83. Builder &SetOpConnectorSize(int32_t op_connector_size) {
  84. build_op_connector_size_ = op_connector_size;
  85. return *this;
  86. }
  87. // The builder "build" method creates the final object.
  88. // @return shared_ptr to the new ShuffleOp object
  89. Status Build(std::shared_ptr<ShuffleOp> *);
  90. private:
  91. // The builder saves all ShuffleOp construction arguments internally.
  92. // The following are the arguments.
  93. int32_t build_shuffle_size_;
  94. uint32_t build_shuffle_seed_;
  95. int32_t build_rows_per_buffer_;
  96. bool build_reshuffle_each_epoch_;
  97. int32_t build_op_connector_size_;
  98. Status SanityCheck() const;
  99. };
  100. // Constructor of the ShuffleOp
  101. // @note The builder class should be used to call it
  102. // @param shuffle_size - The size for the shuffle buffer
  103. // @param shuffle_seed - The seed to use for random number generation
  104. // @param op_connector_size - The output connector queue size
  105. // @param rows_per_buffer - The requested number of rows per buffer
  106. ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch,
  107. int32_t rows_per_buffer);
  108. // Destructor
  109. ~ShuffleOp() = default;
  110. // A print method typically used for debugging
  111. // @param out - The output stream to write output to
  112. // @param show_all - A bool to control if you want to show all info or just a summary
  113. void Print(std::ostream &out, bool show_all) const override;
  114. // << Stream output operator overload
  115. // @notes This allows you to write the debug print info using stream operators
  116. // @param out - reference to the output stream being overloaded
  117. // @param so - reference to the ShuffleOp to display
  118. // @return - the output stream must be returned
  119. friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) {
  120. so.Print(out, false);
  121. return out;
  122. }
  123. // Class functor operator () override.
  124. // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
  125. // provide the master loop that drives the logic for performing the work
  126. // @return Status - The error code return
  127. Status operator()() override;
  128. // Base-class override for special eoe handler.
  129. // ShuffleOp must override this because it shall not perform default handling of eoe. Instead
  130. // the ShuffleOp needs to manage actions related to the end of the epoch itself.
  131. // @return Status - The error code return
  132. Status EoeReceived(int32_t worker_id) override;
  133. // Base-class override for NodePass visitor acceptor.
  134. // @param p - Pointer to the NodePass to be accepted.
  135. // @param modified - Whether this node visit modified the pipeline.
  136. // @return - Status of the node visit.
  137. Status Accept(NodePass *p, bool *modified) override;
  138. // Op name getter
  139. // @return Name of the current Op
  140. std::string Name() const override { return "ShuffleOp"; }
  141. private:
  142. // Private function to add a new row to the shuffle buffer.
  143. // @return Status - The error code return
  144. Status AddRowToShuffleBuffer(TensorRow new_shuffle_row);
  145. // Private function to populate the shuffle buffer initially by fetching from the child output
  146. // connector until the shuffle buffer is full (or there is no more data coming).
  147. // @return Status - The error code return
  148. Status InitShuffleBuffer();
  149. // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
  150. // itself rather than waiting for the reset driven from operators above it in the pipeline.
  151. // @return Status - The error code return
  152. Status SelfReset();
  153. int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows)
  154. uint32_t shuffle_seed_;
  155. bool reshuffle_each_epoch_;
  156. // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period.
  157. // specifically mt19937_64 is used to generate larger random numbers to reduce bias when
  158. // modding to fit within our desired range. we dont use a distribution
  159. // (ie uniform_int_distribution) because we will need to create up to |dataset| instances
  160. // of the distribution object in the common case of a perfect shuffle
  161. std::mt19937_64 rng_;
  162. int32_t buffer_counter_; // For creating new buffer id's
  163. int32_t rows_per_buffer_; // Number of rows to pack into output buffer
  164. // A single (potentially large) buffer of tensor rows for performing shuffling.
  165. std::unique_ptr<TensorTable> shuffle_buffer_;
  166. int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer
  167. int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work
  168. std::unique_ptr<ChildIterator> child_iterator_; // An iterator for fetching.
  169. };
  170. } // namespace dataset
  171. } // namespace mindspore
  172. #endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_