You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.h 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "include/api/status.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. class SamplerObj;
  25. // Abstract class to represent a sampler in the data pipeline.
  26. /// \class Sampler samplers.h
  27. /// \brief An abstract base class to represent a sampler in the data pipeline.
  28. class Sampler : std::enable_shared_from_this<Sampler> {
  29. friend class AlbumDataset;
  30. friend class MindDataDataset;
  31. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  32. public:
  33. /// \brief Constructor
  34. Sampler() {}
  35. /// \brief Destructor
  36. ~Sampler() = default;
  37. /// \brief A virtual function to add a child sampler.
  38. /// \param[in] child The child sampler to be added as a children of this sampler.
  39. virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }
  40. protected:
  41. /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
  42. /// \return shared pointer to the newly created TensorOperation.
  43. virtual std::shared_ptr<SamplerObj> Parse() = 0;
  44. std::vector<std::shared_ptr<Sampler>> children_;
  45. };
  46. /// \brief A class to represent a Distributed Sampler in the data pipeline.
  47. /// \notes A Sampler that accesses a shard of the dataset.
  48. class DistributedSampler : public Sampler {
  49. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  50. public:
  51. /// \brief Constructor
  52. /// \param[in] num_shards - Number of shards to divide the dataset into.
  53. /// \param[in] shard_id - Shard ID of the current shard within num_shards.
  54. /// \param[in] shuffle - If true, the indices are shuffled.
  55. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  56. /// \param[in] seed - The seed in use when shuffle is true.
  57. /// \param[in] offset - The starting position where access to elements in the dataset begins.
  58. /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
  59. /// If false the total rows returned by all the shards would not have overlap.
  60. explicit DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
  61. uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
  62. /// \brief Destructor.
  63. ~DistributedSampler() = default;
  64. protected:
  65. /// \brief Function to convert a Sampler into an IR SamplerObj.
  66. /// \return shared pointer to the newly created SamplerObj.
  67. std::shared_ptr<SamplerObj> Parse() override;
  68. private:
  69. int64_t num_shards_;
  70. int64_t shard_id_;
  71. bool shuffle_;
  72. int64_t num_samples_;
  73. uint32_t seed_;
  74. int64_t offset_;
  75. bool even_dist_;
  76. };
  77. /// \brief A class to represent a PK Sampler in the data pipeline.
  78. /// \notes Samples K elements for each P class in the dataset.
  79. /// This will sample all classes.
  80. class PKSampler : public Sampler {
  81. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  82. public:
  83. /// \brief Constructor
  84. /// \param[in] num_val - Number of elements to sample for each class.
  85. /// \param[in] shuffle - If true, the class IDs are shuffled.
  86. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  87. explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
  88. /// \brief Destructor.
  89. ~PKSampler() = default;
  90. protected:
  91. /// \brief Function to convert a Sampler into an IR SamplerObj.
  92. /// \return shared pointer to the newly created SamplerObj.
  93. std::shared_ptr<SamplerObj> Parse() override;
  94. private:
  95. int64_t num_val_;
  96. bool shuffle_;
  97. int64_t num_samples_;
  98. };
  99. /// \brief A class to represent a Random Sampler in the data pipeline.
  100. /// \notes Samples the elements randomly.
  101. class RandomSampler : public Sampler {
  102. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  103. public:
  104. /// \brief Constructor
  105. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  106. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  107. explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
  108. /// \brief Destructor.
  109. ~RandomSampler() = default;
  110. protected:
  111. /// \brief Function to convert a Sampler into an IR SamplerObj.
  112. /// \return shared pointer to the newly created SamplerObj.
  113. std::shared_ptr<SamplerObj> Parse() override;
  114. private:
  115. bool replacement_;
  116. int64_t num_samples_;
  117. };
  118. /// \brief A class to represent a Sequential Sampler in the data pipeline.
  119. /// \notes Samples the dataset elements sequentially, same as not having a sampler.
  120. class SequentialSampler : public Sampler {
  121. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  122. public:
  123. /// \brief Constructor
  124. /// \param[in] start_index - Index to start sampling at (default to start at first id).
  125. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  126. explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
  127. /// \brief Destructor.
  128. ~SequentialSampler() = default;
  129. protected:
  130. /// \brief Function to convert a Sampler into an IR SamplerObj.
  131. /// \return shared pointer to the newly created SamplerObj.
  132. std::shared_ptr<SamplerObj> Parse() override;
  133. private:
  134. int64_t start_index_;
  135. int64_t num_samples_;
  136. };
  137. /// \brief A class to represent a Subset Sampler in the data pipeline.
  138. /// \notes Samples the elements from a sequence of indices.
  139. class SubsetSampler : public Sampler {
  140. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  141. public:
  142. /// \brief Constructor
  143. /// \param[in] indices - A vector sequence of indices.
  144. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  145. explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  146. /// \brief Destructor.
  147. ~SubsetSampler() = default;
  148. protected:
  149. /// \brief Function to convert a Sampler into an IR SamplerObj.
  150. /// \return shared pointer to the newly created SamplerObj.
  151. std::shared_ptr<SamplerObj> Parse() override;
  152. std::vector<int64_t> indices_;
  153. int64_t num_samples_;
  154. };
  155. /// \brief A class to represent a Subset Random Sampler in the data pipeline.
  156. /// \notes Samples the elements randomly from a sequence of indices.
  157. class SubsetRandomSampler : public SubsetSampler {
  158. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  159. public:
  160. /// \brief Constructor
  161. /// \param[in] indices - A vector sequence of indices.
  162. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  163. explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  164. /// \brief Destructor.
  165. ~SubsetRandomSampler() = default;
  166. protected:
  167. /// \brief Function to convert a Sampler into an IR SamplerObj.
  168. /// \return shared pointer to the newly created SamplerObj.
  169. std::shared_ptr<SamplerObj> Parse() override;
  170. };
  171. /// \brief A class to represent a Weighted Random Sampler in the data pipeline.
  172. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
  173. /// weights (probabilities).
  174. class WeightedRandomSampler : public Sampler {
  175. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  176. public:
  177. /// \brief Constructor
  178. /// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
  179. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  180. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  181. explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
  182. /// \brief Destructor.
  183. ~WeightedRandomSampler() = default;
  184. protected:
  185. /// \brief Function to convert a Sampler into an IR SamplerObj.
  186. /// \return shared pointer to the newly created SamplerObj.
  187. std::shared_ptr<SamplerObj> Parse() override;
  188. private:
  189. std::vector<double> weights_;
  190. int64_t num_samples_;
  191. bool replacement_;
  192. };
  193. } // namespace dataset
  194. } // namespace mindspore
  195. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_