You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

samplers.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. The sampler module provides several samplers to generate data from datasets.
  17. The provided samplers include: DistributedSampler, PKSampler, RandomSampler,
  18. SequentialSampler, SubsetRandomSampler, and WeightedRandomSampler.
  19. Users can also define a custom sampler by extending from the Sampler class.
  20. """
  21. import numpy as np
  22. import mindspore._c_dataengine as cde
  23. import mindspore.dataset as ds
  24. class Sampler:
  25. """
  26. Base class for user defined sampler.
  27. A user defined sampler can be used with any existing dataset with sampler support.
  28. A required _iter_() method should by overridden by the user for sample index generation.
  29. An optional reset() method can be overridden for per repeat reset,
  30. dataset_size and num_samples will be set by dataset once a dataset iterator is created.
  31. Examples:
  32. >>> import mindspore.dataset as ds
  33. >>>
  34. >>> class ReverseSampler(ds,Sampler):
  35. >>> def __iter__(self):
  36. >>> for i in range(self.dataset_size - 1, -1, -1):
  37. >>> yield i
  38. >>>
  39. >>> ds = ds.ImageFolderDataset(path, sampler=ReverseSampler())
  40. """
  41. def __init__(self, num_samples=None):
  42. self.dataset_size = 0
  43. self.child_sampler = None
  44. self.num_samples = num_samples
  45. def __iter__(self):
  46. """
  47. User defined iterator, must be overridden.
  48. _handshake is guaranteed to be called prior to iterator construction.
  49. """
  50. raise NotImplementedError
  51. def reset(self):
  52. """
  53. Per repeat reset callback, override this method if necessary
  54. """
  55. # Initialization handshake callback
  56. # Do not override this method!
  57. def _handshake(self, ds_size, num_samples):
  58. self.dataset_size = ds_size
  59. self.num_samples = num_samples
  60. # Indices fetcher
  61. # Do not override this method!
  62. def _get_indices(self):
  63. sampler_iter = iter(self)
  64. ret = []
  65. for _ in range(self.num_samples):
  66. try:
  67. idx = next(sampler_iter)
  68. ret.append(idx)
  69. except StopIteration:
  70. break
  71. return np.array(ret)
  72. # Instance fetcher
  73. # Do not override this method!
  74. def create(self):
  75. num_samples = self.num_samples if self.num_samples is not None else 0
  76. c_sampler = cde.PythonSampler(num_samples, self)
  77. c_child_sampler = self.create_child()
  78. c_sampler.add_child(c_child_sampler)
  79. return c_sampler
  80. def add_child(self, sampler):
  81. self.child_sampler = sampler
  82. def get_child(self):
  83. return self.child_sampler
  84. def create_child(self):
  85. c_child_sampler = None
  86. if self.child_sampler is not None:
  87. c_child_sampler = self.child_sampler.create()
  88. return c_child_sampler
  89. def is_shuffled(self):
  90. if self.child_sampler is None:
  91. return False
  92. return self.child_sampler.is_shuffled()
  93. def is_sharded(self):
  94. if self.child_sampler is None:
  95. return False
  96. return self.child_sampler.is_sharded()
  97. def get_num_samples(self):
  98. if self.num_samples is None:
  99. return None
  100. return self._get_indices().size
  101. class BuiltinSampler:
  102. """
  103. Base class for BuiltinSampler.
  104. User should not extend this class.
  105. """
  106. def __init__(self, num_samples=None):
  107. self.child_sampler = None
  108. self.num_samples = num_samples
  109. def create(self):
  110. pass
  111. def add_child(self, sampler):
  112. self.child_sampler = sampler
  113. def get_child(self):
  114. return self.child_sampler
  115. def create_child(self):
  116. c_child_sampler = None
  117. if self.child_sampler is not None:
  118. c_child_sampler = self.child_sampler.create()
  119. return c_child_sampler
  120. def create_child_for_minddataset(self):
  121. c_child_sampler = None
  122. if self.child_sampler is not None:
  123. c_child_sampler = self.child_sampler.create_for_minddataset()
  124. return c_child_sampler
  125. def is_shuffled(self):
  126. raise NotImplementedError("Sampler must implement is_shuffled.")
  127. def is_sharded(self):
  128. raise NotImplementedError("Sampler must implement is_sharded.")
  129. def get_num_samples(self):
  130. """
  131. All samplers can contain a numeric num_samples value (or it can be set to None).
  132. A child sampler can exist or be None.
  133. If a child sampler exists, then the child sampler count can be a numeric value or None.
  134. These conditions impact the resultant sampler count that is used.
  135. The following table shows the possible results from calling this function.
  136. .. list-table::
  137. :widths: 25 25 25 25
  138. :header-rows: 1
  139. * - child sampler
  140. - num_samples
  141. - child_samples
  142. - result
  143. * - T
  144. - x
  145. - y
  146. - min(x, y)
  147. * - T
  148. - x
  149. - None
  150. - x
  151. * - T
  152. - None
  153. - y
  154. - y
  155. * - T
  156. - None
  157. - None
  158. - None
  159. * - None
  160. - x
  161. - n/a
  162. - x
  163. * - None
  164. - None
  165. - n/a
  166. - None
  167. Returns:
  168. int, The number of samples, or None
  169. """
  170. if self.child_sampler is not None:
  171. child_samples = self.child_sampler.get_num_samples()
  172. if self.num_samples is not None:
  173. if child_samples is not None:
  174. return min(self.num_samples, child_samples)
  175. return self.num_samples
  176. return child_samples
  177. return self.num_samples
  178. class DistributedSampler(BuiltinSampler):
  179. """
  180. A sampler that accesses a shard of the dataset.
  181. Args:
  182. num_shards (int): Number of shards to divide the dataset into.
  183. shard_id (int): Shard ID of the current shard within num_shards.
  184. shuffle (bool, optional): If True, the indices are shuffled (default=True).
  185. num_samples (int, optional): The number of samples to draw (default=None, all elements).
  186. offset(int, optional): The starting sample ID where access to elements in the dataset begins (default=-1).
  187. Examples:
  188. >>> import mindspore.dataset as ds
  189. >>>
  190. >>> dataset_dir = "path/to/imagefolder_directory"
  191. >>>
  192. >>> # creates a distributed sampler with 10 shards in total. This shard is shard 5.
  193. >>> sampler = ds.DistributedSampler(10, 5)
  194. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  195. Raises:
  196. ValueError: If num_shards is not positive.
  197. ValueError: If shard_id is smaller than 0 or equal to num_shards or larger than num_shards.
  198. ValueError: If shuffle is not a boolean value.
  199. """
  200. def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None, offset=-1):
  201. if num_shards <= 0:
  202. raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards))
  203. if shard_id < 0 or shard_id >= num_shards:
  204. raise ValueError("shard_id is invalid, shard_id={}".format(shard_id))
  205. if not isinstance(shuffle, bool):
  206. raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
  207. if num_samples is not None:
  208. if num_samples <= 0:
  209. raise ValueError("num_samples should be a positive integer "
  210. "value, but got num_samples={}".format(num_samples))
  211. self.num_shards = num_shards
  212. self.shard_id = shard_id
  213. self.shuffle = shuffle
  214. self.seed = 0
  215. self.offset = offset
  216. super().__init__(num_samples)
  217. def create(self):
  218. num_samples = self.num_samples if self.num_samples is not None else 0
  219. # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle
  220. self.seed += 1
  221. c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id,
  222. self.shuffle, self.seed, self.offset)
  223. c_child_sampler = self.create_child()
  224. c_sampler.add_child(c_child_sampler)
  225. return c_sampler
  226. def create_for_minddataset(self):
  227. num_samples = self.num_samples if self.num_samples is not None else 0
  228. c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
  229. self.seed, num_samples, self.offset)
  230. c_child_sampler = self.create_child_for_minddataset()
  231. c_sampler.add_child(c_child_sampler)
  232. return c_sampler
  233. def is_shuffled(self):
  234. if self.child_sampler is None:
  235. return self.shuffle
  236. return self.child_sampler.is_shuffled()
  237. def is_sharded(self):
  238. if self.child_sampler is None:
  239. return self.num_shards > 1
  240. return self.child_sampler.is_sharded()
  241. def set_offset(self, offset):
  242. self.offset = offset
  243. return self
  244. class PKSampler(BuiltinSampler):
  245. """
  246. Samples K elements for each P class in the dataset.
  247. Args:
  248. num_val (int): Number of elements to sample for each class.
  249. num_class (int, optional): Number of classes to sample (default=None, all classes).
  250. shuffle (bool, optional): If True, the class IDs are shuffled (default=False).
  251. class_column (str, optional): Name of column with class labels for MindDataset (default='label').
  252. num_samples (int, optional): The number of samples to draw (default=None, all elements).
  253. Examples:
  254. >>> import mindspore.dataset as ds
  255. >>>
  256. >>> dataset_dir = "path/to/imagefolder_directory"
  257. >>>
  258. >>> # creates a PKSampler that will get 3 samples from every class.
  259. >>> sampler = ds.PKSampler(3)
  260. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  261. Raises:
  262. ValueError: If num_val is not positive.
  263. NotImplementedError: If num_class is not None.
  264. ValueError: If shuffle is not boolean.
  265. """
  266. def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None):
  267. if num_val <= 0:
  268. raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
  269. if num_class is not None:
  270. raise NotImplementedError("Not support specify num_class")
  271. if not isinstance(shuffle, bool):
  272. raise ValueError("shuffle should be a boolean value, but got shuffle={}".format(shuffle))
  273. if num_samples is not None:
  274. if num_samples <= 0:
  275. raise ValueError("num_samples should be a positive integer "
  276. "value, but got num_samples={}".format(num_samples))
  277. self.num_val = num_val
  278. self.shuffle = shuffle
  279. self.class_column = class_column # work for minddataset
  280. super().__init__(num_samples)
  281. def create(self):
  282. num_samples = self.num_samples if self.num_samples is not None else 0
  283. c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle)
  284. c_child_sampler = self.create_child()
  285. c_sampler.add_child(c_child_sampler)
  286. return c_sampler
  287. def is_shuffled(self):
  288. if self.child_sampler is None:
  289. return self.shuffle
  290. return self.child_sampler.is_shuffled()
  291. def is_sharded(self):
  292. if self.child_sampler is None:
  293. return False
  294. return self.child_sampler.is_sharded()
  295. def create_for_minddataset(self):
  296. if not self.class_column or not isinstance(self.class_column, str):
  297. raise ValueError("class_column should be a not empty string value, \
  298. but got class_column={}".format(class_column))
  299. num_samples = self.num_samples if self.num_samples is not None else 0
  300. c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
  301. c_child_sampler = self.create_child_for_minddataset()
  302. c_sampler.add_child(c_child_sampler)
  303. return c_sampler
  304. class RandomSampler(BuiltinSampler):
  305. """
  306. Samples the elements randomly.
  307. Args:
  308. replacement (bool, optional): If True, put the sample ID back for the next draw (default=False).
  309. num_samples (int, optional): Number of elements to sample (default=None, all elements).
  310. Examples:
  311. >>> import mindspore.dataset as ds
  312. >>>
  313. >>> dataset_dir = "path/to/imagefolder_directory"
  314. >>>
  315. >>> # creates a RandomSampler
  316. >>> sampler = ds.RandomSampler()
  317. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  318. Raises:
  319. ValueError: If replacement is not boolean.
  320. ValueError: If num_samples is not positive.
  321. """
  322. def __init__(self, replacement=False, num_samples=None):
  323. if not isinstance(replacement, bool):
  324. raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
  325. if num_samples is not None:
  326. if num_samples <= 0:
  327. raise ValueError("num_samples should be a positive integer "
  328. "value, but got num_samples={}".format(num_samples))
  329. self.deterministic = False
  330. self.replacement = replacement
  331. self.reshuffle_each_epoch = True
  332. super().__init__(num_samples)
  333. def create(self):
  334. num_samples = self.num_samples if self.num_samples is not None else 0
  335. c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
  336. c_child_sampler = self.create_child()
  337. c_sampler.add_child(c_child_sampler)
  338. return c_sampler
  339. def create_for_minddataset(self):
  340. num_samples = self.num_samples if self.num_samples is not None else 0
  341. c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
  342. c_child_sampler = self.create_child_for_minddataset()
  343. c_sampler.add_child(c_child_sampler)
  344. return c_sampler
  345. def is_shuffled(self):
  346. return True
  347. def is_sharded(self):
  348. if self.child_sampler is None:
  349. return False
  350. return self.child_sampler.is_sharded()
  351. class SequentialSampler(BuiltinSampler):
  352. """
  353. Samples the dataset elements sequentially, same as not having a sampler.
  354. Args:
  355. start_index (int, optional): Index to start sampling at. (dafault=None, start at first ID)
  356. num_samples (int, optional): Number of elements to sample (default=None, all elements).
  357. Examples:
  358. >>> import mindspore.dataset as ds
  359. >>>
  360. >>> dataset_dir = "path/to/imagefolder_directory"
  361. >>>
  362. >>> # creates a SequentialSampler
  363. >>> sampler = ds.SequentialSampler()
  364. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  365. """
  366. def __init__(self, start_index=None, num_samples=None):
  367. if num_samples is not None:
  368. if num_samples <= 0:
  369. raise ValueError("num_samples should be a positive integer "
  370. "value, but got num_samples={}".format(num_samples))
  371. if start_index is not None:
  372. if start_index < 0:
  373. raise ValueError("start_index should be a positive integer "
  374. "value or 0, but got start_index={}".format(start_index))
  375. self.start_index = start_index
  376. super().__init__(num_samples)
  377. def create(self):
  378. start_index = self.start_index if self.start_index is not None else 0
  379. num_samples = self.num_samples if self.num_samples is not None else 0
  380. c_sampler = cde.SequentialSampler(num_samples, start_index)
  381. c_child_sampler = self.create_child()
  382. c_sampler.add_child(c_child_sampler)
  383. return c_sampler
  384. def create_for_minddataset(self):
  385. start_index = self.start_index if self.start_index is not None else 0
  386. num_samples = self.num_samples if self.num_samples is not None else 0
  387. c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
  388. c_child_sampler = self.create_child_for_minddataset()
  389. c_sampler.add_child(c_child_sampler)
  390. return c_sampler
  391. def is_shuffled(self):
  392. if self.child_sampler is None:
  393. return False
  394. return self.child_sampler.is_shuffled()
  395. def is_sharded(self):
  396. if self.child_sampler is None:
  397. return False
  398. return self.child_sampler.is_sharded()
  399. class SubsetRandomSampler(BuiltinSampler):
  400. """
  401. Samples the elements randomly from a sequence of indices.
  402. Args:
  403. indices (list[int]): A sequence of indices.
  404. num_samples (int, optional): Number of elements to sample (default=None, all elements).
  405. Examples:
  406. >>> import mindspore.dataset as ds
  407. >>>
  408. >>> dataset_dir = "path/to/imagefolder_directory"
  409. >>>
  410. >>> indices = [0, 1, 2, 3, 7, 88, 119]
  411. >>>
  412. >>> # creates a SubsetRandomSampler, will sample from the provided indices
  413. >>> sampler = ds.SubsetRandomSampler()
  414. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  415. """
  416. def __init__(self, indices, num_samples=None):
  417. if num_samples is not None:
  418. if num_samples <= 0:
  419. raise ValueError("num_samples should be a positive integer "
  420. "value, but got num_samples={}".format(num_samples))
  421. if not isinstance(indices, list):
  422. indices = [indices]
  423. self.indices = indices
  424. super().__init__(num_samples)
  425. def create(self):
  426. num_samples = self.num_samples if self.num_samples is not None else 0
  427. c_sampler = cde.SubsetRandomSampler(num_samples, self.indices)
  428. c_child_sampler = self.create_child()
  429. c_sampler.add_child(c_child_sampler)
  430. return c_sampler
  431. def is_shuffled(self):
  432. return True
  433. def is_sharded(self):
  434. if self.child_sampler is None:
  435. return False
  436. return self.child_sampler.is_sharded()
  437. def create_for_minddataset(self):
  438. c_sampler = cde.MindrecordSubsetRandomSampler(self.indices, ds.config.get_seed())
  439. c_child_sampler = self.create_child_for_minddataset()
  440. c_sampler.add_child(c_child_sampler)
  441. return c_sampler
  442. def get_num_samples(self):
  443. num_samples = super().get_num_samples()
  444. if num_samples is None:
  445. return len(self.indices)
  446. return min(len(self.indices), num_samples)
  447. class WeightedRandomSampler(BuiltinSampler):
  448. """
  449. Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
  450. Args:
  451. weights (list[float]): A sequence of weights, not necessarily summing up to 1.
  452. num_samples (int, optional): Number of elements to sample (default=None, all elements).
  453. replacement (bool): If True, put the sample ID back for the next draw (default=True).
  454. Examples:
  455. >>> import mindspore.dataset as ds
  456. >>>
  457. >>> dataset_dir = "path/to/imagefolder_directory"
  458. >>>
  459. >>> weights = [0.9, 0.01, 0.4, 0.8, 0.1, 0.1, 0.3]
  460. >>>
  461. >>> # creates a WeightedRandomSampler that will sample 4 elements without replacement
  462. >>> sampler = ds.WeightedRandomSampler(weights, 4)
  463. >>> data = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=8, sampler=sampler)
  464. Raises:
  465. ValueError: If num_samples is not positive.
  466. ValueError: If replacement is not boolean.
  467. """
  468. def __init__(self, weights, num_samples=None, replacement=True):
  469. if not isinstance(weights, list):
  470. weights = [weights]
  471. if weights == []:
  472. raise ValueError("weights size should not be 0")
  473. if list(filter(lambda x: x < 0, weights)):
  474. raise ValueError("weights should not contain negative numbers")
  475. if list(filter(lambda x: x == 0, weights)) == weights:
  476. raise ValueError("elements of weights should not be all zero")
  477. if num_samples is not None:
  478. if num_samples <= 0:
  479. raise ValueError("num_samples should be a positive integer "
  480. "value, but got num_samples={}".format(num_samples))
  481. if not isinstance(replacement, bool):
  482. raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement))
  483. self.weights = weights
  484. self.replacement = replacement
  485. super().__init__(num_samples)
  486. def create(self):
  487. num_samples = self.num_samples if self.num_samples is not None else 0
  488. c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement)
  489. c_child_sampler = self.create_child()
  490. c_sampler.add_child(c_child_sampler)
  491. return c_sampler
  492. def is_shuffled(self):
  493. return True
  494. def is_sharded(self):
  495. if self.child_sampler is None:
  496. return False
  497. return self.child_sampler.is_sharded()