You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.py 12 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Sampler module provides several samplers to generate sampling data from dataset.
  17. There are following samplers: DistributedSampler, PKSampler, RandomSampler,
  18. SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
  19. User can also define custom sampler by extending from Sampler class.
  20. """
  21. import numpy as np
  22. import mindspore._c_dataengine as cde
  23. class Sampler:
  24. """
  25. Base class for user defined sampler.
  26. User defined sampler can be used with any existing dataset with sampler support.
  27. An required _iter_() method should by overridden by user for sample index generation.
  28. An optional reset() method can be overridden for per repeat reset,
  29. dataset_size and num_samples will be set by dataset once a dataset iterator is created.
  30. Examples:
  31. >>> import mindspore.dataset as ds
  32. >>>
  33. >>> class ReverseSampler(ds,Sampler):
  34. >>> def __iter__(self):
  35. >>> for i in range(self.dataset_size - 1, -1, -1):
  36. >>> yield i
  37. >>>
  38. >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
  39. """
  40. def __init__(self):
  41. self.dataset_size = 0
  42. self.num_samples = 0
  43. def __iter__(self):
  44. """
  45. User defined iterator, must be overridden.
  46. _handshake is guaranteed to be called prior to iterator construction
  47. """
  48. raise NotImplementedError
  49. def reset(self):
  50. """
  51. Per repeat reset callback, override this method if necessary
  52. """
  53. # Initialization handshake callback
  54. # Do not override this method!
  55. def _handshake(self, ds_size, num_samples):
  56. self.dataset_size = ds_size
  57. self.num_samples = num_samples
  58. # Indices fetcher
  59. # Do not override this method!
  60. def _get_indices(self):
  61. sampler_iter = iter(self)
  62. ret = []
  63. for _ in range(self.num_samples):
  64. try:
  65. idx = next(sampler_iter)
  66. ret.append(idx)
  67. except StopIteration:
  68. break
  69. return np.array(ret)
  70. # Instance fetcher
  71. # Do not override this method!
  72. def create(self):
  73. return cde.PythonSampler(self)
  74. class BuiltinSampler:
  75. """
  76. Base class for BuiltinSampler.
  77. User should not extend this class.
  78. """
  79. def __init__(self):
  80. pass
  81. def create(self):
  82. pass
  83. class DistributedSampler(BuiltinSampler):
  84. """
  85. Sampler that access a shard of the dataset.
  86. Args:
  87. num_shards (int): Number of shards to divide the dataset into.
  88. shard_id (int): Shard ID of the current shard within num_shards.
  89. shuffle (bool, optional): If true, the indices are shuffled (default=True).
  90. Examples:
  91. >>> import mindspore.dataset as ds
  92. >>>
  93. >>> dataset_dir = "path/to/imagefolder_directory"
  94. >>>
  95. >>> # creates a distributed sampler with 10 shards total. This shard is shard 5
  96. >>> sampler = ds.DistributedSampler(10, 5)
  97. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  98. Raises:
  99. ValueError: If num_shards is not positive.
  100. ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
  101. ValueError: If shuffle is not a boolean value.
  102. """
  103. def __init__(self, num_shards, shard_id, shuffle=True):
  104. if num_shards <= 0:
  105. raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards))
  106. if shard_id < 0 or shard_id >= num_shards:
  107. raise ValueError("shard_id is invalid, shard_id={}".format(shard_id))
  108. if not isinstance(shuffle, bool):
  109. raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
  110. self.num_shards = num_shards
  111. self.shard_id = shard_id
  112. self.shuffle = shuffle
  113. self.seed = 0
  114. super().__init__()
  115. def create(self):
  116. # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
  117. self.seed += 1
  118. return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
  119. class PKSampler(BuiltinSampler):
  120. """
  121. Samples K elements for each P class in the dataset.
  122. Args:
  123. num_val (int): Number of elements to sample for each class.
  124. num_class (int, optional): Number of classes to sample (default=None, all classes).
  125. shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
  126. class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
  127. Examples:
  128. >>> import mindspore.dataset as ds
  129. >>>
  130. >>> dataset_dir = "path/to/imagefolder_directory"
  131. >>>
  132. >>> # creates a PKSampler that will get 3 samples from every class.
  133. >>> sampler = ds.PKSampler(3)
  134. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  135. Raises:
  136. ValueError: If num_val is not positive.
  137. NotImplementedError: If num_class is not None.
  138. ValueError: If shuffle is not boolean.
  139. """
  140. def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
  141. if num_val <= 0:
  142. raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
  143. if num_class is not None:
  144. raise NotImplementedError
  145. if not isinstance(shuffle, bool):
  146. raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
  147. self.num_val = num_val
  148. self.shuffle = shuffle
  149. self.class_column = class_column # work for minddataset
  150. super().__init__()
  151. def create(self):
  152. return cde.PKSampler(self.num_val, self.shuffle)
  153. def _create_for_minddataset(self):
  154. if not self.class_column or not isinstance(self.class_column, str):
  155. raise ValueError("class_column should be a not empty string value, \
  156. but got class_column={}".format(class_column))
  157. return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
  158. class RandomSampler(BuiltinSampler):
  159. """
  160. Samples the elements randomly.
  161. Args:
  162. replacement (bool, optional): If True, put the sample ID back for the next draw (default=False).
  163. num_samples (int, optional): Number of elements to sample (default=None, all elements).
  164. Examples:
  165. >>> import mindspore.dataset as ds
  166. >>>
  167. >>> dataset_dir = "path/to/imagefolder_directory"
  168. >>>
  169. >>> # creates a RandomSampler
  170. >>> sampler = ds.RandomSampler()
  171. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  172. Raises:
  173. ValueError: If replacement is not boolean.
  174. ValueError: If num_samples is not positive.
  175. """
  176. def __init__(self, replacement=False, num_samples=None):
  177. if not isinstance(replacement, bool):
  178. raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
  179. if num_samples is not None:
  180. if num_samples <= 0:
  181. raise ValueError("num_samples should be a positive integer "
  182. "value, but got num_samples={}".format(num_samples))
  183. self.replacement = replacement
  184. self.num_samples = num_samples
  185. super().__init__()
  186. def create(self):
  187. # If num_samples is not specified, then call constructor #2
  188. if self.num_samples is None:
  189. return cde.RandomSampler(self.replacement)
  190. return cde.RandomSampler(self.replacement, self.num_samples)
  191. class SequentialSampler(BuiltinSampler):
  192. """
  193. Samples the dataset elements sequentially, same as not having a sampler.
  194. Examples:
  195. >>> import mindspore.dataset as ds
  196. >>>
  197. >>> dataset_dir = "path/to/imagefolder_directory"
  198. >>>
  199. >>> # creates a SequentialSampler
  200. >>> sampler = ds.SequentialSampler()
  201. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  202. """
  203. def create(self):
  204. return cde.SequentialSampler()
  205. class SubsetRandomSampler(BuiltinSampler):
  206. """
  207. Samples the elements randomly from a sequence of indices.
  208. Args:
  209. indices (list[int]): A sequence of indices.
  210. Examples:
  211. >>> import mindspore.dataset as ds
  212. >>>
  213. >>> dataset_dir = "path/to/imagefolder_directory"
  214. >>>
  215. >>> indices = [0, 1, 2, 3, 7, 88, 119]
  216. >>>
  217. >>> # creates a SubsetRandomSampler, will sample from the provided indices
  218. >>> sampler = ds.SubsetRandomSampler()
  219. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  220. """
  221. def __init__(self, indices):
  222. if not isinstance(indices, list):
  223. indices = [indices]
  224. self.indices = indices
  225. super().__init__()
  226. def create(self):
  227. return cde.SubsetRandomSampler(self.indices)
  228. def _create_for_minddataset(self):
  229. return cde.MindrecordSubsetRandomSampler(self.indices)
  230. class WeightedRandomSampler(BuiltinSampler):
  231. """
  232. Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
  233. Args:
  234. weights (list[float]): A sequence of weights, not necessarily summing up to 1.
  235. num_samples (int): Number of elements to sample.
  236. replacement (bool, optional): If True, put the sample ID back for the next draw (default=True).
  237. Examples:
  238. >>> import mindspore.dataset as ds
  239. >>>
  240. >>> dataset_dir = "path/to/imagefolder_directory"
  241. >>>
  242. >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
  243. >>>
  244. >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
  245. >>> sampler = ds.WeightedRandomSampler(weights, 4)
  246. >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler)
  247. Raises:
  248. ValueError: If num_samples is not positive.
  249. ValueError: If replacement is not boolean.
  250. """
  251. def __init__(self, weights, num_samples, replacement=True):
  252. if not isinstance(weights, list):
  253. weights = [weights]
  254. if num_samples <= 0:
  255. raise ValueError("num_samples should be a positive integer "
  256. "value, but got num_samples={}".format(num_samples))
  257. if not isinstance(replacement, bool):
  258. raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
  259. self.weights = weights
  260. self.num_samples = num_samples
  261. self.replacement = replacement
  262. super().__init__()
  263. def create(self):
  264. return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)