You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "include/api/status.h"
  22. namespace mindspore {
  23. namespace dataset {
  24. // Internal Sampler class forward declaration
  25. class SamplerRT;
  26. class SamplerObj {
  27. public:
  28. /// \brief Constructor
  29. SamplerObj();
  30. /// \brief Destructor
  31. ~SamplerObj() = default;
  32. /// \brief Pure virtual function for derived class to implement parameters validation
  33. /// \return The Status code of the function. It returns OK status if parameters are valid.
  34. virtual Status ValidateParams() = 0;
  35. /// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
  36. /// \return Shared pointers to the newly created Sampler
  37. virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
  38. /// \brief Pure virtual function to copy a SamplerObj class
  39. /// \return Shared pointers to the newly copied SamplerObj
  40. virtual std::shared_ptr<SamplerObj> SamplerCopy() = 0;
  41. /// \brief Function for derived class to get the shard id of sampler
  42. /// \return The shard id of the derived sampler
  43. virtual int64_t ShardId() { return 0; }
  44. /// \brief Adds a child to the sampler
  45. /// \param[in] child The sampler to be added as child
  46. /// \return the Status code returned
  47. Status AddChildSampler(std::shared_ptr<SamplerObj> child);
  48. std::vector<std::shared_ptr<SamplerObj>> GetChild() { return children_; }
  49. protected:
  50. /// \brief A function that calls build on the children of this sampler
  51. /// \param[in] sampler The samplerRT object built from this sampler
  52. void BuildChildren(std::shared_ptr<SamplerRT> sampler);
  53. std::vector<std::shared_ptr<SamplerObj>> children_;
  54. };
  55. class DistributedSamplerObj;
  56. class PKSamplerObj;
  57. class PreBuiltSamplerObj;
  58. class RandomSamplerObj;
  59. class SequentialSamplerObj;
  60. class SubsetSamplerObj;
  61. class SubsetRandomSamplerObj;
  62. class WeightedRandomSamplerObj;
  63. /// Function to create a Distributed Sampler.
  64. /// \notes A Sampler that access a shard of the dataset.
  65. /// \param[in] num_shards - Number of shards to divide the dataset into.
  66. /// \param[in] shard_id - Shard ID of the current shard within num_shards.
  67. /// \param[in] shuffle - If true, the indices are shuffled.
  68. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  69. /// \param[in] seed - The seed in use when shuffle is true.
  70. /// \param[in] offset - The starting position where access to elements in the dataset begins.
  71. /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
  72. /// If false the total rows returned by all the shards would not have overlap.
  73. /// \return Shared pointer to the current Sampler.
  74. std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true,
  75. int64_t num_samples = 0, uint32_t seed = 1,
  76. int64_t offset = -1, bool even_dist = true);
  77. /// Function to create a PK Sampler.
  78. /// \notes Samples K elements for each P class in the dataset.
  79. /// This will sample all classes.
  80. /// \param[in] num_val - Number of elements to sample for each class.
  81. /// \param[in] shuffle - If true, the class IDs are shuffled.
  82. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  83. /// \return Shared pointer to the current Sampler.
  84. std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
  85. /// Function to create a Random Sampler.
  86. /// \notes Samples the elements randomly.
  87. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  88. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  89. /// \return Shared pointer to the current Sampler.
  90. std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement = false, int64_t num_samples = 0);
  91. /// Function to create a Sequential Sampler.
  92. /// \notes Samples the dataset elements sequentially, same as not having a sampler.
  93. /// \param[in] start_index - Index to start sampling at (default to start at first id).
  94. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  95. /// \return Shared pointer to the current Sampler.
  96. std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
  97. /// Function to create a Subset Sampler.
  98. /// \notes Samples the elements from a sequence of indices.
  99. /// \param[in] indices - A vector sequence of indices.
  100. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  101. /// \return Shared pointer to the current Sampler.
  102. std::shared_ptr<SubsetSamplerObj> SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  103. /// Function to create a Subset Random Sampler.
  104. /// \notes Samples the elements randomly from a sequence of indices.
  105. /// \param[in] indices - A vector sequence of indices.
  106. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  107. /// \return Shared pointer to the current Sampler.
  108. std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  109. /// Function to create a Weighted Random Sampler.
  110. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
  111. /// weights (probabilities).
  112. /// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
  113. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  114. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  115. /// \return Shared pointer to the current Sampler.
  116. std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
  117. bool replacement = true);
  118. /* ####################################### Derived Sampler classes ################################# */
  119. class DistributedSamplerObj : public SamplerObj {
  120. public:
  121. DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed,
  122. int64_t offset, bool even_dist);
  123. virtual ~DistributedSamplerObj() = default;
  124. std::shared_ptr<SamplerRT> SamplerBuild() override;
  125. std::shared_ptr<SamplerObj> SamplerCopy() override {
  126. auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
  127. offset_, even_dist_);
  128. for (auto child : children_) {
  129. sampler->AddChildSampler(child);
  130. }
  131. return sampler;
  132. }
  133. Status ValidateParams() override;
  134. /// \brief Function to get the shard id of sampler
  135. /// \return The shard id of sampler
  136. int64_t ShardId() override { return shard_id_; }
  137. private:
  138. int64_t num_shards_;
  139. int64_t shard_id_;
  140. bool shuffle_;
  141. int64_t num_samples_;
  142. uint32_t seed_;
  143. int64_t offset_;
  144. bool even_dist_;
  145. };
  146. class PKSamplerObj : public SamplerObj {
  147. public:
  148. PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples);
  149. virtual ~PKSamplerObj() = default;
  150. std::shared_ptr<SamplerRT> SamplerBuild() override;
  151. std::shared_ptr<SamplerObj> SamplerCopy() override {
  152. auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
  153. for (auto child : children_) {
  154. sampler->AddChildSampler(child);
  155. }
  156. return sampler;
  157. }
  158. Status ValidateParams() override;
  159. private:
  160. int64_t num_val_;
  161. bool shuffle_;
  162. int64_t num_samples_;
  163. };
  164. class PreBuiltSamplerObj : public SamplerObj {
  165. public:
  166. explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
  167. ~PreBuiltSamplerObj() = default;
  168. std::shared_ptr<SamplerRT> SamplerBuild() override;
  169. std::shared_ptr<SamplerObj> SamplerCopy() override;
  170. Status ValidateParams() override;
  171. private:
  172. std::shared_ptr<SamplerRT> sp_;
  173. };
  174. class RandomSamplerObj : public SamplerObj {
  175. public:
  176. RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch = true);
  177. virtual ~RandomSamplerObj() = default;
  178. std::shared_ptr<SamplerRT> SamplerBuild() override;
  179. std::shared_ptr<SamplerObj> SamplerCopy() override {
  180. auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
  181. for (auto child : children_) {
  182. sampler->AddChildSampler(child);
  183. }
  184. return sampler;
  185. }
  186. Status ValidateParams() override;
  187. private:
  188. bool replacement_;
  189. int64_t num_samples_;
  190. bool reshuffle_each_epoch_;
  191. };
  192. class SequentialSamplerObj : public SamplerObj {
  193. public:
  194. SequentialSamplerObj(int64_t start_index, int64_t num_samples);
  195. virtual ~SequentialSamplerObj() = default;
  196. std::shared_ptr<SamplerRT> SamplerBuild() override;
  197. std::shared_ptr<SamplerObj> SamplerCopy() override {
  198. auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
  199. for (auto child : children_) {
  200. sampler->AddChildSampler(child);
  201. }
  202. return sampler;
  203. }
  204. Status ValidateParams() override;
  205. private:
  206. int64_t start_index_;
  207. int64_t num_samples_;
  208. };
  209. class SubsetSamplerObj : public SamplerObj {
  210. public:
  211. SubsetSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
  212. virtual ~SubsetSamplerObj() = default;
  213. std::shared_ptr<SamplerRT> SamplerBuild() override;
  214. std::shared_ptr<SamplerObj> SamplerCopy() override {
  215. auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
  216. for (auto child : children_) {
  217. sampler->AddChildSampler(child);
  218. }
  219. return sampler;
  220. }
  221. Status ValidateParams() override;
  222. protected:
  223. const std::vector<int64_t> indices_;
  224. int64_t num_samples_;
  225. };
  226. class SubsetRandomSamplerObj : public SubsetSamplerObj {
  227. public:
  228. SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
  229. ~SubsetRandomSamplerObj() = default;
  230. std::shared_ptr<SamplerRT> SamplerBuild() override;
  231. std::shared_ptr<SamplerObj> SamplerCopy() override {
  232. auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
  233. for (auto child : children_) {
  234. sampler->AddChildSampler(child);
  235. }
  236. return sampler;
  237. }
  238. private:
  239. };
  240. class WeightedRandomSamplerObj : public SamplerObj {
  241. public:
  242. explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
  243. virtual ~WeightedRandomSamplerObj() = default;
  244. std::shared_ptr<SamplerRT> SamplerBuild() override;
  245. std::shared_ptr<SamplerObj> SamplerCopy() override {
  246. auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
  247. for (auto child : children_) {
  248. sampler->AddChildSampler(child);
  249. }
  250. return sampler;
  251. }
  252. Status ValidateParams() override;
  253. private:
  254. const std::vector<double> weights_;
  255. int64_t num_samples_;
  256. bool replacement_;
  257. };
  258. } // namespace dataset
  259. } // namespace mindspore
  260. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_