From: @hfarahat Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -16,13 +16,11 @@ | |||
| General Validators. | |||
| """ | |||
| import inspect | |||
| import numbers | |||
| from multiprocessing import cpu_count | |||
| import os | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| from ..engine import samplers | |||
| # POS_INT_MIN is used to limit values from starting from 0 | |||
| POS_INT_MIN = 1 | |||
| @@ -389,38 +387,5 @@ def check_tensor_op(param, param_name): | |||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | |||
| def check_sampler(sampler): | |||
| """ | |||
| Check if the sampler is of valid input. | |||
| Args: | |||
| param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler | |||
| Returns: | |||
| Exception: TypeError if error | |||
| """ | |||
| builtin = False | |||
| base_sampler = False | |||
| list_num = False | |||
| if sampler is not None: | |||
| if isinstance(sampler, samplers.BuiltinSampler): | |||
| builtin = True | |||
| elif isinstance(sampler, samplers.Sampler): | |||
| base_sampler = True | |||
| else: | |||
| # check for list of numbers | |||
| list_num = True | |||
| # subset sampler check | |||
| subset_sampler = sampler | |||
| if not isinstance(sampler, list): | |||
| subset_sampler = [sampler] | |||
| for _, item in enumerate(subset_sampler): | |||
| if not isinstance(item, numbers.Number): | |||
| list_num = False | |||
| if not (builtin or base_sampler or list_num): | |||
| raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") | |||
| def replace_none(value, default): | |||
| return value if value is not None else default | |||
| @@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| Sampler, sampler selected based on user input. | |||
| """ | |||
| def _is_iterable(obj): | |||
| try: | |||
| iter(obj) | |||
| except TypeError: | |||
| return False | |||
| return True | |||
| def _get_sample_ids_as_list(sampler, number_of_samples=None): | |||
| if number_of_samples is None: | |||
| return list(sampler) | |||
| if isinstance(sampler, list): | |||
| return sampler[:number_of_samples] | |||
| return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] | |||
| if input_sampler is not None: | |||
| # If the user provided a sampler, then it doesn't matter what the other args are because | |||
| # we are being asked specifically to use the given sampler. | |||
| @@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) | |||
| if isinstance(input_sampler, BuiltinSampler): | |||
| return input_sampler | |||
| if not isinstance(input_sampler, str) and _is_iterable(input_sampler): | |||
| return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) | |||
| if isinstance(input_sampler, int): | |||
| return SubsetSampler([input_sampler]) | |||
| raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) | |||
| return SubsetSampler(input_sampler, num_samples) | |||
| if shuffle is None: | |||
| if num_shards is not None: | |||
| # If shuffle is not specified, sharding enabled, use distributed random sampler | |||
| @@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler): | |||
| """ | |||
| def __init__(self, indices, num_samples=None): | |||
| if not isinstance(indices, list): | |||
| def _is_iterable(obj): | |||
| try: | |||
| iter(obj) | |||
| except TypeError: | |||
| return False | |||
| return True | |||
| def _get_sample_ids_as_list(sampler, number_of_samples=None): | |||
| if number_of_samples is None: | |||
| return list(sampler) | |||
| if isinstance(sampler, list): | |||
| return sampler[:number_of_samples] | |||
| return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] | |||
| if not isinstance(indices, str) and _is_iterable(indices): | |||
| indices = _get_sample_ids_as_list(indices, num_samples) | |||
| elif isinstance(indices, int): | |||
| indices = [indices] | |||
| else: | |||
| raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) | |||
| for i, item in enumerate(indices): | |||
| if not isinstance(item, int): | |||
| if not isinstance(item, (int, np.integer)): | |||
| raise TypeError("SubsetSampler: Type of indices element must be int, " | |||
| "but got list[{}]: {}, type: {}.".format(i, item, type(item))) | |||
| @@ -177,13 +177,23 @@ def test_subset_sampler(): | |||
| def pipeline(): | |||
| sampler = ds.SubsetSampler(indices, num_samples) | |||
| data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) | |||
| data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples) | |||
| dataset_size = data.get_dataset_size() | |||
| return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size | |||
| dataset_size2 = data.get_dataset_size() | |||
| res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size | |||
| res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2 | |||
| return res1, res2 | |||
| if exception_msg is None: | |||
| res, size = pipeline() | |||
| res, res2 = pipeline() | |||
| res, size = res | |||
| res2, size2 = res2 | |||
| if not isinstance(indices, list): | |||
| indices = list(indices) | |||
| assert indices[:num_samples] == res | |||
| assert len(indices[:num_samples]) == size | |||
| assert indices[:num_samples] == res2 | |||
| assert len(indices[:num_samples]) == size2 | |||
| else: | |||
| with pytest.raises(Exception) as error_info: | |||
| pipeline() | |||
| @@ -205,6 +215,8 @@ def test_subset_sampler(): | |||
| test_config([0, 9, 3, 2], num_samples=2) | |||
| test_config([0, 9, 3, 2], num_samples=5) | |||
| test_config(np.array([1, 2, 3])) | |||
| test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") | |||
| test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") | |||
| test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") | |||
| @@ -212,6 +224,9 @@ def test_subset_sampler(): | |||
| # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset | |||
| test_config([0, 9, 3, 2], num_samples=-1, | |||
| exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") | |||
| test_config(np.array([[1], [5]]), num_samples=10, | |||
| exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1]," | |||
| " type: <class 'numpy.ndarray'>.") | |||
| def test_sampler_chain(): | |||
| @@ -291,8 +306,8 @@ def test_sampler_list(): | |||
| msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") | |||
| bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") | |||
| bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") | |||
| bad_pipeline(sampler=np.array([1, 2]), | |||
| msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.") | |||
| bad_pipeline(sampler=np.array([[1, 2]]), | |||
| msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.") | |||
| if __name__ == '__main__': | |||