| @@ -23,6 +23,7 @@ import numpy as np | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from ..engine import samplers | from ..engine import samplers | ||||
| # POS_INT_MIN is used to limit values from starting from 0 | # POS_INT_MIN is used to limit values from starting from 0 | ||||
| POS_INT_MIN = 1 | POS_INT_MIN = 1 | ||||
| UINT8_MAX = 255 | UINT8_MAX = 255 | ||||
| @@ -289,7 +290,6 @@ def check_sampler_shuffle_shard_options(param_dict): | |||||
| shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') | shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') | ||||
| num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') | ||||
| num_samples = param_dict.get('num_samples') | num_samples = param_dict.get('num_samples') | ||||
| check_sampler(sampler) | |||||
| if sampler is not None: | if sampler is not None: | ||||
| if shuffle is not None: | if shuffle is not None: | ||||
| @@ -348,6 +348,7 @@ def check_num_samples(value): | |||||
| raise ValueError( | raise ValueError( | ||||
| "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX)) | "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX)) | ||||
| def validate_dataset_param_value(param_list, param_dict, param_type): | def validate_dataset_param_value(param_list, param_dict, param_type): | ||||
| for param_name in param_list: | for param_name in param_list: | ||||
| if param_dict.get(param_name) is not None: | if param_dict.get(param_name) is not None: | ||||
| @@ -387,6 +388,7 @@ def check_tensor_op(param, param_name): | |||||
| if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): | if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): | ||||
| raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) | ||||
| def check_sampler(sampler): | def check_sampler(sampler): | ||||
| """ | """ | ||||
| Check if the sampler is of valid input. | Check if the sampler is of valid input. | ||||
| @@ -419,5 +421,6 @@ def check_sampler(sampler): | |||||
| if not (builtin or base_sampler or list_num): | if not (builtin or base_sampler or list_num): | ||||
| raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") | raise TypeError("Argument sampler is not of type Sampler, BuiltinSamplers, or list of numbers") | ||||
| def replace_none(value, default): | def replace_none(value, default): | ||||
| return value if value is not None else default | return value if value is not None else default | ||||
| @@ -73,11 +73,11 @@ def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||||
| ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) | ' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) | ||||
| if isinstance(input_sampler, BuiltinSampler): | if isinstance(input_sampler, BuiltinSampler): | ||||
| return input_sampler | return input_sampler | ||||
| if _is_iterable(input_sampler): | |||||
| if not isinstance(input_sampler, str) and _is_iterable(input_sampler): | |||||
| return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) | return SubsetSampler(_get_sample_ids_as_list(input_sampler, num_samples)) | ||||
| if isinstance(input_sampler, int): | if isinstance(input_sampler, int): | ||||
| return [input_sampler] | |||||
| raise ValueError('Unsupported sampler object ({})'.format(input_sampler)) | |||||
| return SubsetSampler([input_sampler]) | |||||
| raise TypeError('Unsupported sampler object of type ({})'.format(type(input_sampler))) | |||||
| if shuffle is None: | if shuffle is None: | ||||
| if num_shards is not None: | if num_shards is not None: | ||||
| # If shuffle is not specified, sharding enabled, use distributed random sampler | # If shuffle is not specified, sharding enabled, use distributed random sampler | ||||
| @@ -644,9 +644,9 @@ class SubsetSampler(BuiltinSampler): | |||||
| indices = [indices] | indices = [indices] | ||||
| for i, item in enumerate(indices): | for i, item in enumerate(indices): | ||||
| if not isinstance(item, numbers.Number): | |||||
| raise TypeError("type of indices element must be number, " | |||||
| "but got w[{}]: {}, type: {}.".format(i, item, type(item))) | |||||
| if not isinstance(item, int): | |||||
| raise TypeError("SubsetSampler: Type of indices element must be int, " | |||||
| "but got list[{}]: {}, type: {}.".format(i, item, type(item))) | |||||
| if num_samples is not None: | if num_samples is not None: | ||||
| if not isinstance(num_samples, int): | if not isinstance(num_samples, int): | ||||
| @@ -179,7 +179,7 @@ def test_celeba_sampler_exception(): | |||||
| pass | pass | ||||
| assert False | assert False | ||||
| except TypeError as e: | except TypeError as e: | ||||
| assert "Argument" in str(e) | |||||
| assert "Unsupported sampler object of type (<class 'str'>)" in str(e) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -274,6 +274,26 @@ def test_sampler_list(): | |||||
| dataset_equal(data1, data21 + data22 + data23, 0) | dataset_equal(data1, data21 + data22 + data23, 0) | ||||
| data3 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=1) | |||||
| dataset_equal(data3, data21, 0) | |||||
| def bad_pipeline(sampler, msg): | |||||
| with pytest.raises(Exception) as info: | |||||
| data1 = ds.ImageFolderDataset("../data/dataset/testPK/data", sampler=sampler) | |||||
| for _ in data1: | |||||
| pass | |||||
| assert msg in str(info.value) | |||||
| bad_pipeline(sampler=[1.5, 7], | |||||
| msg="Type of indices element must be int, but got list[0]: 1.5, type: <class 'float'>") | |||||
| bad_pipeline(sampler=["a", "b"], | |||||
| msg="Type of indices element must be int, but got list[0]: a, type: <class 'str'>.") | |||||
| bad_pipeline(sampler="a", msg="Unsupported sampler object of type (<class 'str'>)") | |||||
| bad_pipeline(sampler="", msg="Unsupported sampler object of type (<class 'str'>)") | |||||
| bad_pipeline(sampler=np.array([1, 2]), | |||||
| msg="Type of indices element must be int, but got list[0]: 1, type: <class 'numpy.int64'>.") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_sequential_sampler(True) | test_sequential_sampler(True) | ||||