| @@ -95,7 +95,7 @@ def check_uint32(value, arg_name=""): | |||||
| def check_pos_int32(value, arg_name=""): | def check_pos_int32(value, arg_name=""): | ||||
| type_check(value, (int,), arg_name) | type_check(value, (int,), arg_name) | ||||
| check_value(value, [POS_INT_MIN, INT32_MAX]) | |||||
| check_value(value, [POS_INT_MIN, INT32_MAX], arg_name) | |||||
| def check_uint64(value, arg_name=""): | def check_uint64(value, arg_name=""): | ||||
| @@ -143,6 +143,8 @@ def check_columns(columns, name): | |||||
| col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] | col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] | ||||
| type_check_list(columns, (str,), col_names) | type_check_list(columns, (str,), col_names) | ||||
| if len(set(columns)) != len(columns): | |||||
| raise ValueError("Every column name should not be same with others in column_names.") | |||||
| def parse_user_args(method, *args, **kwargs): | def parse_user_args(method, *args, **kwargs): | ||||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | ||||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | ||||
| check_split, check_bucket_batch_by_length, check_cluedataset | |||||
| check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 | |||||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | ||||
| try: | try: | ||||
| @@ -939,6 +939,7 @@ class Dataset: | |||||
| raise TypeError("apply_func must return a dataset.") | raise TypeError("apply_func must return a dataset.") | ||||
| return dataset | return dataset | ||||
| @check_positive_int32 | |||||
| def device_que(self, prefetch_size=None): | def device_que(self, prefetch_size=None): | ||||
| """ | """ | ||||
| Return a transferredDataset that transfer data through device. | Return a transferredDataset that transfer data through device. | ||||
| @@ -956,6 +957,7 @@ class Dataset: | |||||
| """ | """ | ||||
| return self.to_device() | return self.to_device() | ||||
| @check_positive_int32 | |||||
| def to_device(self, num_batch=None): | def to_device(self, num_batch=None): | ||||
| """ | """ | ||||
| Transfer data through CPU, GPU or Ascend devices. | Transfer data through CPU, GPU or Ascend devices. | ||||
| @@ -973,7 +975,7 @@ class Dataset: | |||||
| Raises: | Raises: | ||||
| TypeError: If device_type is empty. | TypeError: If device_type is empty. | ||||
| ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. | ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. | ||||
| ValueError: If num_batch is None or 0 or larger than int_max. | |||||
| ValueError: If num_batch is negative or larger than int_max. | |||||
| RuntimeError: If dataset is unknown. | RuntimeError: If dataset is unknown. | ||||
| RuntimeError: If distribution file path is given but failed to read. | RuntimeError: If distribution file path is given but failed to read. | ||||
| """ | """ | ||||
| @@ -25,7 +25,7 @@ from mindspore._c_expression import typing | |||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | ||||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | ||||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | ||||
| check_columns, check_positive | |||||
| check_columns, check_positive, check_pos_int32 | |||||
| from . import datasets | from . import datasets | ||||
| from . import samplers | from . import samplers | ||||
| @@ -593,6 +593,25 @@ def check_take(method): | |||||
| return new_method | return new_method | ||||
| def check_positive_int32(method): | |||||
| """check whether the input argument is positive and int, only works for functions with one input.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [count], param_dict = parse_user_args(method, *args, **kwargs) | |||||
| para_name = None | |||||
| for key in list(param_dict.keys()): | |||||
| if key not in ['self', 'cls']: | |||||
| para_name = key | |||||
| # Need to get default value of param | |||||
| if count is not None: | |||||
| check_pos_int32(count, para_name) | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_zip(method): | def check_zip(method): | ||||
| """check the input arguments of zip.""" | """check the input arguments of zip.""" | ||||