From: @hfarahat Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -16,13 +16,11 @@ | |||||
| General Validators. | General Validators. | ||||
| """ | """ | ||||
| import inspect | import inspect | ||||
| import numbers | |||||
| from multiprocessing import cpu_count | from multiprocessing import cpu_count | ||||
| import os | import os | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from ..engine import samplers | |||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| POS_INT_MIN = 1 | POS_INT_MIN = 1 | ||||
| @@ -389,38 +387,5 @@ def check_tensor_op(param, param_name): | |||||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | ||||
| def check_sampler(sampler): | |||||
| """ | |||||
| Check if the sampler is of valid input. | |||||
| Args: | |||||
| param(Union[list, samplers.Sampler, samplers.BuiltinSampler, None]): sampler | |||||
| Returns: | |||||
| Exception: TypeError if error | |||||
| """ | |||||
| builtin = False | |||||
| base_sampler = False | |||||
| list_num = False | |||||
| if sampler is not None: | |||||
| if isinstance(sampler, samplers.BuiltinSampler): | |||||
| builtin = True | |||||
| elif isinstance(sampler, samplers.Sampler): | |||||
| base_sampler = True | |||||
| else: | |||||
| # check for list of numbers | |||||
| list_num = True | |||||
| # subset sampler check | |||||
| subset_sampler = sampler | |||||
| if not isinstance(sampler, list): | |||||
| subset_sampler = [sampler] | |||||
| for _, item in enumerate(subset_sampler): | |||||
| if not isinstance(item, numbers.Number): | |||||
| list_num = False | |||||
| if not (builtin or base_sampler or list_num): | |||||
| raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") | |||||
| def replace_none(value, default): | def replace_none(value, default): | ||||
| return value if value is not None else default | return value if value is not None else default | ||||
| @@ -41,22 +41,6 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||||
| Sampler, sampler selected based on user input. | Sampler, sampler selected based on user input. | ||||
| """ | """ | ||||
| def _is_iterable(obj): | |||||
| try: | |||||
| iter(obj) | |||||
| except TypeError: | |||||
| return False | |||||
| return True | |||||
| def _get_sample_ids_as_list(sampler, number_of_samples=None): | |||||
| if number_of_samples is None: | |||||
| return list(sampler) | |||||
| if isinstance(sampler, list): | |||||
| return sampler[:number_of_samples] | |||||
| return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] | |||||
| if input_sampler is not None: | if input_sampler is not None: | ||||
| # If the user provided a sampler, then it doesn't matter what the other args are because | # If the user provided a sampler, then it doesn't matter what the other args are because | ||||
| # we are being asked specifically to use the given sampler. | # we are being asked specifically to use the given sampler. | ||||
| @@ -73,11 +57,8 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||||
| ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) | ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) | ||||
| if isinstance(input_sampler, BuiltinSampler): | if isinstance(input_sampler, BuiltinSampler): | ||||
| return input_sampler | return input_sampler | ||||
| if not isinstance(input_sampler, str) and _is_iterable(input_sampler): | |||||
| return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) | |||||
| if isinstance(input_sampler, int): | |||||
| return SubsetSampler([input_sampler]) | |||||
| raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) | |||||
| return SubsetSampler(input_sampler, num_samples) | |||||
| if shuffle is None: | if shuffle is None: | ||||
| if num_shards is not None: | if num_shards is not None: | ||||
| # If shuffle is not specified, sharding enabled, use distributed random sampler | # If shuffle is not specified, sharding enabled, use distributed random sampler | ||||
| @@ -640,11 +621,31 @@ class SubsetSampler(BuiltinSampler): | |||||
| """ | """ | ||||
| def __init__(self, indices, num_samples=None): | def __init__(self, indices, num_samples=None): | ||||
| if not isinstance(indices, list): | |||||
| def _is_iterable(obj): | |||||
| try: | |||||
| iter(obj) | |||||
| except TypeError: | |||||
| return False | |||||
| return True | |||||
| def _get_sample_ids_as_list(sampler, number_of_samples=None): | |||||
| if number_of_samples is None: | |||||
| return list(sampler) | |||||
| if isinstance(sampler, list): | |||||
| return sampler[:number_of_samples] | |||||
| return [sample_id for sample_id, _ in zip(sampler, range(number_of_samples))] | |||||
| if not isinstance(indices, str) and _is_iterable(indices): | |||||
| indices = _get_sample_ids_as_list(indices, num_samples) | |||||
| elif isinstance(indices, int): | |||||
| indices = [indices] | indices = [indices] | ||||
| else: | |||||
| raise TypeError('Unsupported sampler object of type ({})'.format(type(indices))) | |||||
| for i, item in enumerate(indices): | for i, item in enumerate(indices): | ||||
| if not isinstance(item, int): | |||||
| if not isinstance(item, (int, np.integer)): | |||||
| raise TypeError("SubsetSampler: Type of indices element must be int, " | raise TypeError("SubsetSampler: Type of indices element must be int, " | ||||
| "but got list[{}]: {}, type: {}.".format(i, item, type(item))) | "but got list[{}]: {}, type: {}.".format(i, item, type(item))) | ||||
| @@ -177,13 +177,23 @@ def test_subset_sampler(): | |||||
| def pipeline(): | def pipeline(): | ||||
| sampler = ds.SubsetSampler(indices, num_samples) | sampler = ds.SubsetSampler(indices, num_samples) | ||||
| data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) | data = ds.NumpySlicesDataset(list(range(0, 10)), sampler=sampler) | ||||
| data2 = ds.NumpySlicesDataset(list(range(0, 10)), sampler=indices, num_samples=num_samples) | |||||
| dataset_size = data.get_dataset_size() | dataset_size = data.get_dataset_size() | ||||
| return [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size | |||||
| dataset_size2 = data.get_dataset_size() | |||||
| res1 = [d[0] for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size | |||||
| res2 = [d[0] for d in data2.create_tuple_iterator(num_epochs=1, output_numpy=True)], dataset_size2 | |||||
| return res1, res2 | |||||
| if exception_msg is None: | if exception_msg is None: | ||||
| res, size = pipeline() | |||||
| res, res2 = pipeline() | |||||
| res, size = res | |||||
| res2, size2 = res2 | |||||
| if not isinstance(indices, list): | |||||
| indices = list(indices) | |||||
| assert indices[:num_samples] == res | assert indices[:num_samples] == res | ||||
| assert len(indices[:num_samples]) == size | assert len(indices[:num_samples]) == size | ||||
| assert indices[:num_samples] == res2 | |||||
| assert len(indices[:num_samples]) == size2 | |||||
| else: | else: | ||||
| with pytest.raises(Exception) as error_info: | with pytest.raises(Exception) as error_info: | ||||
| pipeline() | pipeline() | ||||
| @@ -205,6 +215,8 @@ def test_subset_sampler(): | |||||
| test_config([0, 9, 3, 2], num_samples=2) | test_config([0, 9, 3, 2], num_samples=2) | ||||
| test_config([0, 9, 3, 2], num_samples=5) | test_config([0, 9, 3, 2], num_samples=5) | ||||
| test_config(np.array([1, 2, 3])) | |||||
| test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") | test_config([20], exception_msg="Sample ID (20) is out of bound, expected range [0, 9]") | ||||
| test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") | test_config([10], exception_msg="Sample ID (10) is out of bound, expected range [0, 9]") | ||||
| test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") | test_config([0, 9, 0, 500], exception_msg="Sample ID (500) is out of bound, expected range [0, 9]") | ||||
| @@ -212,6 +224,9 @@ def test_subset_sampler(): | |||||
| # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset | # test_config([], exception_msg="Indices list is empty") # temporary until we check with MindDataset | ||||
| test_config([0, 9, 3, 2], num_samples=-1, | test_config([0, 9, 3, 2], num_samples=-1, | ||||
| exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") | exception_msg="num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)") | ||||
| test_config(np.array([[1], [5]]), num_samples=10, | |||||
| exception_msg="SubsetSampler: Type of indices element must be int, but got list[0]: [1]," | |||||
| " type: <class 'numpy.ndarray'>.") | |||||
| def test_sampler_chain(): | def test_sampler_chain(): | ||||
| @@ -291,8 +306,8 @@ def test_sampler_list(): | |||||
| msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") | msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") | ||||
| bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") | bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") | ||||
| bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") | bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") | ||||
| bad_pipeline(sampler=np.array([1, 2]), | |||||
| msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.") | |||||
| bad_pipeline(sampler=np.array([[1, 2]]), | |||||
| msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||