You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.h 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  17. #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
  18. #include <memory>
  19. #include <vector>
  20. namespace mindspore {
  21. namespace dataset {
  22. // Forward declare
  23. class SamplerObj;
  24. // Abstract class to represent a sampler in the data pipeline.
  25. /// \class Sampler samplers.h
  26. /// \brief An abstract base class to represent a sampler in the data pipeline.
  27. class Sampler : std::enable_shared_from_this<Sampler> {
  28. friend class AlbumDataset;
  29. friend class CelebADataset;
  30. friend class Cifar10Dataset;
  31. friend class Cifar100Dataset;
  32. friend class CLUEDataset;
  33. friend class CocoDataset;
  34. friend class CSVDataset;
  35. friend class ImageFolderDataset;
  36. friend class ManifestDataset;
  37. friend class MindDataDataset;
  38. friend class MnistDataset;
  39. friend class RandomDataDataset;
  40. friend class TextFileDataset;
  41. friend class TFRecordDataset;
  42. friend class VOCDataset;
  43. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  44. public:
  45. /// \brief Constructor
  46. Sampler() {}
  47. /// \brief Destructor
  48. ~Sampler() = default;
  49. /// \brief A virtual function to add a child sampler.
  50. /// \param[in] child The child sampler to be added as a children of this sampler.
  51. virtual void AddChild(std::shared_ptr<Sampler> child) { children_.push_back(child); }
  52. protected:
  53. /// \brief Pure virtual function to convert a Sampler class into an IR Sampler object.
  54. /// \return shared pointer to the newly created TensorOperation.
  55. virtual std::shared_ptr<SamplerObj> Parse() const = 0;
  56. std::vector<std::shared_ptr<Sampler>> children_;
  57. };
  58. /// \brief A class to represent a Distributed Sampler in the data pipeline.
  59. /// \notes A Sampler that accesses a shard of the dataset.
  60. class DistributedSampler final : public Sampler {
  61. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  62. public:
  63. /// \brief Constructor
  64. /// \param[in] num_shards - Number of shards to divide the dataset into.
  65. /// \param[in] shard_id - Shard ID of the current shard within num_shards.
  66. /// \param[in] shuffle - If true, the indices are shuffled.
  67. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  68. /// \param[in] seed - The seed in use when shuffle is true.
  69. /// \param[in] offset - The starting position where access to elements in the dataset begins.
  70. /// \param[in] even_dist - If true, each shard would return the same number of rows (default to true).
  71. /// If false the total rows returned by all the shards would not have overlap.
  72. explicit DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
  73. uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
  74. /// \brief Destructor.
  75. ~DistributedSampler() = default;
  76. protected:
  77. /// \brief Function to convert a Sampler into an IR SamplerObj.
  78. /// \return shared pointer to the newly created SamplerObj.
  79. std::shared_ptr<SamplerObj> Parse() const override;
  80. private:
  81. int64_t num_shards_;
  82. int64_t shard_id_;
  83. bool shuffle_;
  84. int64_t num_samples_;
  85. uint32_t seed_;
  86. int64_t offset_;
  87. bool even_dist_;
  88. };
  89. /// \brief A class to represent a PK Sampler in the data pipeline.
  90. /// \notes Samples K elements for each P class in the dataset.
  91. /// This will sample all classes.
  92. class PKSampler final : public Sampler {
  93. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  94. public:
  95. /// \brief Constructor
  96. /// \param[in] num_val - Number of elements to sample for each class.
  97. /// \param[in] shuffle - If true, the class IDs are shuffled.
  98. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  99. explicit PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0);
  100. /// \brief Destructor.
  101. ~PKSampler() = default;
  102. protected:
  103. /// \brief Function to convert a Sampler into an IR SamplerObj.
  104. /// \return shared pointer to the newly created SamplerObj.
  105. std::shared_ptr<SamplerObj> Parse() const override;
  106. private:
  107. int64_t num_val_;
  108. bool shuffle_;
  109. int64_t num_samples_;
  110. };
  111. /// \brief A class to represent a Random Sampler in the data pipeline.
  112. /// \notes Samples the elements randomly.
  113. class RandomSampler final : public Sampler {
  114. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  115. public:
  116. /// \brief Constructor
  117. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  118. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  119. explicit RandomSampler(bool replacement = false, int64_t num_samples = 0);
  120. /// \brief Destructor.
  121. ~RandomSampler() = default;
  122. protected:
  123. /// \brief Function to convert a Sampler into an IR SamplerObj.
  124. /// \return shared pointer to the newly created SamplerObj.
  125. std::shared_ptr<SamplerObj> Parse() const override;
  126. private:
  127. bool replacement_;
  128. int64_t num_samples_;
  129. };
  130. /// \brief A class to represent a Sequential Sampler in the data pipeline.
  131. /// \notes Samples the dataset elements sequentially, same as not having a sampler.
  132. class SequentialSampler final : public Sampler {
  133. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  134. public:
  135. /// \brief Constructor
  136. /// \param[in] start_index - Index to start sampling at (default to start at first id).
  137. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  138. explicit SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0);
  139. /// \brief Destructor.
  140. ~SequentialSampler() = default;
  141. protected:
  142. /// \brief Function to convert a Sampler into an IR SamplerObj.
  143. /// \return shared pointer to the newly created SamplerObj.
  144. std::shared_ptr<SamplerObj> Parse() const override;
  145. private:
  146. int64_t start_index_;
  147. int64_t num_samples_;
  148. };
  149. /// \brief A class to represent a Subset Sampler in the data pipeline.
  150. /// \notes Samples the elements from a sequence of indices.
  151. class SubsetSampler : public Sampler {
  152. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  153. public:
  154. /// \brief Constructor
  155. /// \param[in] indices - A vector sequence of indices.
  156. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  157. explicit SubsetSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  158. /// \brief Destructor.
  159. ~SubsetSampler() = default;
  160. protected:
  161. /// \brief Function to convert a Sampler into an IR SamplerObj.
  162. /// \return shared pointer to the newly created SamplerObj.
  163. std::shared_ptr<SamplerObj> Parse() const override;
  164. std::vector<int64_t> indices_;
  165. int64_t num_samples_;
  166. };
  167. /// \brief A class to represent a Subset Random Sampler in the data pipeline.
  168. /// \notes Samples the elements randomly from a sequence of indices.
  169. class SubsetRandomSampler final : public SubsetSampler {
  170. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  171. public:
  172. /// \brief Constructor
  173. /// \param[in] indices - A vector sequence of indices.
  174. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  175. explicit SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
  176. /// \brief Destructor.
  177. ~SubsetRandomSampler() = default;
  178. protected:
  179. /// \brief Function to convert a Sampler into an IR SamplerObj.
  180. /// \return shared pointer to the newly created SamplerObj.
  181. std::shared_ptr<SamplerObj> Parse() const override;
  182. };
  183. /// \brief A class to represent a Weighted Random Sampler in the data pipeline.
  184. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
  185. /// weights (probabilities).
  186. class WeightedRandomSampler final : public Sampler {
  187. friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
  188. public:
  189. /// \brief Constructor
  190. /// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
  191. /// \param[in] num_samples - The number of samples to draw (default to all elements).
  192. /// \param[in] replacement - If true, put the sample ID back for the next draw.
  193. explicit WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
  194. /// \brief Destructor.
  195. ~WeightedRandomSampler() = default;
  196. protected:
  197. /// \brief Function to convert a Sampler into an IR SamplerObj.
  198. /// \return shared pointer to the newly created SamplerObj.
  199. std::shared_ptr<SamplerObj> Parse() const override;
  200. private:
  201. std::vector<double> weights_;
  202. int64_t num_samples_;
  203. bool replacement_;
  204. };
  205. } // namespace dataset
  206. } // namespace mindspore
  207. #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_