Browse Source

!11854 Support list of IDs as a sampler

From: @hfarahat
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
112b5829e7
4 changed files with 105 additions and 162 deletions
  1. +1
    -4
      mindspore/dataset/core/validator_helpers.py
  2. +17
    -158
      mindspore/dataset/engine/datasets.py
  3. +76
    -0
      mindspore/dataset/engine/samplers.py
  4. +11
    -0
      tests/ut/python/dataset/test_sampler.py

+ 1
- 4
mindspore/dataset/core/validator_helpers.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -21,7 +21,6 @@ import os
import numpy as np import numpy as np


import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from ..engine import samplers


# POS_INT_MIN is used to limit values from starting from 0 # POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1 POS_INT_MIN = 1
@@ -290,8 +289,6 @@ def check_sampler_shuffle_shard_options(param_dict):
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
num_samples = param_dict.get('num_samples') num_samples = param_dict.get('num_samples')


type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler")

if sampler is not None: if sampler is not None:
if shuffle is not None: if shuffle is not None:
raise RuntimeError("sampler and shuffle cannot be specified at the same time.") raise RuntimeError("sampler and shuffle cannot be specified at the same time.")


+ 17
- 158
mindspore/dataset/engine/datasets.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -2708,7 +2708,7 @@ class ConcatDataset(Dataset):


self.dataset_size = None self.dataset_size = None


self._sampler = _select_sampler(None, sampler, None, None, None)
self._sampler = samplers.select_sampler(None, sampler, None, None, None)
cumulative_samples_nums = 0 cumulative_samples_nums = 0
for index, child in enumerate(self.children): for index, child in enumerate(self.children):
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None:
@@ -2990,65 +2990,6 @@ class RangeDataset(MappableDataset):
return self.dataset_size return self.dataset_size




def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
"""
Create sampler based on user input.

Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False).

Returns:
Sampler, sampler selected based on user input.
"""
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]):
return None

if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
samplers.RandomSampler, samplers.SubsetRandomSampler,
samplers.WeightedRandomSampler, samplers.Sampler)) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
return input_sampler
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
return samplers.RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return samplers.RandomSampler(replacement=True, num_samples=num_samples)
return samplers.RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return samplers.SequentialSampler(num_samples=num_samples)


class ImageFolderDataset(MappableDataset): class ImageFolderDataset(MappableDataset):
""" """
A source dataset that reads images from a tree of directories. A source dataset that reads images from a tree of directories.
@@ -3144,7 +3085,7 @@ class ImageFolderDataset(MappableDataset):
super().__init__(num_parallel_workers=num_parallel_workers) super().__init__(num_parallel_workers=num_parallel_workers)


self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.shuffle_level = shuffle self.shuffle_level = shuffle
self.extensions = replace_none(extensions, []) self.extensions = replace_none(extensions, [])
@@ -3293,7 +3234,7 @@ class MnistDataset(MappableDataset):


self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all") self.usage = replace_none(usage, "all")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.shuffle_level = shuffle self.shuffle_level = shuffle
self.num_shards = num_shards self.num_shards = num_shards
@@ -3386,7 +3327,7 @@ class MindDataset(MappableDataset):
samplers.SequentialSampler)) is False: samplers.SequentialSampler)) is False:
raise ValueError("The sampler is not supported yet.") raise ValueError("The sampler is not supported yet.")


self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples


self.padded_sample = padded_sample self.padded_sample = padded_sample
@@ -3470,27 +3411,6 @@ def _generator_fn(generator, num_samples):
yield val yield val




def _py_sampler_fn(sampler, num_samples, dataset):
"""
Generator function wrapper for mappable dataset with Python sampler.
"""
if num_samples is not None:
sampler_iter = iter(sampler)
for _ in range(num_samples):
try:
idx = next(sampler_iter)
except StopIteration:
return
val = dataset[idx]
# convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val])
else:
for i in sampler:
val = dataset[i]
# convert output tensors to ndarrays
yield tuple([np.array(x, copy=False) for x in val])


def _cpp_sampler_fn(sample_ids, dataset): def _cpp_sampler_fn(sample_ids, dataset):
""" """
Generator function wrapper for mappable dataset with cpp sampler. Generator function wrapper for mappable dataset with cpp sampler.
@@ -3518,31 +3438,6 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn):
return sample_fn.process(sample_ids) return sample_fn.process(sample_ids)




def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
"""
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
"""
indices = _fetch_py_sampler_indices(sampler, num_samples)
return sample_fn.process(indices)


def _fetch_py_sampler_indices(sampler, num_samples):
"""
Indice fetcher for Python sampler.
"""
if num_samples is not None:
sampler_iter = iter(sampler)
ret = []
for _ in range(num_samples):
try:
val = next(sampler_iter)
ret.append(val)
except StopIteration:
break
return ret
return [i for i in sampler]


def _fill_worker_indices(workers, indices, idx): def _fill_worker_indices(workers, indices, idx):
""" """
Worker index queue filler, fill worker index queue in round robin order. Worker index queue filler, fill worker index queue in round robin order.
@@ -3865,7 +3760,7 @@ class GeneratorDataset(MappableDataset):
python_multiprocessing=True): python_multiprocessing=True):
super().__init__(num_parallel_workers=num_parallel_workers) super().__init__(num_parallel_workers=num_parallel_workers)
self.source = source self.source = source
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.num_shards = num_shards self.num_shards = num_shards
self.python_multiprocessing = python_multiprocessing self.python_multiprocessing = python_multiprocessing
@@ -3912,26 +3807,11 @@ class GeneratorDataset(MappableDataset):
if hasattr(self, "__total_batch__"): if hasattr(self, "__total_batch__"):
new_op.__total_batch__ = self.__total_batch__ new_op.__total_batch__ = self.__total_batch__
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
if isinstance(new_op.sampler, samplers.BuiltinSampler):
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else:
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else: else:
# the sampler provided is not a built-in sampler, it is a list of sample_ids
new_op.sample_ids = new_op.sampler
# since list of sample_ids are not passed to c++, we need to find the proper len here
new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len(
new_op.sample_ids)
new_op.source_len = min(self.source_len,
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
new_op.sampler = None
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn))
else:
new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source))
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
else: else:
try: try:
@@ -4089,13 +3969,6 @@ class TFRecordDataset(SourceDataset):
self.shuffle_level = shuffle self.shuffle_level = shuffle
self.shuffle_files = True self.shuffle_files = True


# The TF record dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)
self.shard_equal_rows = replace_none(shard_equal_rows, False) self.shard_equal_rows = replace_none(shard_equal_rows, False)


def get_args(self): def get_args(self):
@@ -4231,7 +4104,7 @@ class ManifestDataset(MappableDataset):
super().__init__(num_parallel_workers=num_parallel_workers) super().__init__(num_parallel_workers=num_parallel_workers)


self.dataset_file = dataset_file self.dataset_file = dataset_file
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)


if class_indexing is not None and not isinstance(class_indexing, dict): if class_indexing is not None and not isinstance(class_indexing, dict):
raise RuntimeError("class_indexing must be a dictionary.") raise RuntimeError("class_indexing must be a dictionary.")
@@ -4396,7 +4269,7 @@ class Cifar10Dataset(MappableDataset):


self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all") self.usage = replace_none(usage, "all")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
@@ -4535,7 +4408,7 @@ class Cifar100Dataset(MappableDataset):


self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all") self.usage = replace_none(usage, "all")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
@@ -4607,8 +4480,6 @@ class RandomDataset(SourceDataset):
super().__init__(num_parallel_workers=num_parallel_workers) super().__init__(num_parallel_workers=num_parallel_workers)
self.schema = schema self.schema = schema
self.columns_list = replace_none(columns_list, []) self.columns_list = replace_none(columns_list, [])
sampler = None
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True)


self.num_samples = num_samples self.num_samples = num_samples
self.total_rows = total_rows self.total_rows = total_rows
@@ -4900,7 +4771,7 @@ class VOCDataset(MappableDataset):
self.task = replace_none(task, "Segmentation") self.task = replace_none(task, "Segmentation")
self.usage = replace_none(usage, "train") self.usage = replace_none(usage, "train")
self.class_indexing = class_indexing self.class_indexing = class_indexing
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.decode = replace_none(decode, False) self.decode = replace_none(decode, False)
self.shuffle_level = shuffle self.shuffle_level = shuffle
@@ -5092,7 +4963,7 @@ class CocoDataset(MappableDataset):
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.annotation_file = annotation_file self.annotation_file = annotation_file
self.task = replace_none(task, "Detection") self.task = replace_none(task, "Detection")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.decode = replace_none(decode, False) self.decode = replace_none(decode, False)
self.shuffle_level = shuffle self.shuffle_level = shuffle
@@ -5224,7 +5095,7 @@ class CelebADataset(MappableDataset):
extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers) super().__init__(num_parallel_workers=num_parallel_workers)
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_parallel_workers = num_parallel_workers self.num_parallel_workers = num_parallel_workers
self.decode = replace_none(decode, False) self.decode = replace_none(decode, False)
self.extensions = replace_none(extensions, []) self.extensions = replace_none(extensions, [])
@@ -5596,12 +5467,7 @@ class CSVDataset(SourceDataset):
self.shuffle_files = True self.shuffle_files = True


self.cache = cache self.cache = cache
# The CSV dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler = None
self.sampler = _select_sampler(num_samples, sampler, self.shuffle_files, num_shards, shard_id,
non_mappable=True)

self.num_shards = replace_none(num_shards, 1) self.num_shards = replace_none(num_shards, 1)
self.shard_id = replace_none(shard_id, 0) self.shard_id = replace_none(shard_id, 0)
self.num_samples = replace_none(num_samples, 0) self.num_samples = replace_none(num_samples, 0)
@@ -5715,13 +5581,6 @@ class TextFileDataset(SourceDataset):
self.shard_id = replace_none(shard_id, 0) self.shard_id = replace_none(shard_id, 0)


self.cache = cache self.cache = cache
# The text file dataset does not directly support a sampler. It has provided sampling arguments
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used.
sampler_shuffle = self.shuffle_files
sampler = None
self.sampler = _select_sampler(num_samples, sampler, sampler_shuffle, num_shards, shard_id,
non_mappable=True)


def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()


+ 76
- 0
mindspore/dataset/engine/samplers.py View File

@@ -25,6 +25,82 @@ import mindspore._c_dataengine as cde
import mindspore.dataset as ds import mindspore.dataset as ds




def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
"""
Create sampler based on user input.

Args:
num_samples (int): Number of samples.
input_sampler (Union[Iterable, Sampler]): Sampler from user.
shuffle (bool): Shuffle.
num_shards (int): Number of shard for sharding.
shard_id (int): Shard ID.

Returns:
Sampler, sampler selected based on user input.
"""

def _is_iterable(obj):
try:
iter(obj)
except TypeError:
return False
return True

def _get_sample_ids_as_list(sampler, number_of_samples=None):
if number_of_samples is None:
return list(sampler)

if isinstance(sampler, list):
return sampler[:number_of_samples]

return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))]

if input_sampler is not None:
# If the user provided a sampler, then it doesn't matter what the other args are because
# we are being asked specifically to use the given sampler.
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all
# be None. Consider this example:
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle)
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1)
# In this case, the user has given different sample-related arguments that contradict each other.
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None
if (isinstance(input_sampler, BuiltinSampler) and
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))):
raise ValueError(
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},'
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle))
if isinstance(input_sampler, BuiltinSampler):
return input_sampler
if _is_iterable(input_sampler):
return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples))
if isinstance(input_sampler, int):
return [input_sampler]
raise ValueError('Unsupported sampler object ({})'.format(input_sampler))
if shuffle is None:
if num_shards is not None:
# If shuffle is not specified, sharding enabled, use distributed random sampler
shuffle = True
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle is not specified, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if shuffle is True:
if num_shards is not None:
# If shuffle enabled, sharding enabled, use distributed random sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle enabled, sharding disabled, use random sampler
if num_samples is not None:
return RandomSampler(replacement=True, num_samples=num_samples)
return RandomSampler(num_samples=num_samples)
if num_shards is not None:
# If shuffle disabled, sharding enabled, use distributed sequential sampler
return DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples)
# If shuffle disabled, sharding disabled, use sequential sampler
return SequentialSampler(num_samples=num_samples)


class BuiltinSampler: class BuiltinSampler:
""" """
Base class for BuiltinSampler. Base class for BuiltinSampler.


+ 11
- 0
tests/ut/python/dataset/test_sampler.py View File

@@ -17,6 +17,7 @@ import pytest


import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from util import dataset_equal




# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
@@ -265,6 +266,15 @@ def test_distributed_sampler_invalid_offset():
assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value) assert "DistributedSampler: invalid offset: 5, which should be no more than num_shards: 4" in str(info.value)




def test_sampler_list():
data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=[1, 3, 5])
data21 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(2).skip(1)
data22 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(4).skip(3)
data23 = ds.ImageFolderDataset("../data/dataset/testPK/data", shuffle=False).take(6).skip(5)

dataset_equal(data1, data21 + data22 + data23, 0)


if __name__ == '__main__': if __name__ == '__main__':
test_sequential_sampler(True) test_sequential_sampler(True)
test_random_sampler(True) test_random_sampler(True)
@@ -276,3 +286,4 @@ if __name__ == '__main__':
test_sampler_chain() test_sampler_chain()
test_add_sampler_invalid_input() test_add_sampler_invalid_input()
test_distributed_sampler_invalid_offset() test_distributed_sampler_invalid_offset()
test_sampler_list()

Loading…
Cancel
Save