diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 15ef79cf77..7c408bce01 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -415,37 +415,20 @@ class Validator: break if not hit: type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) - raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' - f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') + raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass' + f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.') @staticmethod def check_const_input(arg_name, arg_value, prim_name): """Checks valid value.""" if arg_value is None: - raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') + raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.') return arg_value @staticmethod - def check_type(arg_name, arg_value, valid_types): - """Type checking.""" - def raise_error_msg(): - """func for raising error message when check failed""" - raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.') - - if isinstance(arg_value, type(mstype.tensor)): - arg_value = arg_value.element_type() - if isinstance(arg_value, bool) and bool not in tuple(valid_types): - raise_error_msg() - if arg_value in valid_types: - return arg_value - if isinstance(arg_value, tuple(valid_types)): - return arg_value - raise_error_msg() - - @staticmethod - def check_type_same(args, valid_values, prim_name): - """Checks whether the types of inputs are the same.""" - def _check_tensor_type(arg): + def check_types_same_and_valid(args, valid_values, prim_name): + """Checks whether the types of inputs are the same and valid.""" + def _check_type_valid(arg): arg_key, arg_val = arg elem_type = arg_val Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) @@ -455,21 +438,27 @@ class Validator: arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 if arg1_type != arg2_type: - raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' + raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') return arg1 - elem_types = map(_check_tensor_type, args.items()) + elem_types = map(_check_type_valid, args.items()) reduce(_check_types_same, elem_types) @staticmethod - def check_tensor_type_same(args, valid_values, prim_name): - """Checks whether the element types of input tensors are the same.""" - tensor_types = [mstype.tensor_type(t) for t in valid_values] - Validator.check_type_same(args, tensor_types, prim_name) + def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): + """Checks whether the element types of input tensors are the same and valid.""" + tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] + Validator.check_types_same_and_valid(args, tensor_types, prim_name) + + @staticmethod + def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): + """Checks whether the element types of input tensors are valid.""" + tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] + Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) @staticmethod - def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): + def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False): """ Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. @@ -480,7 +469,7 @@ class Validator: if isinstance(arg_val, type(mstype.tensor)): arg_val = arg_val.element_type() if not arg_val in valid_values: - raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' + raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},' f' but `{arg_key}` is {arg_val}.') return arg @@ -512,40 +501,40 @@ class Validator: def raise_error_msg(): """func for raising error message when check failed""" - type_names = [t.__name__ for t in valid_types] + type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] num_types = len(valid_types) - msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' - f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') + f'{type_names if num_types > 1 else type_names[0]}, ' + f'but got {arg_value} with type {type(arg_value).__name__}.') # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and # `check_value_type('x', True, [bool, int])` will check pass if isinstance(arg_value, bool) and bool not in tuple(valid_types): raise_error_msg() - if isinstance(arg_value, tuple(valid_types)): - return arg_value - raise_error_msg() + if not isinstance(arg_value, tuple(valid_types)): + raise_error_msg() + return arg_value @staticmethod def check_type_name(arg_name, arg_type, valid_types, prim_name): """Checks whether a type in some specified types""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) - def get_typename(t): - return t.__name__ if hasattr(t, '__name__') else str(t) + def raise_error_msg(): + """func for raising error message when check failed""" + type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] + num_types = len(valid_types) + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" + raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" + f"{type_names if num_types > 1 else type_names[0]}, " + f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.") if isinstance(arg_type, type(mstype.tensor)): arg_type = arg_type.element_type() - - if arg_type in valid_types: - return arg_type - type_names = [get_typename(t) for t in valid_types] - msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' - if len(valid_types) == 1: - raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' - f' but got {get_typename(arg_type)}.') - raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' - f' but got {get_typename(arg_type)}.') + if arg_type not in valid_types: + raise_error_msg() + return arg_type @staticmethod def check_reduce_shape(ori_shape, shape, axis, prim_name): @@ -611,65 +600,6 @@ def check_output_data(data): once = _expand_tuple(1) twice = _expand_tuple(2) triple = _expand_tuple(3) -valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, np.float16, - np.float32, np.float64, bool, np.bool_) - - -def check_type(arg_name, arg_value, valid_types): - """Check value type.""" - # if input type is Tensor ,get element type - if isinstance(arg_value, type(mstype.tensor)): - arg_value = arg_value.element_type() - - # First, check if arg_value has argvalid_types - if isinstance(arg_value, tuple(valid_types)): - return type(arg_value).__name__ - - # Second, wrap arg_value with numpy array so that it can be checked through numpy api - if isinstance(arg_value, (list, tuple)): - arg_value = np.array(arg_value) - - # Thirdly, check the data type by numpy's dtype api - valid = False - if isinstance(arg_value, np.ndarray): - valid = arg_value.dtype in valid_data_types - - # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and - # `check_type('x', True, [bool, int])` will check pass - if isinstance(arg_value, bool) and bool not in tuple(valid_types): - valid = False - - if not valid: - type_names = [t.__name__ for t in valid_types] - if len(valid_types) == 1: - raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' - f' but got {type(arg_value).__name__}.') - raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' - f' but got {type(arg_value).__name__}.') - - return type(arg_value).__name__ - - -def check_typename(arg_name, arg_type, valid_types): - """Check type name.""" - - def get_typename(t): - return t.__name__ if hasattr(t, '__name__') else str(t) - - if isinstance(arg_type, type(mstype.tensor)): - arg_type = arg_type.element_type() - - if arg_type in valid_types: - return arg_type - if isinstance(arg_type, tuple(valid_types)): - return arg_type - type_names = [get_typename(t) for t in valid_types] - if len(valid_types) == 1: - raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' - f' but got {get_typename(arg_type)}.') - raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' - f' but got {get_typename(arg_type)}.') def args_type_check(*type_args, **type_kwargs): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 7b5ae75f24..73c758eccb 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -19,7 +19,7 @@ from mindspore import log as logger from mindspore.communication.management import get_rank, get_group_size from .._c_expression import Tensor as Tensor_ from .._c_expression import MetaTensor as MetaTensor_ -from .._checkparam import check_type, check_typename +from .._checkparam import Validator as validator from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry @@ -64,9 +64,19 @@ class Tensor(Tensor_): input_data = np.array(input_data) # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. - check_type('tensor input_data', input_data, (Tensor_, float, int)) + validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), + 'Tensor') + valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, np.bool_) + if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes: + raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is " + f"{input_data.dtype} that is not supported to initialize a Tensor.") + if isinstance(input_data, (tuple, list)): + if np.array(input_data).dtype not in valid_dtypes: + raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") if dtype is not None: - check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,)) + validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor") + if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): input_data = np.ascontiguousarray(input_data) if dtype is None: @@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_): Returns: Array, an array after being initialized. """ + def __init__(self, dtype, shape, init=None): - #check param + # check param self.init = init MetaTensor_.__init__(self, dtype, shape) @@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_): msg = "Error shape={}".format(shape) logger.error(msg) raise ValueError(msg) + class seed_context: '''set and restore seed''' + def __init__(self, init): self.init = init from .seed import get_seed @@ -482,4 +495,5 @@ def _vm_compare(*args): y = args[0] return Tensor(np.array(fn(y))) + tensor_operator_registry.register('vm_compare', _vm_compare) diff --git a/mindspore/nn/graph_kernels/graph_kernels.py b/mindspore/nn/graph_kernels/graph_kernels.py index f1e511dd63..9cb06d7e49 100644 --- a/mindspore/nn/graph_kernels/graph_kernels.py +++ b/mindspore/nn/graph_kernels/graph_kernels.py @@ -21,7 +21,7 @@ from ...ops import operations as P from ...ops.primitive import PrimitiveWithInfer, prim_attr_register from ...ops.composite import multitype_ops as C from ...ops.operations import _grad_ops as G -from ..._checkparam import Validator +from ..._checkparam import Validator as validator from ..cell import Cell, GraphKernel @@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel): use_locking=False, gradient_scale=1.0): super(ApplyMomentum, self).__init__() - self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) + self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__) self.fake_output_assign_1 = InplaceAssign() self.fake_output_assign_1.add_prim_attr("fake_output", True) self.fake_output_assign_2 = InplaceAssign() @@ -334,7 +334,7 @@ class ReduceMean(GraphKernel): def __init__(self, keep_dims=True): super(ReduceMean, self).__init__() - self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) + self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__) self.sum = P.ReduceSum(self.keep_dims) def construct(self, x, axis): @@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel): """ Forward function of the LayerNorm operator. """ def __init__(self, begin_norm_axis=1, begin_params_axis=1): super(LayerNormForward, self).__init__() - self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) + self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], + type(self).__name__) + self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], + type(self).__name__) self.mul = P.Mul() self.sum_keep_dims = P.ReduceSum(keep_dims=True) self.sub = P.Sub() @@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel): def __init__(self, axis=-1): super(LogSoftmax, self).__init__() - self.axis = Validator.check_type('axis', axis, [int]) + self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__) self.max_keep_dims = P.ReduceMax(keep_dims=True) self.sub = P.Sub() self.exp = P.Exp() @@ -952,13 +954,13 @@ class Softmax(GraphKernel): def __init__(self, axis): super(Softmax, self).__init__() - Validator.check_type("axis", axis, [int, tuple]) + validator.check_value_type("axis", axis, [int, tuple], type(self).__name__) if isinstance(axis, int): self.axis = (axis,) else: self.axis = axis for item in self.axis: - Validator.check_type("item of axis", item, [int]) + validator.check_value_type("item of axis", item, [int], type(self).__name__) self.max = P.ReduceMax(keep_dims=True) self.sub = P.Sub() self.exp = P.Exp() diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 1bbc2c965b..1da5dc0ddb 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.ops.primitive import constexpr import mindspore.context as context -from mindspore._checkparam import Validator, check_typename +from mindspore._checkparam import Validator as validator from mindspore._extends import cell_attr_register from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management @@ -52,7 +52,7 @@ class _BatchNorm(Cell): if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) - self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics @@ -67,7 +67,7 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) - self.group = Validator.check_positive_int(device_num_each_group) + self.group = validator.check_positive_int(device_num_each_group) self.is_global = False if self.group != 1: self.rank_id = get_rank() @@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics, device_num_each_group, input_dims='both') - self.group = Validator.check_positive_int(device_num_each_group) + self.group = validator.check_positive_int(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") @@ -607,12 +607,12 @@ class GroupNorm(Cell): def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): super(GroupNorm, self).__init__() - self.num_groups = Validator.check_positive_int(num_groups) - self.num_channels = Validator.check_positive_int(num_channels) + self.num_groups = validator.check_positive_int(num_groups) + self.num_channels = validator.check_positive_int(num_channels) if num_channels % num_groups != 0: raise ValueError("num_channels should be divided by num_groups") - self.eps = check_typename('eps', eps, (float,)) - self.affine = Validator.check_bool(affine) + self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__) + self.affine = validator.check_bool(affine) gamma = initializer(gamma_init, num_channels) beta = initializer(beta_init, num_channels) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 7d7090b425..bf1532da6b 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range, num_channels=num_channels) - Validator.check_type("min_init", min_init, [int, float]) - Validator.check_type("max_init", max_init, [int, float]) + Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__) + Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) Validator.check_non_negative_int(quant_delay, 'quant_delay') self.min_init = min_init diff --git a/mindspore/nn/probability/bijector/gumbel_cdf.py b/mindspore/nn/probability/bijector/gumbel_cdf.py index d3c3308b56..eef7affc6a 100644 --- a/mindspore/nn/probability/bijector/gumbel_cdf.py +++ b/mindspore/nn/probability/bijector/gumbel_cdf.py @@ -68,7 +68,7 @@ class GumbelCDF(Bijector): """ param = dict(locals()) valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) diff --git a/mindspore/nn/probability/distribution/bernoulli.py b/mindspore/nn/probability/distribution/bernoulli.py index 9f20c60af4..d048cdadb9 100644 --- a/mindspore/nn/probability/distribution/bernoulli.py +++ b/mindspore/nn/probability/distribution/bernoulli.py @@ -119,7 +119,7 @@ class Bernoulli(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Bernoulli, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 7546598810..ea98bbaaa0 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -109,7 +109,7 @@ class Categorical(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type - Validator.check_type("Categorical", dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Categorical, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/exponential.py b/mindspore/nn/probability/distribution/exponential.py index 64e3a88363..907888f191 100644 --- a/mindspore/nn/probability/distribution/exponential.py +++ b/mindspore/nn/probability/distribution/exponential.py @@ -121,7 +121,7 @@ class Exponential(Distribution): param = dict(locals()) param['param_dict'] = {'rate': rate} valid_dtype = mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Exponential, self).__init__(seed, dtype, name, param) self._rate = self._add_parameter(rate, 'rate') diff --git a/mindspore/nn/probability/distribution/geometric.py b/mindspore/nn/probability/distribution/geometric.py index a7f087771a..ec1e9c94e3 100644 --- a/mindspore/nn/probability/distribution/geometric.py +++ b/mindspore/nn/probability/distribution/geometric.py @@ -122,7 +122,7 @@ class Geometric(Distribution): param = dict(locals()) param['param_dict'] = {'probs': probs} valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Geometric, self).__init__(seed, dtype, name, param) self._probs = self._add_parameter(probs, 'probs') diff --git a/mindspore/nn/probability/distribution/gumbel.py b/mindspore/nn/probability/distribution/gumbel.py index fca438a777..f598e29d55 100644 --- a/mindspore/nn/probability/distribution/gumbel.py +++ b/mindspore/nn/probability/distribution/gumbel.py @@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution): Constructor of Gumbel distribution. """ valid_dtype = mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) super(Gumbel, self).__init__( distribution=msd.Uniform(0.0, 1.0, dtype=dtype), diff --git a/mindspore/nn/probability/distribution/logistic.py b/mindspore/nn/probability/distribution/logistic.py index 7dedc64515..1033f4de95 100644 --- a/mindspore/nn/probability/distribution/logistic.py +++ b/mindspore/nn/probability/distribution/logistic.py @@ -111,7 +111,7 @@ class Logistic(Distribution): param = dict(locals()) param['param_dict'] = {'loc': loc, 'scale': scale} valid_dtype = mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Logistic, self).__init__(seed, dtype, name, param) self._loc = self._add_parameter(loc, 'loc') diff --git a/mindspore/nn/probability/distribution/normal.py b/mindspore/nn/probability/distribution/normal.py index 189e4f36fc..6a4949084e 100644 --- a/mindspore/nn/probability/distribution/normal.py +++ b/mindspore/nn/probability/distribution/normal.py @@ -127,7 +127,7 @@ class Normal(Distribution): param = dict(locals()) param['param_dict'] = {'mean': mean, 'sd': sd} valid_dtype = mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Normal, self).__init__(seed, dtype, name, param) self._mean_value = self._add_parameter(mean, 'mean') diff --git a/mindspore/nn/probability/distribution/uniform.py b/mindspore/nn/probability/distribution/uniform.py index 1324121fb0..3759349ec1 100644 --- a/mindspore/nn/probability/distribution/uniform.py +++ b/mindspore/nn/probability/distribution/uniform.py @@ -126,7 +126,7 @@ class Uniform(Distribution): param = dict(locals()) param['param_dict'] = {'low': low, 'high': high} valid_dtype = mstype.float_type - Validator.check_type(type(self).__name__, dtype, valid_dtype) + Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) super(Uniform, self).__init__(seed, dtype, name, param) self._low = self._add_parameter(low, 'low') diff --git a/mindspore/ops/operations/_cache_ops.py b/mindspore/ops/operations/_cache_ops.py index f30e9a5a5d..fa9e99c53e 100644 --- a/mindspore/ops/operations/_cache_ops.py +++ b/mindspore/ops/operations/_cache_ops.py @@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer): return [1] def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): - args = {"indices": indices_dtype} - validator.check_tensor_type_same(args, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) return input_x_dtype @@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): args = {"hashmap": hashmap_dtype, "indices": indices_dtype} - validator.check_tensor_type_same(args, mstype.int_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) return out_dtype @@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer): return out_shape def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): - args = {"miss_emb_idx": miss_emb_idx_dtype} - validator.check_tensor_type_same(args, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) return out_dtype @@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer): return miss_value_shape def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): - args = {"swap_cache_idx": swap_cache_idx_dtype} - validator.check_tensor_type_same(args, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) return miss_value_dtype @@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): args = {"hashmap": hashmap_dtype, "indices": indices_dtype} - validator.check_tensor_type_same(args, mstype.int_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype, hashmap_dtype) return out_dtype diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index fc58c2c603..291d90713d 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """Operators for gradients.""" +from functools import partial from .. import signature as sig from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -23,6 +24,7 @@ from ...common import dtype as mstype from .. import functional as F from ... import context + class AbsGrad(PrimitiveWithInfer): """Computes gradients for abs operation.""" @@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer): def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x @@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer): def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x @@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer): def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x @@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer): def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x @@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, dout_dtype): args = {"x": x_dtype, "dout": dout_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) return x_dtype @@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, dout_dtype): args = {"x": x_dtype, "dout": dout_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], + self.name) return x_dtype @@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, dout_dtype): args = {"x": x_dtype, "dout": dout_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) return x_dtype @@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, dout_dtype): args = {"x": x_dtype, "dout": dout_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) return x_dtype @@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer): def infer_dtype(self, x_type, y_type, doutput_type): args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return x_type, y_type @@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): def infer_dtype(self, x_type, y_type, doutput_type, weight_type): args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) if weight_type: validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) return x_type @@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): for i, dim_len in enumerate(w_size_v): validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) args = {"x": x['dtype'], "doutput": doutput['dtype']} - validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], + self.name) out = { 'value': None, 'shape': w_size_v, @@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): def __infer__(self, x, w_size, dout): w_size_v = w_size['value'] args = {'x': x['dtype'], 'dout': dout['dtype']} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) out = { 'value': None, 'shape': w_size_v, @@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): def __infer__(self, x_size, w, dout): args = {'w': w['dtype'], 'dout': dout['dtype']} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) x_size_v = x_size['value'] out = { 'value': None, @@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer): return dy_shape def infer_dtype(self, dy_dtype, mask_dtype): - valid_types = (mstype.float16, mstype.float32) - validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name) + valid_dtypes = (mstype.float16, mstype.float32) validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) - validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) + validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name) return dy_dtype @@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): - validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("y_backprop", "x", "y"), + (y_backprop_dtype, x_dtype, y_dtype))) return x_dtype @@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad): def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) return x1_dtype @@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad): def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) return grad_dtype @@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer): def infer_dtype(self, input_x, out, dout): args = {'input_x': input_x, 'out': out, 'dout': dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return input_x @@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer): def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, hx_dtype, cx_dtype, reserve_dtype, state_dtype): args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} - validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) return (dy_dtype, dy_dtype, dy_dtype) @@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} - validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) - validator.check_tensor_type_same(args, valid_types, self.name) + validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name) + validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name) + validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) if seq_dtype is not None: - validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name) if mask_dtype is not None: - validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) + validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name) return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype @@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer): return y_backprop_shape, w_shape def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ('y_backprop', "input_x", "weight"), + (y_backprop_dtype, A_dtype, w_dtype))) return y_backprop_dtype, w_dtype @@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) return x_dtype @@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer): return gradients_shape def infer_dtype(self, gradients_dtype, mask_dtype): - validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) - validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) + validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name) return gradients_dtype @@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer): def infer_dtype(self, y_grad_dtype, x_dtype): args = {'y_grad': y_grad_dtype, 'x': x_dtype} - validator.check_tensor_type_same(args, mstype.float_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) return x_dtype @@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer): def infer_dtype(self, out, dout): args = {'out': out, 'dout': dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return out @@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) return x_dtype @@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) return x_dtype @@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): def infer_dtype(self, x_dtype, y_dtype, dout_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return dout_dtype @@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer): def infer_dtype(self, prediction, target, dloss): args = {"prediction": prediction, "target": target, 'dloss': dloss} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return dloss @@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer): self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) def __infer__(self, dy, shapex, begin, end, strides): - args = {"dy": dy['dtype']} - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name) for idx, item in enumerate(shapex['value']): validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) @@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer): def infer_dtype(self, dout_dtype, x_dtype): args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} - validator.check_tensor_type_same(args, mstype.float_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) return x_dtype @@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer): def infer_dtype(self, out, dout): args = {"out": out, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return out @@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer): def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x @@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, y): args = {"grads": grads, "x": x, "y": y} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name) return x def infer_shape(self, grads, x, y): diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 5a4e36dd23..05989b9cd0 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksizes, strides, rates, padding="valid"): """init""" + def _check_tuple_or_list(arg_name, arg_val, prim_name): validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: @@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer): def infer_dtype(self, input_x): """infer dtype""" - validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name) return input_x @@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) return x_dtype @@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer): >>> dequant = P.Dequant(False, False) >>> y = dequant(input_x) """ + @prim_attr_register def __init__(self, sqrt_mode=False, relu_flag=False): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) @@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer): return assist def infer_dtype(self, assist, start, stop, num): - args = {"num": num} - validator.check_tensor_type_same(args, (mstype.int32,), self.name) + validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name) args = {"assist": assist, "start": start, "stop": stop} - validator.check_tensor_type_same(args, (mstype.float32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name) return assist @@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer): def infer_dtype(self, x_dtype, assist_dtype): valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] args = {"x": x_dtype, "assist": assist_dtype} - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, assist_shape): validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) - validator.check('rank of x', len(x_shape)+1, + validator.check('rank of x', len(x_shape) + 1, 'rank of assist', len(assist_shape), Rel.LE, self.name) validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', assist_shape[-1], Rel.EQ, self.name) @@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer): while r_idx >= r_end_dim: if x_shape[r_idx] != 1: validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % - assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) + assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name) r_idx = r_idx - 1 return assist_shape @@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer): def infer_dtype(self, x_dtype, assist_dtype): valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] args = {"x": x_dtype, "assist": assist_dtype} - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, assist_shape): @@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer): def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, diagonal_shape, assist_shape): @@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer): return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): - validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) - validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) - validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) + validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) + validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) b_dtype = mstype.float32 if binput_dtype is not None: - validator.check_tensor_type_same({"bias input dtype": binput_dtype}, - (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, + (mstype.float16, mstype.float32), self.name) b_dtype = binput_dtype elif bhidden_dtype is not None: - validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, - (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, + (mstype.float16, mstype.float32), self.name) b_dtype = bhidden_dtype elif h_dtype is not None: - validator.check_tensor_type_same({"init_h dtype": h_dtype}, - (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("init_h dtype", h_dtype, + (mstype.float16, mstype.float32), self.name) b_dtype = h_dtype return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index f6df84f18e..c69db7ffaf 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """Operators for quantization.""" +from functools import partial import mindspore.context as context from ..._checkparam import Validator as validator @@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): return min_shape, max_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return min_type, max_type @@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): return min_shape, max_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return min_type, max_type @@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): >>> input_tensor, min_tensor, max_tensor) >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 """ + @prim_attr_register def __init__(self, num_bits=8, @@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) - validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) - validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return x_type @@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): >>> min_gradient shape: (1,) data type: mstype.float32 >>> max_gradient shape: (1,) data type: mstype.float32 """ + @prim_attr_register def __init__(self, num_bits=8, @@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): return x_shape, min_shape, max_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) - validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) - validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) - validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ('dout', "x", "min", "max"), + (dout_type, x_type, min_type, max_type))) return x_type, min_type, max_type @@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): >>> input_tensor, min_tensor, max_tensor) >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 """ + @prim_attr_register def __init__(self, num_bits=8, @@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) - validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) - validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return x_type @@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): >>> min_gradient shape: (4,) data type: mstype.float32 >>> max_gradient shape: (4,) data type: mstype.float32 """ + @prim_attr_register def __init__(self, num_bits=8, @@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): return x_shape, min_shape, max_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) - validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) - validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) - validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("dout", "x", "min", "max"), + (dout_type, x_type, min_type, max_type))) return x_type, min_type, max_type @@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer): def infer_dtype(self, x_type, min_type, max_type): if context.get_context('device_target') == "GPU": - valid_types = (mstype.float32,) + valid_dtypes = (mstype.float32,) else: - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return x_type @@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type, min_type, max_type): if context.get_context('device_target') == "GPU": - valid_types = (mstype.float32,) + valid_dtypes = (mstype.float32,) else: - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"dout": dout_type}, valid_types, self.name) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "min", "max"), + (dout_type, x_type, min_type, max_type))) return dout_type @@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): def infer_dtype(self, x_type, min_type, max_type): if context.get_context('device_target') == "GPU": - valid_types = (mstype.float32,) + valid_dtypes = (mstype.float32,) else: - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("x", "min", "max"), + (x_type, min_type, max_type))) return x_type @@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type, min_type, max_type): if context.get_context('device_target') == "GPU": - valid_types = (mstype.float32,) + valid_dtypes = (mstype.float32,) else: - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same( - {"dout": dout_type}, valid_types, self.name) - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"min": min_type}, valid_types, self.name) - validator.check_tensor_type_same( - {"max": max_type}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), + ("dout", "x", "min", "max"), + (dout_type, x_type, min_type, max_type))) return dout_type @@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer): validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "variance type", variance_type) args = {"x": x_type, "mean": mean_type, "variance": variance_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) return x_type, x_type, x_type, x_type @@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer): global_step_type): args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, "batch_mean": batch_mean_type, "batch_std": batch_std_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) return x_type @@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer): def infer_dtype(self, x_type, batch_std_type, running_std_type): args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return x_type @@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer): def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) if context.get_context('device_target') == "Ascend": return x_type, x_type return x_type, gamma_type @@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer): running_mean_type, global_step_type): args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) return x_type @@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): "dout type", dout_type) args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type @@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer): validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "variance type", variance_type) args = {"x": x_type, "mean": mean_type, "variance": variance_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return x_type, x_type, x_type, x_type, x_type, x_type, x_type @@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer): validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) validator.check("input type", x_type, "batch_mean type", batch_mean_type) validator.check("input type", x_type, "batch_std type", batch_std_type) - args = {"input type": x_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name) return x_type @@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer): def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, "beta": beta_type, "gamma": gamma_type, "x": x_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return x_type @@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer): "dout type", dout_type) args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, "running_std": running_std_type, "dout": dout_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return gamma_type, gamma_type, gamma_type, gamma_type diff --git a/mindspore/ops/operations/_thor_ops.py b/mindspore/ops/operations/_thor_ops.py index a757a37fca..83e2a8165f 100644 --- a/mindspore/ops/operations/_thor_ops.py +++ b/mindspore/ops/operations/_thor_ops.py @@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): def infer_shape(self, data1_shape): ll = [] if len(data1_shape) == 2: - ll = [1,] + ll = [1] else: ll = [32, 64] return ll @@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer): >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) >>> output = img2col(input_x) """ + @prim_attr_register def __init__(self, kernel_size, @@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype): - args = {'x': x_dtype} - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) return x_dtype @@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer): return x2_shape def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): - validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, - [mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid( + {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, + [mstype.float32], self.name) return x2_dtype + class Cholesky(PrimitiveWithInfer): """ Inner API for resnet50 THOR GPU backend """ + @prim_attr_register def __init__(self, split_dim=0): self.init_prim_io_names(inputs=['x1'], outputs=['y']) @@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer): return out_shape def infer_dtype(self, x1_dtype): - validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) return x1_dtype + class DetTriangle(PrimitiveWithInfer): """ Calculate the determinant of triangle matrices """ + @prim_attr_register def __init__(self, fill_mode=0): self.init_prim_io_names(inputs=['x1'], outputs=['y']) @@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer): return out_shape def infer_dtype(self, x1_dtype): - validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) return x1_dtype diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d1ced6afd8..5caee52bb6 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) args = {"x": x_dtype, "updates": updates_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x_dtype @@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp): """ Defines _ScatterNd operators """ + def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): validator.check('the dimension of x', len(x_shape), 'the dimension of indices', indices_shape[-1], Rel.GE) @@ -627,6 +628,7 @@ class Unique(Primitive): >>> out = P.Unique()(x) (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) """ + @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['x'], outputs=['output']) @@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck): def __init__(self): """Initialize index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) - self.add_prim_attr("dynamic_shape_depends", [2,]) + self.add_prim_attr("dynamic_shape_depends", [2]) def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) axis_v = axis['value'] params_shp = params['shape'] @@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer): >>> out = P.Padding(pad_dim_size)(x) [[8, 0, 0, 0], [10, 0, 0, 0]] """ + @prim_attr_register def __init__(self, pad_dim_size=8): """Initialize padding""" @@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer): >>> out = P.UniqueWithPad()(x, pad_num) ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) """ + @prim_attr_register def __init__(self): """init UniqueWithPad""" def __infer__(self, x, pad_num): - validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name) + validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name) validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) x_shape = list(x['shape']) validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) @@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer): def __init__(self, seed=0, dtype=mstype.float32): """Initialize TruncatedNormal""" validator.check_value_type('seed', seed, [int], self.name) - validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name) + validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name) def __infer__(self, shape): shape_value = shape['value'] @@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer): validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) for i, item in enumerate(dims['value']): validator.check_positive_int(item, f'dims[{i}]', self.name) - valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint32, mstype.uint64, - mstype.float16, mstype.float32, mstype.float64] - validator.check_type_same({"value": dtype['value']}, valid_types, self.name) + valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, + mstype.uint8, mstype.uint32, mstype.uint64, + mstype.float16, mstype.float32, mstype.float64] + validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) ret = np.full(dims['value'], x['value'], x_nptype) out = { @@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) return x_dtype @@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) return x_dtype @@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer): """Initialize Argmax""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int], self.name) - validator.check_type_same({'output': output_type}, [mstype.int32], self.name) + validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name) self.axis = axis self.add_prim_attr('output_type', output_type) @@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): def __init__(self): """Initialize UnsortedSegmentSum""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) - self.add_prim_attr("dynamic_shape_depends", [2,]) + self.add_prim_attr("dynamic_shape_depends", [2]) def __infer__(self, x, segment_ids, num_segments): x_type = x['dtype'] @@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): num_segments_type = num_segments['dtype'] validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) if isinstance(num_segments_type, type(mstype.tensor)): - validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name) shp = [-1] else: validator.check_value_type('num_segments', num_segments_v, [int], self.name) @@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer): x_shape = x['shape'] segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] - validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) - validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) + validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check(f'first shape of input_x', x_shape[0], 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) @@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer): x_shape = x['shape'] segment_ids_shape = segment_ids['shape'] valid_type = [mstype.float16, mstype.float32, mstype.int32] - validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) - validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check(f'first shape of input_x', x_shape[0], 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) @@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer): validator.check_subclass("input_x", x_type, mstype.tensor, self.name) validator.check_value_type("x_shape", x_shape, [list], self.name) valid_type = [mstype.float16, mstype.float32, mstype.int32] - validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) - validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) + validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) validator.check(f'first shape of input_x', x_shape[0], 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) @@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer): validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) first_elem = x_shp[0] for i, elem in enumerate(x_shp[1:]): @@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) return x_dtype @@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer): self.add_prim_attr('T', x_type) validator.check_subclass("x_type", x_type, mstype.tensor, self.name) validator.check_subclass("y_type", y_type, mstype.tensor, self.name) - validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name) + validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) if x_type != y_type: raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) return x_type @@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer): validator.check_positive_int(n, "n", self.name) validator.check_positive_int(m, "m", self.name) args = {"dtype": t} - validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) np_type = mstype.dtype_to_nptype(t) ret = np.eye(n, m, dtype=np_type) return Tensor(ret) @@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer): def __infer__(self, indices, update, shape): shp = shape['value'] validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name) validator.check_value_type("shape", shp, [tuple], self.name) for i, x in enumerate(shp): validator.check_positive_int(x, f'shape[{i}]', self.name) @@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): validator.check_non_negative_int(value, f'{i}th value of size', self.name) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) - def infer_shape(self, x): - validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) - return tuple(x)[:-2] + tuple(self.size) + def infer_shape(self, x_shape): + validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name) + return tuple(x_shape)[:-2] + tuple(self.size) - def infer_dtype(self, x): - validator.check_subclass("x", x, mstype.tensor, self.name) - validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name) - return x + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) + return x_dtype class GatherNd(PrimitiveWithInfer): @@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer): return indices_shape[:-1] + x_shape[indices_shape[-1]:] def infer_dtype(self, x_dtype, indices_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) return x_dtype @@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp): self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp): [0.4 2.2 -3.2]] """ - @prim_attr_register def __init__(self, use_locking=True): """Initialize ScatterNdUpdate""" @@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp): self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) args = {"x": x_dtype, "value": value_dtype} - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp): self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) args = {"x": x_dtype, "updates": updates_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name) return x_dtype @@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer): self.paddings = paddings def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) return x_dtype def infer_shape(self, x_shape): @@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer): self.crops = crops def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) return x_dtype def infer_shape(self, x_shape): @@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer): self.add_prim_attr("paddings", paddings_append) def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) return x_dtype def infer_shape(self, x_shape): @@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer): self.add_prim_attr("crops", crops_append) def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) return x_dtype def infer_shape(self, x_shape): @@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer): out_shape = tuple(tuple(shape_0) for _ in range(n)) return out_shape - def infer_dtype(self, x_type): validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) n = len(x_type) @@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer): validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) return x_type + class InplaceUpdate(PrimitiveWithInfer): r""" Updates specified rows with values in `v`. @@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer): def infer_dtype(self, x_dtype, v_dtype): args = {'x': x_dtype, 'v': v_dtype} valid_type = [mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, v_shape): @@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer): return x def infer_dtype(self, x, seq_lengths): - validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) - validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) + validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name) return x @@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer): validator.check_const_input('truth_shape', truth_shape['value'], self.name) args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} - validator.check_tensor_type_same(args_int, [mstype.int64], self.name) + validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name) args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) @@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer): Outputs: Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. """ + @prim_attr_register def __init__(self): self.__setattr_flag__ = True @@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer): def __infer__(self, x, shape): shp = shape['value'] dtype = x['dtype'] - validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name) self.add_prim_attr('out_shape', tuple(shp)) return {'shape': shp, 'dtype': dtype, @@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer): return x_shape, x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name) + validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name) return x_dtype, mstype.tensor_type(mstype.int32) @@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer): >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] """ + @prim_attr_register def __init__(self): """Initialize index_select""" @@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer): def __infer__(self, params, indices, offset): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) params_shp = params['shape'] if len(params_shp) != 2: @@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer): >>> out = P.GatherD()(x, dim, index) [[1, 1], [4, 3]] """ + @prim_attr_register def __init__(self): """Initialize GatherD""" @@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer): def __infer__(self, x, dim, index): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name) + validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name) validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) x_shp = x['shape'] idx_shp = index['shape'] @@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer): >>> y = P.Identity()(x) [1, 2, 3, 4] """ + @prim_attr_register def __init__(self): """Initialize identity""" diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 4a06ed9b7f..9017c72b9d 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype @@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype def __call__(self, tensor): @@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer): if not isinstance(x_dtype, tuple): raise TypeError(f"{self.name}'s input should be a tuple!") for _ele in x_dtype: - validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name) return x_dtype @@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) return x_dtype def __call__(self, tensor): diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index acdf8ff548..362975958a 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer): def infer_dtype(self, data_type, pred_type): validator.check_subclass( "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) - validator.check_tensor_type_same( - {"pred": pred_type}, [mstype.bool_], self.name) + validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name) return (data_type, data_type) @@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer): for i, item in enumerate(inputs): args['inputs[%d]' % i] = item - validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name) return (inputs[0], mstype.int32) diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 6dcf8dc8a4..cf2c4424ac 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer): return [1] def infer_dtype(self, condition, inputs): - validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name) + validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name) for dtype in inputs: validator.check_subclass("input", dtype, [mstype.tensor], self.name) return mstype.int32 diff --git a/mindspore/ops/operations/image_ops.py b/mindspore/ops/operations/image_ops.py index 3041508daa..3cdacfade1 100644 --- a/mindspore/ops/operations/image_ops.py +++ b/mindspore/ops/operations/image_ops.py @@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer): box_index_dtype = box_index['dtype'] crop_size_dtype = crop_size['dtype'] # check dytpe - validator.check_tensor_type_same({"x": x_dtype}, - [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, - mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) - validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) - validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("x", x_dtype, + [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, + mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) + validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name) + validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name) validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) # check input shape rank validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 97323cbeb1..78048534de 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -16,6 +16,8 @@ """Operators for math.""" import copy +from functools import partial + import numpy as np from ... import context from .. import signature as sig @@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp): @staticmethod def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): args_type = {"x": x_dtype, "y": y_dtype} - validator.check_tensor_type_same(args_type, valid_dtype, prim_name) + validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name) return x_dtype def infer_dtype(self, x_dtype, y_dtype): @@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp): @staticmethod def _check_bitwise_op_input_type(x1_type, x2_type, prim): args = {'x1': x1_type, 'x2': x2_type} - valid_types = mstype.int_type + mstype.uint_type - validator.check_tensor_type_same(args, valid_types, prim) + valid_dtypes = mstype.int_type + mstype.uint_type + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim) return x1_type def infer_dtype(self, x1_type, x2_type): @@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer): def infer_dtype(self, variable, value): args = {"variable": variable, "value": value} - validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) + validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) return value @@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer): def infer_dtype(self, variable, value): args = {"variable": variable, "value": value} - validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) + validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) return value @@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer): axis_v = axis['value'] input_shp = input_x['shape'] args = {'input_x': input_x['dtype']} - validator.check_tensor_type_same(args, valid_dtype, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name) if axis_v is None: raise ValueError(f"For {self.name}, axis must be const.") @@ -504,6 +506,7 @@ class ReduceMax(_Reduce): def __infer__(self, input_x, axis): return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) + class ReduceMin(_Reduce): """ Reduce a dimension of a tensor by the minimum value in the dimension. @@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer): def infer_dtype(self, x_type, axis_type): cls_name = self.name - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) + validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name) return x_type @@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer): def infer_dtype(self, x1, x2): args = {"x1": x1, "x2": x2} - validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) if x1.element_type() == mstype.int8: return mstype.tensor_type(mstype.int32) return x1 @@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer): self.axes = axes validator.check_value_type('axes', axes, [int, tuple, list], self.name) if not isinstance(self.axes, int): - self.axes = list(self.axes) # to avoid immutability issues + self.axes = list(self.axes) # to avoid immutability issues if len(self.axes) != 2: raise ValueError("Require two axes inputs, given less") - self.int_to_tuple_conv() # convert before length checks + self.int_to_tuple_conv() # convert before length checks if len(self.axes[0]) != len(self.axes[1]): raise ValueError("Axes have to be the same size/length") if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): @@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer): if isinstance(self.axes, int): if self.axes <= 0: # outer product, no input validation required - self.axes = ([], []) # no axes selected for either + self.axes = ([], []) # no axes selected for either return if self.axes > len(x1_shape) or self.axes > len(x2_shape): raise ValueError( @@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer): def infer_dtype(self, x1, x2): args = {"x1": x1, "x2": x2} - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) return x1 @@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer): if axis['value'] is None: raise ValueError(f"For {self.name}, axis must be const.") validator.check_value_type('axis', axis['value'], [int], cls_name) - valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) + valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name) return {'shape': x_shp, 'dtype': x['dtype'], 'value': None} @@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer): if dtype == mstype.undetermined: contains_undetermined = True if not contains_undetermined: - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) return inputs[0] def infer_value(self, inputs): @@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer): args = {} for i, dtype in enumerate(inputs): args[f"inputs[{i}]"] = dtype - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) return inputs[0] @@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer): """Initialize Neg""" self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, input_x): - return input_x + def infer_shape(self, x_shape): + return x_shape - def infer_dtype(self, input_x): - validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) - return input_x + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) + return x_dtype def infer_value(self, input_x): if input_x is not None: @@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer): def infer_dtype(self, x_dtype, v_dtype): args = {'x': x_dtype, 'v': v_dtype} valid_type = [mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, v_shape): @@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer): def infer_dtype(self, x_dtype, v_dtype): args = {'x': x_dtype, 'v': v_dtype} valid_type = [mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) return x_dtype def infer_shape(self, x_shape, v_shape): @@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape - def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) - return x_type + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) + return x_dtype def infer_value(self, x): if x is not None: @@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape - def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) - return x_type + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) + return x_dtype def infer_value(self, x): if x is not None: @@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck): self.init_prim_io_names(inputs=['x'], outputs=['output']) def check_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name) def infer_value(self, x): if x is not None: @@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_subclass("x", x_type, mstype.tensor, self.name) - validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) return x_type @@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer): return (self.nbins,) def infer_dtype(self, x_dtype, range_dtype): - validator.check_subclass("x", x_dtype, mstype.tensor, self.name) - valid_types = (mstype.float16, mstype.float32, mstype.int32) - validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32, mstype.int32) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name) y_dtype = mstype.int32 return y_dtype @@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer): def __init__(self): self.init_prim_io_names(inputs=['x'], outputs=['y']) - def infer_shape(self, x): - return x + def infer_shape(self, x_shape): + return x_shape - def infer_dtype(self, x): - validator.check_subclass("x", x, mstype.tensor, self.name) - validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name) - return x + def infer_dtype(self, x_dtype): + validator.check_subclass("x", x_dtype, mstype.tensor, self.name) + validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) + return x_dtype class Erf(PrimitiveWithInfer): @@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape - def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) - return x_type + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) + return x_dtype class Erfc(PrimitiveWithInfer): @@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) return x_type @@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) + validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name) return x_dtype @@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) return x_dtype @@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp): @staticmethod def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): args_dtype = {"x": x_dtype, "y": y_dtype} - validator.check_tensor_type_same(args_dtype, valid_type, prim_name) + validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name) return mstype.tensor_type(mstype.bool_) def infer_dtype(self, x_dtype, y_dtype): @@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp): def infer_dtype(self, x_dtype, y_dtype): args_dtype = {"x": x_dtype, "y": y_dtype} valid_type = [mstype.float32, mstype.float16] - validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name) + validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name) return mstype.tensor_type(mstype.bool_) @@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer): def infer_dtype(self, x_dtype, y_dtype): args = {'x': x_dtype, 'y': y_dtype} - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) return x_dtype @@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) + validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name) return mstype.tensor_type(mstype.bool_) @@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_subclass("x", x_dtype, mstype.tensor, self.name) - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) return mstype.bool_ @@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer): return [1] def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name) + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name) return x_dtype @@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer): return [8] def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) return mstype.float32 @@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer): return [8] def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) return mstype.float32 @@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer): return (bboxes_shape, (num,), (num,)) def infer_dtype(self, bboxes_dtype): - validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name) return (bboxes_dtype, mstype.int32, mstype.bool_) @@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) return x_type def infer_value(self, x): @@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) return x_dtype @@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape - def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) - return x_type + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) + return x_dtype class Tan(PrimitiveWithInfer): @@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - valid_types = [mstype.float16, mstype.float32, mstype.int32] - validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name) return x_type @@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) return x_type @@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) return x_type @@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer): return [], [] def infer_dtype(self, x_type, y_type): - validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name) + validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name) return x_type, y_type @@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) return x @@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer): return x def infer_dtype(self, x): - validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) return x @@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32, + validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32, mstype.int32], self.name) return x_dtype @@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name) + validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name) return x_dtype @@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer): self.init_prim_io_names(inputs=['input_x'], outputs=['y']) def __infer__(self, input_x): - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name) x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) if x_nptype == np.float16: @@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer): return (1,), (1,) def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): - valid_types = [mstype.float32, mstype.float16] - validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("input_value", "input_min", "input_max"), + (data_dtype, data_min_dtype, data_max_dtype))) + validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 68d9154a7f..fb34e62dbf 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -17,7 +17,7 @@ import math import operator -from functools import reduce +from functools import reduce, partial import numpy as np from ... import context from .. import signature as sig @@ -153,8 +153,7 @@ class Softmax(PrimitiveWithInfer): return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor, self.name) - validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name) + validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name) return logits @@ -197,8 +196,7 @@ class LogSoftmax(PrimitiveWithInfer): return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor, self.name) - validator.check_tensor_type_same({"logits": logits}, mstype.float_type, self.name) + validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name) return logits @@ -230,12 +228,12 @@ class Softplus(PrimitiveWithInfer): """Initialize Softplus""" self.init_prim_io_names(inputs=['x'], outputs=['output']) - def infer_shape(self, input_x): - return input_x + def infer_shape(self, x_shape): + return x_shape - def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) - return input_x + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid('x', x_dtype, mstype.float_type, self.name) + return x_dtype class Softsign(PrimitiveWithInfer): @@ -269,7 +267,7 @@ class Softsign(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('input_x', input_x, [mstype.float16, mstype.float32], self.name) return input_x @@ -301,7 +299,7 @@ class ReLU(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) return input_x @@ -332,7 +330,7 @@ class ReLU6(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name) return input_x @@ -384,7 +382,7 @@ class ReLUV2(PrimitiveWithInfer): output_shape = (input_x['shape'], mask_shape) validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) - validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid('input_x', input_dtype, mstype.number_type, self.name) mask_dtype = mstype.uint8 output_dtype = (input_dtype, mask_dtype) @@ -426,7 +424,7 @@ class Elu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) + validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name) return input_x @@ -463,7 +461,7 @@ class HSwish(PrimitiveWithInfer): return xshape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -499,7 +497,7 @@ class Sigmoid(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) return input_x @@ -536,7 +534,7 @@ class HSigmoid(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -733,12 +731,12 @@ class FusedBatchNormEx(PrimitiveWithInfer): return (input_x, scale, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): - validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) args = {"scale": scale, "bias": bias} - validator.check_tensor_type_same(args, [mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) args_moving = {"mean": mean, "variance": variance} - valid_types = [mstype.tensor_type(mstype.float32)] - validator.check_type_same(args_moving, valid_types, self.name) + valid_dtypes = [mstype.tensor_type(mstype.float32)] + validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) return (input_x, scale, scale, scale, scale, scale) @@ -769,7 +767,7 @@ class BNTrainingReduce(PrimitiveWithInfer): return ([x_shape[1]], [x_shape[1]]) def infer_dtype(self, x_type): - validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) return (x_type, x_type) @@ -819,6 +817,7 @@ class BNTrainingUpdate(PrimitiveWithInfer): >>> bn_training_update = P.BNTrainingUpdate() >>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance) """ + @prim_attr_register def __init__(self, isRef=True, epsilon=1e-5, factor=0.1): self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'], @@ -846,13 +845,10 @@ class BNTrainingUpdate(PrimitiveWithInfer): return (x, variance, variance, variance, variance) def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance): - validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x", "sum", "square_sum", "scale", "b", "mean", "variance"), + (x, sum, square_sum, scale, b, mean, variance))) return (x, variance, variance, variance, variance) @@ -928,16 +924,16 @@ class BatchNorm(PrimitiveWithInfer): return (input_x, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): - validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) args = {"scale": scale, "bias": bias} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) args_moving = {"mean": mean, "variance": variance} if self.is_training: - valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] - validator.check_type_same(args_moving, valid_types, self.name) + valid_dtypes = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] + validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) else: args_moving = {"mean": mean, "variance": variance} - validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) return (input_x, scale, bias, input_x, input_x) @@ -1053,7 +1049,7 @@ class Conv2D(PrimitiveWithInfer): validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name) validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \ - Rel.EQ, self.name) + Rel.EQ, self.name) validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name) @@ -1084,24 +1080,24 @@ class Conv2D(PrimitiveWithInfer): pad_top, pad_bottom, pad_left, pad_right = self.padding h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \ - * (dilation_h - 1)) / stride_h + * (dilation_h - 1)) / stride_h w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \ - * (dilation_w - 1)) / stride_w + * (dilation_w - 1)) / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) out_channel = self.out_channel - out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else\ + out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \ [x_shape_norm[0], h_out, w_out, out_channel] _check_shape('output', out_shape, self.name) return out_shape def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} - valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) if x_dtype.element_type() == mstype.int8: return mstype.tensor_type(mstype.int32) return x_dtype @@ -1220,9 +1216,9 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): pad_top, pad_bottom, pad_left, pad_right = self.padding h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ - / stride_h + / stride_h w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ - / stride_w + / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) @@ -1235,7 +1231,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) if x_dtype.element_type() == mstype.int8: return mstype.tensor_type(mstype.int32) return x_dtype @@ -1436,7 +1432,7 @@ class MaxPoolWithArgmax(_Pool): def infer_dtype(self, x_dtype): out_dtype = x_dtype - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 if self.is_gpu: argmax_dtype = mstype.int32 @@ -1604,12 +1600,12 @@ class Conv2DBackpropInput(PrimitiveWithInfer): for i, dim_len in enumerate(x_size_v): validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) args = {'doutput': doutput['dtype'], 'w': w['dtype']} - valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) # infer shape dout_shape = doutput['shape'] - dout_shape_norm = dout_shape if self.format == "NCHW" else\ + dout_shape_norm = dout_shape if self.format == "NCHW" else \ [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] kernel_h = self.kernel_size[0] kernel_w = self.kernel_size[1] @@ -1682,7 +1678,7 @@ class BiasAdd(PrimitiveWithInfer): def infer_dtype(self, x_type, b_type): args = {"input_x": x_type, "bias": b_type} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x_type @@ -1721,8 +1717,8 @@ class TopK(PrimitiveWithInfer): def __infer__(self, input_x, k): x_dtype = input_x['dtype'] - valid_types = (mstype.int32, mstype.float16, mstype.float32) - validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name) + valid_dtypes = (mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) k_v = k['value'] validator.check_value_type('k', k_v, (int,), self.name) x_shape = list(input_x['shape']) @@ -1774,7 +1770,7 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): def infer_dtype(self, logits_type, labels_type): args = {"logits": logits_type, "labels": labels_type} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return (logits_type, logits_type) @@ -1825,8 +1821,9 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): return loss_shape def infer_dtype(self, logits_type, labels_type): - validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name) - validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name) + validator.check_tensor_dtype_valid("logits", logits_type, (mstype.float16, mstype.float32), + self.name) + validator.check_tensor_dtype_valid("labels", labels_type, (mstype.int32, mstype.int64), self.name) return logits_type @@ -1886,13 +1883,13 @@ class ApplyMomentum(PrimitiveWithInfer): return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): - valid_types = [mstype.float16, mstype.float32, mstype.float64] + valid_dtypes = [mstype.float16, mstype.float32, mstype.float64] if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: - validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) + validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name) if not self.is_ge and self.is_tbe: return g_dtype, g_dtype return g_dtype @@ -1944,7 +1941,7 @@ class SmoothL1Loss(PrimitiveWithInfer): def infer_dtype(self, prediction, target): args = {"prediction": prediction, "target": target} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) return prediction @@ -1981,9 +1978,8 @@ class L2Loss(PrimitiveWithInfer): return loss_shape def infer_dtype(self, x_type): - validator.check_subclass("x_type", x_type, mstype.tensor, self.name) - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid('x_type', x_type, valid_dtypes, self.name) return x_type @@ -2019,11 +2015,10 @@ class DataFormatDimMap(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape - def infer_dtype(self, x_type): - validator.check_subclass("x", x_type, mstype.tensor, self.name) - valid_types = [mstype.int32] - validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) - return x_type + def infer_dtype(self, x_dtype): + valid_dtypes = [mstype.int32] + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) + return x_dtype class RNNTLoss(PrimitiveWithInfer): @@ -2065,21 +2060,18 @@ class RNNTLoss(PrimitiveWithInfer): validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name) validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name) validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) - validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) + validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2] - 1, Rel.EQ, self.name) validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) costs_shape = (acts_shape[0],) return (costs_shape, acts_shape) def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): - validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) - validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) - validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) - validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) - validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32, mstype.float16], self.name) - validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name) - validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name) - validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.int32,), prim_name=self.name), + ("labels", "input_length", "label_length"), + (labels_type, input_length_type, label_length_type))) return (acts_type, acts_type) @@ -2143,13 +2135,10 @@ class SGD(PrimitiveWithInfer): def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype): - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"), + (parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype))) return parameters_dtype @@ -2229,13 +2218,13 @@ class ApplyRMSProp(PrimitiveWithInfer): def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype, momentum_dtype, epsilon_dtype): args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype} - validator.check_type_same(args_decay, valid_types, self.name) + validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name) args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype} - validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) + validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) if not self.is_ge and self.is_d: return var_dtype, var_dtype, var_dtype return var_dtype @@ -2332,13 +2321,13 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype} - validator.check_type_same(args_rho, valid_types, self.name) + validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name) args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} - validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) + validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True) if self.is_ascend: return var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype return var_dtype @@ -2440,8 +2429,7 @@ class L2Normalize(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("x", input_x, mstype.tensor, self.name) - validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) return input_x @@ -2527,9 +2515,9 @@ class DropoutDoMask(PrimitiveWithInfer): raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:" "{input_x_shape}, mask shape: {mask_shape}.") - validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32], - self.name) - validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name) + validator.check_tensor_dtype_valid("input_x", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32], + self.name) + validator.check_tensor_dtype_valid("input_mask", mask['dtype'], [mstype.uint8], self.name) keep_prob_v = keep_prob['value'] if keep_prob_v is not None: @@ -2587,7 +2575,8 @@ class ResizeBilinear(PrimitiveWithInfer): return out_shape def infer_dtype(self, input_dtype): - validator.check_tensor_type_same({'input_dtype': input_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32], + self.name) return mstype.tensor_type(mstype.float32) @@ -2631,10 +2620,10 @@ class OneHot(PrimitiveWithInfer): def __infer__(self, indices, depth, on_value, off_value): # check type - validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], (mstype.int32,), self.name) validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} - validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) # check shape indices_shp = indices['shape'] @@ -2685,7 +2674,7 @@ class Gelu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) return input_x @@ -2804,9 +2793,9 @@ class PReLU(PrimitiveWithInfer): return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"input_x": input_x_dtype}, valid_types, self.name) - validator.check_tensor_type_same({"weight": weight_dtype}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name) return input_x_dtype @@ -2877,7 +2866,7 @@ class LSTM(PrimitiveWithInfer): def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} - validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) def rnd_up(self, current_offset, page_size): @@ -2930,7 +2919,7 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): def infer_dtype(self, x_dtype, y_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return x_dtype @@ -3123,9 +3112,9 @@ class ROIAlign(PrimitiveWithInfer): return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width] def infer_dtype(self, inputs_type, rois_type): - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same({"inputs_type": inputs_type}, valid_types, self.name) - validator.check_tensor_type_same({"rois_type": rois_type}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("inputs_type", inputs_type, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("rois_type", rois_type, valid_dtypes, self.name) return inputs_type @@ -3199,6 +3188,7 @@ class Adam(PrimitiveWithInfer): >>> gradient = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) >>> result = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient) """ + @prim_attr_register def __init__(self, use_locking=False, use_nesterov=False): validator.check_value_type("use_locking", use_locking, [bool], self.name) @@ -3214,11 +3204,11 @@ class Adam(PrimitiveWithInfer): def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} - validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) + validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) return var_dtype, m_dtype, v_dtype @@ -3345,12 +3335,12 @@ class FusedSparseAdam(PrimitiveWithInfer): def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype): args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} - validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) + validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name) return var_dtype, m_dtype, v_dtype @@ -3478,13 +3468,13 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype): args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} - validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) + validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name) return var_dtype, m_dtype, v_dtype @@ -3578,8 +3568,8 @@ class FusedSparseFtrl(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} - validator.check_tensor_type_same(args, [mstype.float32], self.name) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) + validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name) return var_dtype, accum_dtype, linear_dtype @@ -3665,13 +3655,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float32], self.name) - valid_types = [mstype.int16, mstype.int32, mstype.int64, - mstype.uint16, mstype.uint32, mstype.uint64] - validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float32], self.name) + valid_dtypes = [mstype.int16, mstype.int32, mstype.int64, + mstype.uint16, mstype.uint32, mstype.uint64] + validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name) return var_dtype, accum_dtype @@ -3742,8 +3732,8 @@ class KLDivLoss(PrimitiveWithInfer): def infer_dtype(self, x_type, y_type): args = {'x': x_type, 'y': y_type} - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) return x_type @@ -3820,10 +3810,10 @@ class BinaryCrossEntropy(PrimitiveWithInfer): def infer_dtype(self, x_type, y_type, weight_type): args = {'x': x_type, 'y': y_type} - valid_types = (mstype.float16, mstype.float32) - validator.check_tensor_type_same(args, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) if weight_type: - validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid({'x': x_type, 'weight': weight_type}, valid_dtypes, self.name) return x_type @@ -3950,14 +3940,14 @@ class ApplyAdaMax(PrimitiveWithInfer): def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"beta1_power": beta1_power_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"beta1": beta1_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"beta2": beta2_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name) return var_dtype, m_dtype, v_dtype @@ -4058,12 +4048,12 @@ class ApplyAdadelta(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype, epsilon_dtype, grad_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"rho": rho_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"epsilon": epsilon_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"rho": rho_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name) return var_dtype, accum_dtype, accum_update_dtype @@ -4142,9 +4132,9 @@ class ApplyAdagrad(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - valid_types = [mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, valid_dtypes, self.name) return var_dtype, accum_dtype @@ -4226,8 +4216,8 @@ class ApplyAdagradV2(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name) return var_dtype, accum_dtype @@ -4313,8 +4303,8 @@ class SparseApplyAdagrad(PrimitiveWithInfer): def infer_dtype(self, var_type, accum_type, grad_type, indices_type): args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name) return var_type, accum_type @@ -4402,8 +4392,8 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): def infer_dtype(self, var_type, accum_type, grad_type, indices_type): args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name) return var_type, accum_type @@ -4500,12 +4490,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): return var_shape, accum_shape def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name) return var_dtype, accum_dtype @@ -4594,13 +4584,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name) - valid_types = [mstype.int16, mstype.int32, mstype.int64, - mstype.uint16, mstype.uint32, mstype.uint64] - validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name) + validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name) + valid_dtypes = [mstype.int16, mstype.int32, mstype.int64, + mstype.uint16, mstype.uint32, mstype.uint64] + validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name) class ApplyAddSign(PrimitiveWithInfer): @@ -4699,13 +4689,13 @@ class ApplyAddSign(PrimitiveWithInfer): return var_shape, m_shape def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name) return var_dtype, m_dtype @@ -4808,13 +4798,13 @@ class ApplyPowerSign(PrimitiveWithInfer): return var_shape, m_shape def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"logbase": logbase_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"sign_decay": sign_decay_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"beta": beta_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name) return var_dtype, m_dtype @@ -4876,10 +4866,10 @@ class ApplyGradientDescent(PrimitiveWithInfer): return var_shape def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'delta': delta_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name) return var_dtype @@ -4959,12 +4949,12 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): return var_shape def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'delta': delta_dtype} - validator.check_tensor_type_same(args, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"alpha": alpha_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_dtype}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_dtype}, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name) return var_dtype @@ -5036,11 +5026,13 @@ class LARSUpdate(PrimitiveWithInfer): weight_decay_dtype, learning_rate_dtype): args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype, "norm gradient dtype": norm_gradient_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name) - validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype}, - [mstype.float16, mstype.float32, mstype.float64], self.name) - validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype}, - [mstype.float16, mstype.float32, mstype.float64], self.name) + validator.check_tensors_dtypes_same_and_valid(args, + [mstype.float16, mstype.float32, mstype.int16, mstype.int32], + self.name) + validator.check_scalar_or_tensor_types_same({"weight_decay": weight_decay_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) + validator.check_scalar_or_tensor_types_same({"learning_rate": learning_rate_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) return weight_dtype @@ -5117,14 +5109,14 @@ class ApplyFtrl(PrimitiveWithInfer): return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): - valid_types = [mstype.float16, mstype.float32] + valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} - validator.check_tensor_type_same(args, valid_types, self.name) + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) - validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) + validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name) + validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name) if self.is_tbe: return var_type, var_type, var_type return var_type @@ -5219,8 +5211,8 @@ class SparseApplyFtrl(PrimitiveWithCheck): def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32, mstype.int64], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32, mstype.int64], self.name) class SparseApplyFtrlV2(PrimitiveWithInfer): @@ -5316,8 +5308,8 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype): args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype, "linear_dtype": linear_dtype, "grad_dtype": grad_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"indices_dtype": indices_dtype}, [mstype.int32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensor_dtype_valid("indicese", indices_dtype, [mstype.int32], self.name) return var_dtype, accum_dtype, linear_dtype @@ -5351,9 +5343,8 @@ class Dropout(PrimitiveWithInfer): return x_shape, mask_shape def infer_dtype(self, x_dtype): - valid_types = (mstype.float16, mstype.float32) - validator.check_subclass("x", x_dtype, mstype.tensor, self.name) - validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) + valid_dtypes = (mstype.float16, mstype.float32) + validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) return x_dtype, x_dtype @@ -5425,10 +5416,10 @@ class CTCLoss(PrimitiveWithInfer): def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length): valid_dtype = [mstype.float16, mstype.float32, mstype.double] - validator.check_tensor_type_same({"inputs_dtype": inputs}, valid_dtype, self.name) - validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name) - validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) - validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("inputs", inputs, valid_dtype, self.name) + validator.check_tensor_dtype_valid("labels_indices", labels_indices, [mstype.int64], self.name) + validator.check_tensor_dtype_valid("labels_values", labels_values, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("sequence_length", sequence_length, [mstype.int32], self.name) return inputs, inputs @@ -5492,8 +5483,8 @@ class CTCGreedyDecoder(PrimitiveWithInfer): return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape def infer_dtype(self, inputs_dtype, sequence_length_dtype): - validator.check_tensor_type_same({"inputs_dtype": inputs_dtype}, [mstype.float32, mstype.double], self.name) - validator.check_tensor_type_same({"sequence_length_dtype": sequence_length_dtype}, [mstype.int32], self.name) + validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name) + validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name) decoded_type = mstype.tensor_type(mstype.int64) return decoded_type, decoded_type, decoded_type, inputs_dtype @@ -5597,12 +5588,12 @@ class BasicLSTMCell(PrimitiveWithInfer): return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape) def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype): - validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name) - validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name) - + tuple(map(partial(validator.check_tensor_dtype_valid, + valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), + ("x_dtype", "h_dtype", "w_dtype"), + (x_dtype, h_dtype, w_dtype))) args = {"c_dtype": c_dtype, "b_dtype": b_dtype} - validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype) @@ -5725,11 +5716,10 @@ class DynamicRNN(PrimitiveWithInfer): return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype): - validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) - validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name) - validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name) - validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name) - validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name) + tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=mstype.float16, prim_name=self.name), + ("x", "w", "h", "c"), + (x_dtype, w_dtype, h_dtype, c_dtype))) + validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name) return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype @@ -5765,8 +5755,8 @@ class InTopK(PrimitiveWithInfer): validator.check_value_type("k", k, [int], self.name) def infer_dtype(self, x1_dtype, x2_dtype): - validator.check_tensor_type_same({"x1": x1_dtype}, (mstype.float16, mstype.float32,), self.name) - validator.check_tensor_type_same({"x2": x2_dtype}, (mstype.int32,), self.name) + validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name) + validator.check_tensor_dtype_valid("x2", x2_dtype, (mstype.int32,), self.name) return mstype.tensor_type(mstype.bool_) @@ -5803,6 +5793,7 @@ class LRN(PrimitiveWithInfer): [[0.6258911 0.4964315 ] [0.3141494 0.43636137]]]] """ + @prim_attr_register def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"): """Initialize LRN""" @@ -5816,7 +5807,7 @@ class LRN(PrimitiveWithInfer): validator.check_non_negative_int(depth_radius, "depth_radius", self.name) def infer_dtype(self, x_dtype): - validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) + validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32,), self.name) return x_dtype def infer_shape(self, x_shape): @@ -5857,6 +5848,7 @@ class UniformSampler(PrimitiveWithInfer): [3]], dtype=np.int32))) [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] """ + @prim_attr_register def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): """Initialize UniformSampler""" diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index c7fba63528..15fdc6d088 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck): def check_dtype(self, variable, value): if variable != mstype.type_refkey: - validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) - validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) + validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name) + validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) class BoundingBoxEncode(PrimitiveWithInfer): @@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer): def infer_dtype(self, anchor_box, groundtruth_box): args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return anchor_box @@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer): def infer_dtype(self, anchor_box, deltas): args = {"anchor_box": anchor_box, "deltas": deltas} - validator.check_tensor_type_same(args, mstype.number_type, self.name) + validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) return anchor_box @@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer): def infer_dtype(self, bboxes_type, metas_type): valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] - validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name) - validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name) + validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name) + validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name) return mstype.bool_ @@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer): def infer_dtype(self, anchor_boxes, gt_boxes): valid_type = [mstype.float32, mstype.float16] - validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name) - validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name) + validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name) + validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name) return anchor_boxes @@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer): if weights is not None: validator.check_subclass('weights', weights, mstype.tensor, self.name) args = {"labels": labels, "predictions": predictions} - validator.check_tensor_type_same(args, (mstype.number_type), self.name) + validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name) return labels @@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - args = {"x": x_dtype} - validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) + validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) return mstype.tensor_type(mstype.uint8) class Push(PrimitiveWithInfer): diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index e05335800d..bd7670cde5 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer): Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) - Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) - Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) + Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name) + Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name) broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) out = { @@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer): Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) - Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) + Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name) broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) out = { 'shape': broadcast_shape, @@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer): Validator.check_value_type("shape", shape_v, [tuple], self.name) for i, shape_i in enumerate(shape_v): Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) - Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) - Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) + Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name) + Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name) minval_shape = minval['shape'] maxval_shape = maxval['shape'] Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) @@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): - Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) + Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) @@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer): def __infer__(self, logits, num_samples, seed): logits_dtype = logits['dtype'] - valid_types = (mstype.float32, mstype.float16, mstype.float64) - Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) + valid_dtypes = (mstype.float32, mstype.float16, mstype.float64) + Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name) num_samples_v = num_samples['value'] seed_v = seed['value'] Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) @@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer): input_shape = inputs["shape"] if len(input_shape) != 1 and len(input_shape) != 2: raise ValueError("input dim must be 1 or 2") - Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) + Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name) num_samples_value = num_samples["value"] if num_samples_value is None: raise ValueError(f"For {self.name}, shape nust be const") diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 78dd215fce..b9bd0f0259 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs): if quant_mode not in quant_mode_formats: raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') - mean = Validator.check_type("mean", mean, (int, float)) - std_dev = Validator.check_type("std_dev", std_dev, (int, float)) + mean = Validator.check_value_type("mean", mean, (int, float)) + std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) if context.get_context('device_target') not in supported_device: raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) diff --git a/tests/ut/python/ir/test_row_tensor.py b/tests/ut/python/ir/test_row_tensor.py index 1f36817077..b83f985ea8 100644 --- a/tests/ut/python/ir/test_row_tensor.py +++ b/tests/ut/python/ir/test_row_tensor.py @@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer): def __infer__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) axis_v = axis['value'] params_shp = params['shape']