mindspore.dataset.DistributedSampler ==================================== .. py:class:: mindspore.dataset.DistributedSampler(num_shards, shard_id, shuffle=True, num_samples=None, offset=-1) 分布式采样器,将数据集进行分片用于分布式训练。 **参数:** - **num_shards** (int) - 数据集分片数量。 - **shard_id** (int) - 当前分片的分片ID,应在[0, num_shards-1]范围内。 - **shuffle** (bool, optional) - 如果为True,则索引将被打乱(默认为True)。 - **num_samples** (int, optional) - 要采样的样本数(默认为None,对所有元素进行采样)。 - **offset** (int, optional) - 将数据集中的元素发送到的起始分片ID,不应超过 `num_shards` 。仅当ConcatDataset以DistributedSampler为采样器时,此参数才有效。此参数影响每个分片的样本数(默认为-1,每个分片具有相同的样本数)。 **样例:** >>> # 创建一个分布式采样器,共10个分片。当前分片为分片5。 >>> sampler = ds.DistributedSampler(10, 5) >>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, ... num_parallel_workers=8, ... sampler=sampler) **异常:** - **TypeError** - `num_shards` 不是整数值。 - **TypeError** - `shard_id` 不是整数值。 - **TypeError** - `shuffle` 不是Boolean值。 - **TypeError** - `num_samples` 不是整数值。 - **TypeError** - `offset` 不是整数值。 - **ValueError** - `num_samples` 为负值。 - **RuntimeError** - `num_shards` 不是正值。 - **RuntimeError** - `shard_id` 小于0或大于等于 `num_shards` 。 - **RuntimeError** - `offset` 大于 `num_shards` 。 .. include:: mindspore.dataset.BuiltinSampler.rst