|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Yolo dataset distributed sampler."""
- from __future__ import division
- import math
- import numpy as np
-
-
- class DistributedSampler:
- """Distributed sampler."""
- def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
- if num_replicas is None:
- print("***********Setting world_size to 1 since it is not passed in ******************")
- num_replicas = 1
- if rank is None:
- print("***********Setting rank to 0 since it is not passed in ******************")
- rank = 0
- self.dataset_size = dataset_size
- self.num_replicas = num_replicas
- self.rank = rank
- self.epoch = 0
- self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
- self.total_size = self.num_samples * self.num_replicas
- self.shuffle = shuffle
-
- def __iter__(self):
- # deterministically shuffle based on epoch
- if self.shuffle:
- indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
- # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
- indices = indices.tolist()
- self.epoch += 1
- # change to list type
- else:
- indices = list(range(self.dataset_size))
-
- # add extra samples to make it evenly divisible
- indices += indices[:(self.total_size - len(indices))]
- assert len(indices) == self.total_size
-
- # subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
- assert len(indices) == self.num_samples
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
|