You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.DistributedSampler.rst 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. mindspore.dataset.DistributedSampler
  2. ====================================
  3. .. py:class:: mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1)
  4. 分布式采样器,将数据集进行分片用于分布式训练。
  5. **参数:**
  6. - **num_shards** (int):数据集分片数量。
  7. - **shard_id** (int):当前分片的分片ID,应在[0, num_shards-1]范围内。
  8. - **shuffle** (bool, optional):如果为True,则索引将被打乱(默认为True)。
  9. - **num_samples** (int, optional):要采样的样本数(默认为None,对所有元素进行采样)。
  10. - **offset** (int, optional):将数据集中的元素发送到的起始分片ID,不应超过 `num_shards` 。仅当ConcatDataset以DistributedSampler为采样器时,此参数才有效。此参数影响每个分片的样本数(默认为-1,每个分片具有相同的样本数)。
  11. **样例:**
  12. >>> # 创建一个分布式采样器,共10个分片。当前分片为分片5。
  13. >>> sampler = ds.DistributedSampler(10, 5)
  14. >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir,
  15. ... num_parallel_workers=8,
  16. ... sampler=sampler)
  17. **异常:**
  18. - **TypeError**:`num_shards` 不是整数值。
  19. - **TypeError**:`shard_id` 不是整数值。
  20. - **TypeError**:`shuffle` 不是Boolean值。
  21. - **TypeError**:`num_samples` 不是整数值。
  22. - **TypeError**:`offset` 不是整数值。
  23. - **ValueError**:`num_samples` 为负值。
  24. - **RuntimeError**:`num_shards` 不是正值。
  25. - **RuntimeError**:`shard_id` 小于0或大于等于 `num_shards` 。
  26. - **RuntimeError**:`offset` 大于 `num_shards` 。
  27. .. include:: mindspore.dataset.BuiltinSampler.rst