|
|
@@ -25,7 +25,6 @@ import mindspore._c_dataengine as cde |
|
|
import mindspore.dataset as ds |
|
|
import mindspore.dataset as ds |
|
|
from ..core import validator_helpers as validator |
|
|
from ..core import validator_helpers as validator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): |
|
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): |
|
|
""" |
|
|
""" |
|
|
Create sampler based on user input. |
|
|
Create sampler based on user input. |
|
|
@@ -57,8 +56,14 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): |
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) |
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) |
|
|
if isinstance(input_sampler, BuiltinSampler): |
|
|
if isinstance(input_sampler, BuiltinSampler): |
|
|
return input_sampler |
|
|
return input_sampler |
|
|
return SubsetSampler(input_sampler, num_samples) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(input_sampler, str) and isinstance(input_sampler, (np.ndarray, list)): |
|
|
|
|
|
return SubsetSampler(input_sampler, num_samples) |
|
|
|
|
|
if not isinstance(input_sampler, str) and validator.is_iterable(input_sampler): |
|
|
|
|
|
# in this case, the user passed in their own sampler object that's not of type BuiltinSampler |
|
|
|
|
|
return IterSampler(input_sampler, num_samples) |
|
|
|
|
|
if isinstance(input_sampler, int): |
|
|
|
|
|
return SubsetSampler([input_sampler]) |
|
|
|
|
|
raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) |
|
|
if shuffle is None: |
|
|
if shuffle is None: |
|
|
if num_shards is not None: |
|
|
if num_shards is not None: |
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler |
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler |
|
|
@@ -621,13 +626,6 @@ class SubsetSampler(BuiltinSampler): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, indices, num_samples=None): |
|
|
def __init__(self, indices, num_samples=None): |
|
|
def _is_iterable(obj): |
|
|
|
|
|
try: |
|
|
|
|
|
iter(obj) |
|
|
|
|
|
except TypeError: |
|
|
|
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def _get_sample_ids_as_list(sampler, number_of_samples=None): |
|
|
def _get_sample_ids_as_list(sampler, number_of_samples=None): |
|
|
if number_of_samples is None: |
|
|
if number_of_samples is None: |
|
|
return list(sampler) |
|
|
return list(sampler) |
|
|
@@ -637,7 +635,7 @@ class SubsetSampler(BuiltinSampler): |
|
|
|
|
|
|
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] |
|
|
return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] |
|
|
|
|
|
|
|
|
if not isinstance(indices, str) and _is_iterable(indices): |
|
|
|
|
|
|
|
|
if not isinstance(indices, str) and validator.is_iterable(indices): |
|
|
indices = _get_sample_ids_as_list(indices, num_samples) |
|
|
indices = _get_sample_ids_as_list(indices, num_samples) |
|
|
elif isinstance(indices, int): |
|
|
elif isinstance(indices, int): |
|
|
indices = [indices] |
|
|
indices = [indices] |
|
|
@@ -731,6 +729,42 @@ class SubsetRandomSampler(SubsetSampler): |
|
|
return c_sampler |
|
|
return c_sampler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IterSampler(Sampler): |
|
|
|
|
|
""" |
|
|
|
|
|
User provided an iterable object without inheriting from our Sampler class. |
|
|
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
|
This class exists to allow handshake logic between dataset operators and user defined samplers. |
|
|
|
|
|
By constructing this object we avoid the user having to inherit from our Sampler class. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
sampler (iterable object): an user defined iterable object. |
|
|
|
|
|
num_samples (int, optional): Number of elements to sample (default=None, all elements). |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> class MySampler(): |
|
|
|
|
|
>>> def __iter__(self): |
|
|
|
|
|
>>> for i in range(99, -1, -1): |
|
|
|
|
|
>>> yield i |
|
|
|
|
|
|
|
|
|
|
|
>>> # creates an IterSampler |
|
|
|
|
|
>>> sampler = ds.IterSampler(sampler=MySampler()) |
|
|
|
|
|
>>> dataset = ds.ImageFolderDataset(image_folder_dataset_dir, |
|
|
|
|
|
... num_parallel_workers=8, |
|
|
|
|
|
... sampler=sampler) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, sampler, num_samples=None): |
|
|
|
|
|
if num_samples is None: |
|
|
|
|
|
num_samples = len(list(sampler)) |
|
|
|
|
|
super().__init__(num_samples=num_samples) |
|
|
|
|
|
self.sampler = sampler |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
return iter(self.sampler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedRandomSampler(BuiltinSampler): |
|
|
class WeightedRandomSampler(BuiltinSampler): |
|
|
""" |
|
|
""" |
|
|
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). |
|
|
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). |
|
|
|