You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.h 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_
  18. #include <vector>
  19. #include <memory>
  20. namespace mindspore {
  21. namespace dataset {
  22. // Internal Sampler class forward declaration
  23. class Sampler;
  24. namespace api {
  25. class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
  26. public:
  27. SamplerObj();
  28. ~SamplerObj() = default;
  29. virtual std::shared_ptr<Sampler> Build() = 0;
  30. virtual bool ValidateParams() = 0;
  31. };
  32. class DistributedSamplerObj;
  33. class PKSamplerObj;
  34. class RandomSamplerObj;
  35. class SequentialSamplerObj;
  36. class SubsetRandomSamplerObj;
  37. class WeightedRandomSamplerObj;
  38. /// Function to create a Distributed Sampler.
  39. /// \notes A Sampler that access a shard of the dataset.
  40. /// \param[in] num_shards - Number of shards to divide the dataset into.
  41. /// \param[in] shard_id - Shard ID of the current shard within num_shards.
  42. /// \param[in] shuffle - If true, the indices are shuffled.
  43. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  44. /// \param[in] seed - The seed in use when shuffle is true.
  45. /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
  46. /// If false the total rows returned by all the shards would not have overlap.
  47. /// \return Shared pointer to the current Sampler.
  48. std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
  49. int64_t num_samples = 0, uint32_t seed = 1,
  50. bool even_dist = true);
  51. /// Function to create a PK Sampler.
  52. /// \notes Samples K elements for each P class in the dataset.
  53. /// This will sample all classes.
  54. /// \param[in] num_val - Number of elements to sample for each class.
  55. /// \param[in] shuffle - If true, the class IDs are shuffled.
  56. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  57. /// \return Shared pointer to the current Sampler.
  58. std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
  59. /// Function to create a Random Sampler.
  60. /// \notes Samples the elements randomly.
  61. /// \param[in] replacement - If True, put the sample ID back for the next draw.
  62. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  63. /// \return Shared pointer to the current Sampler.
  64. std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_t num_samples = 0);
  65. /// Function to create a Sequential Sampler.
  66. /// \notes Samples the dataset elements sequentially, same as not having a sampler.
  67. /// \param[in] start_index - Index to start sampling at (dafault to start at first id).
  68. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  69. /// \return Shared pointer to the current Sampler.
  70. std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
  71. /// Function to create a Subset Random Sampler.
  72. /// \notes Samples the elements randomly from a sequence of indices.
  73. /// \param[in] indices - A vector sequence of indices.
  74. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  75. /// \return Shared pointer to the current Sampler.
  76. std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  77. /// Function to create a Weighted Random Sampler.
  78. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
  79. /// weights (probabilities).
  80. /// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
  81. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  82. /// \param[in] replacement - If True, put the sample ID back for the next draw.
  83. /// \return Shared pointer to the current Sampler.
  84. std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
  85. bool replacement = true);
  86. /* ####################################### Derived Sampler classes ################################# */
  87. class DistributedSamplerObj : public SamplerObj {
  88. public:
  89. DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
  90. bool even_dist);
  91. ~DistributedSamplerObj() = default;
  92. std::shared_ptr<Sampler> Build() override;
  93. bool ValidateParams() override;
  94. private:
  95. int64_t num_shards_;
  96. int64_t shard_id_;
  97. bool shuffle_;
  98. int64_t num_samples_;
  99. uint32_t seed_;
  100. bool even_dist_;
  101. };
  102. class PKSamplerObj : public SamplerObj {
  103. public:
  104. PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
  105. ~PKSamplerObj() = default;
  106. std::shared_ptr<Sampler> Build() override;
  107. bool ValidateParams() override;
  108. private:
  109. int64_t num_val_;
  110. bool shuffle_;
  111. int64_t num_samples_;
  112. };
  113. class RandomSamplerObj : public SamplerObj {
  114. public:
  115. RandomSamplerObj(bool replacement, int64_t num_samples);
  116. ~RandomSamplerObj() = default;
  117. std::shared_ptr<Sampler> Build() override;
  118. bool ValidateParams() override;
  119. private:
  120. bool replacement_;
  121. int64_t num_samples_;
  122. };
  123. class SequentialSamplerObj : public SamplerObj {
  124. public:
  125. SequentialSamplerObj(int64_t start_index, int64_t num_samples);
  126. ~SequentialSamplerObj() = default;
  127. std::shared_ptr<Sampler> Build() override;
  128. bool ValidateParams() override;
  129. private:
  130. int64_t start_index_;
  131. int64_t num_samples_;
  132. };
  133. class SubsetRandomSamplerObj : public SamplerObj {
  134. public:
  135. SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
  136. ~SubsetRandomSamplerObj() = default;
  137. std::shared_ptr<Sampler> Build() override;
  138. bool ValidateParams() override;
  139. private:
  140. const std::vector<int64_t> indices_;
  141. int64_t num_samples_;
  142. };
  143. class WeightedRandomSamplerObj : public SamplerObj {
  144. public:
  145. explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
  146. ~WeightedRandomSamplerObj() = default;
  147. std::shared_ptr<Sampler> Build() override;
  148. bool ValidateParams() override;
  149. private:
  150. const std::vector<double> weights_;
  151. int64_t num_samples_;
  152. bool replacement_;
  153. };
  154. } // namespace api
  155. } // namespace dataset
  156. } // namespace mindspore
  157. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_SAMPLERS_H_