# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Yolo dataset distributed sampler.""" from __future__ import division import math import numpy as np class DistributedSampler: """Distributed sampler.""" def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): if num_replicas is None: print("***********Setting world_size to 1 since it is not passed in ******************") num_replicas = 1 if rank is None: print("***********Setting rank to 0 since it is not passed in ******************") rank = 0 self.dataset_size = dataset_size self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle def __iter__(self): # deterministically shuffle based on epoch if self.shuffle: indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset indices = indices.tolist() self.epoch += 1 # change to list type else: indices = list(range(self.dataset_size)) # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples