Merge pull request !8114 from zhangbuxue/rectify_and_optimize_the_type_checking_functiontags/v1.1.0
| @@ -415,37 +415,20 @@ class Validator: | |||
| break | |||
| if not hit: | |||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||
| f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') | |||
| raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass' | |||
| f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_const_input(arg_name, arg_value, prim_name): | |||
| """Checks valid value.""" | |||
| if arg_value is None: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||
| raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_type(arg_name, arg_value, valid_types): | |||
| """Type checking.""" | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.') | |||
| if isinstance(arg_value, type(mstype.tensor)): | |||
| arg_value = arg_value.element_type() | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| raise_error_msg() | |||
| if arg_value in valid_types: | |||
| return arg_value | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return arg_value | |||
| raise_error_msg() | |||
| @staticmethod | |||
| def check_type_same(args, valid_values, prim_name): | |||
| """Checks whether the types of inputs are the same.""" | |||
| def _check_tensor_type(arg): | |||
| def check_types_same_and_valid(args, valid_values, prim_name): | |||
| """Checks whether the types of inputs are the same and valid.""" | |||
| def _check_type_valid(arg): | |||
| arg_key, arg_val = arg | |||
| elem_type = arg_val | |||
| Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) | |||
| @@ -455,21 +438,27 @@ class Validator: | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| if arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') | |||
| return arg1 | |||
| elem_types = map(_check_tensor_type, args.items()) | |||
| elem_types = map(_check_type_valid, args.items()) | |||
| reduce(_check_types_same, elem_types) | |||
| @staticmethod | |||
| def check_tensor_type_same(args, valid_values, prim_name): | |||
| """Checks whether the element types of input tensors are the same.""" | |||
| tensor_types = [mstype.tensor_type(t) for t in valid_values] | |||
| Validator.check_type_same(args, tensor_types, prim_name) | |||
| def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): | |||
| """Checks whether the element types of input tensors are the same and valid.""" | |||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | |||
| Validator.check_types_same_and_valid(args, tensor_types, prim_name) | |||
| @staticmethod | |||
| def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): | |||
| """Checks whether the element types of input tensors are valid.""" | |||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | |||
| Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) | |||
| @staticmethod | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||
| def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False): | |||
| """ | |||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | |||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | |||
| @@ -480,7 +469,7 @@ class Validator: | |||
| if isinstance(arg_val, type(mstype.tensor)): | |||
| arg_val = arg_val.element_type() | |||
| if not arg_val in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' | |||
| raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {arg_val}.') | |||
| return arg | |||
| @@ -512,40 +501,40 @@ class Validator: | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] | |||
| num_types = len(valid_types) | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| msg_prefix = f"For '{prim_name}', the" if prim_name else "The" | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| f'{type_names if num_types > 1 else type_names[0]}, ' | |||
| f'but got {arg_value} with type {type(arg_value).__name__}.') | |||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | |||
| # `check_value_type('x', True, [bool, int])` will check pass | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| raise_error_msg() | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return arg_value | |||
| raise_error_msg() | |||
| if not isinstance(arg_value, tuple(valid_types)): | |||
| raise_error_msg() | |||
| return arg_value | |||
| @staticmethod | |||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | |||
| """Checks whether a type in some specified types""" | |||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] | |||
| num_types = len(valid_types) | |||
| msg_prefix = f"For '{prim_name}', the" if prim_name else "The" | |||
| raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" | |||
| f"{type_names if num_types > 1 else type_names[0]}, " | |||
| f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.") | |||
| if isinstance(arg_type, type(mstype.tensor)): | |||
| arg_type = arg_type.element_type() | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| if len(valid_types) == 1: | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| if arg_type not in valid_types: | |||
| raise_error_msg() | |||
| return arg_type | |||
| @staticmethod | |||
| def check_reduce_shape(ori_shape, shape, axis, prim_name): | |||
| @@ -611,65 +600,6 @@ def check_output_data(data): | |||
| once = _expand_tuple(1) | |||
| twice = _expand_tuple(2) | |||
| triple = _expand_tuple(3) | |||
| valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64, | |||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | |||
| np.float32, np.float64, bool, np.bool_) | |||
| def check_type(arg_name, arg_value, valid_types): | |||
| """Check value type.""" | |||
| # if input type is Tensor ,get element type | |||
| if isinstance(arg_value, type(mstype.tensor)): | |||
| arg_value = arg_value.element_type() | |||
| # First, check if arg_value has argvalid_types | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return type(arg_value).__name__ | |||
| # Second, wrap arg_value with numpy array so that it can be checked through numpy api | |||
| if isinstance(arg_value, (list, tuple)): | |||
| arg_value = np.array(arg_value) | |||
| # Thirdly, check the data type by numpy's dtype api | |||
| valid = False | |||
| if isinstance(arg_value, np.ndarray): | |||
| valid = arg_value.dtype in valid_data_types | |||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||
| # `check_type('x', True, [bool, int])` will check pass | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| valid = False | |||
| if not valid: | |||
| type_names = [t.__name__ for t in valid_types] | |||
| if len(valid_types) == 1: | |||
| raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {type(arg_value).__name__}.') | |||
| raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {type(arg_value).__name__}.') | |||
| return type(arg_value).__name__ | |||
| def check_typename(arg_name, arg_type, valid_types): | |||
| """Check type name.""" | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| if isinstance(arg_type, type(mstype.tensor)): | |||
| arg_type = arg_type.element_type() | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| if isinstance(arg_type, tuple(valid_types)): | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| if len(valid_types) == 1: | |||
| raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| def args_type_check(*type_args, **type_kwargs): | |||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| from .._c_expression import Tensor as Tensor_ | |||
| from .._c_expression import MetaTensor as MetaTensor_ | |||
| from .._checkparam import check_type, check_typename | |||
| from .._checkparam import Validator as validator | |||
| from . import dtype as mstype | |||
| from ._register_for_tensor import tensor_operator_registry | |||
| @@ -64,9 +64,19 @@ class Tensor(Tensor_): | |||
| input_data = np.array(input_data) | |||
| # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. | |||
| check_type('tensor input_data', input_data, (Tensor_, float, int)) | |||
| validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), | |||
| 'Tensor') | |||
| valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, | |||
| np.float16, np.float32, np.float64, np.bool_) | |||
| if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is " | |||
| f"{input_data.dtype} that is not supported to initialize a Tensor.") | |||
| if isinstance(input_data, (tuple, list)): | |||
| if np.array(input_data).dtype not in valid_dtypes: | |||
| raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") | |||
| if dtype is not None: | |||
| check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,)) | |||
| validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor") | |||
| if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | |||
| input_data = np.ascontiguousarray(input_data) | |||
| if dtype is None: | |||
| @@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_): | |||
| Returns: | |||
| Array, an array after being initialized. | |||
| """ | |||
| def __init__(self, dtype, shape, init=None): | |||
| #check param | |||
| # check param | |||
| self.init = init | |||
| MetaTensor_.__init__(self, dtype, shape) | |||
| @@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_): | |||
| msg = "Error shape={}".format(shape) | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| class seed_context: | |||
| '''set and restore seed''' | |||
| def __init__(self, init): | |||
| self.init = init | |||
| from .seed import get_seed | |||
| @@ -482,4 +495,5 @@ def _vm_compare(*args): | |||
| y = args[0] | |||
| return Tensor(np.array(fn(y))) | |||
| tensor_operator_registry.register('vm_compare', _vm_compare) | |||
| @@ -21,7 +21,7 @@ from ...ops import operations as P | |||
| from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ...ops.composite import multitype_ops as C | |||
| from ...ops.operations import _grad_ops as G | |||
| from ..._checkparam import Validator | |||
| from ..._checkparam import Validator as validator | |||
| from ..cell import Cell, GraphKernel | |||
| @@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel): | |||
| use_locking=False, | |||
| gradient_scale=1.0): | |||
| super(ApplyMomentum, self).__init__() | |||
| self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) | |||
| self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__) | |||
| self.fake_output_assign_1 = InplaceAssign() | |||
| self.fake_output_assign_1.add_prim_attr("fake_output", True) | |||
| self.fake_output_assign_2 = InplaceAssign() | |||
| @@ -334,7 +334,7 @@ class ReduceMean(GraphKernel): | |||
| def __init__(self, keep_dims=True): | |||
| super(ReduceMean, self).__init__() | |||
| self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) | |||
| self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__) | |||
| self.sum = P.ReduceSum(self.keep_dims) | |||
| def construct(self, x, axis): | |||
| @@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel): | |||
| """ Forward function of the LayerNorm operator. """ | |||
| def __init__(self, begin_norm_axis=1, begin_params_axis=1): | |||
| super(LayerNormForward, self).__init__() | |||
| self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||
| self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||
| self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], | |||
| type(self).__name__) | |||
| self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], | |||
| type(self).__name__) | |||
| self.mul = P.Mul() | |||
| self.sum_keep_dims = P.ReduceSum(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| @@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel): | |||
| def __init__(self, axis=-1): | |||
| super(LogSoftmax, self).__init__() | |||
| self.axis = Validator.check_type('axis', axis, [int]) | |||
| self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__) | |||
| self.max_keep_dims = P.ReduceMax(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| self.exp = P.Exp() | |||
| @@ -952,13 +954,13 @@ class Softmax(GraphKernel): | |||
| def __init__(self, axis): | |||
| super(Softmax, self).__init__() | |||
| Validator.check_type("axis", axis, [int, tuple]) | |||
| validator.check_value_type("axis", axis, [int, tuple], type(self).__name__) | |||
| if isinstance(axis, int): | |||
| self.axis = (axis,) | |||
| else: | |||
| self.axis = axis | |||
| for item in self.axis: | |||
| Validator.check_type("item of axis", item, [int]) | |||
| validator.check_value_type("item of axis", item, [int], type(self).__name__) | |||
| self.max = P.ReduceMax(keep_dims=True) | |||
| self.sub = P.Sub() | |||
| self.exp = P.Exp() | |||
| @@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops.primitive import constexpr | |||
| import mindspore.context as context | |||
| from mindspore._checkparam import Validator, check_typename | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._extends import cell_attr_register | |||
| from mindspore.communication.management import get_group_size, get_rank | |||
| from mindspore.communication import management | |||
| @@ -52,7 +52,7 @@ class _BatchNorm(Cell): | |||
| if momentum < 0 or momentum > 1: | |||
| raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | |||
| self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.use_batch_statistics = use_batch_statistics | |||
| @@ -67,7 +67,7 @@ class _BatchNorm(Cell): | |||
| gamma_init, num_features), name="gamma", requires_grad=affine) | |||
| self.beta = Parameter(initializer( | |||
| beta_init, num_features), name="beta", requires_grad=affine) | |||
| self.group = Validator.check_positive_int(device_num_each_group) | |||
| self.group = validator.check_positive_int(device_num_each_group) | |||
| self.is_global = False | |||
| if self.group != 1: | |||
| self.rank_id = get_rank() | |||
| @@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm): | |||
| use_batch_statistics, | |||
| device_num_each_group, | |||
| input_dims='both') | |||
| self.group = Validator.check_positive_int(device_num_each_group) | |||
| self.group = validator.check_positive_int(device_num_each_group) | |||
| if self.group <= 1: | |||
| raise ValueError("the number of group must be greater than 1.") | |||
| @@ -607,12 +607,12 @@ class GroupNorm(Cell): | |||
| def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): | |||
| super(GroupNorm, self).__init__() | |||
| self.num_groups = Validator.check_positive_int(num_groups) | |||
| self.num_channels = Validator.check_positive_int(num_channels) | |||
| self.num_groups = validator.check_positive_int(num_groups) | |||
| self.num_channels = validator.check_positive_int(num_channels) | |||
| if num_channels % num_groups != 0: | |||
| raise ValueError("num_channels should be divided by num_groups") | |||
| self.eps = check_typename('eps', eps, (float,)) | |||
| self.affine = Validator.check_bool(affine) | |||
| self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__) | |||
| self.affine = validator.check_bool(affine) | |||
| gamma = initializer(gamma_init, num_channels) | |||
| beta = initializer(beta_init, num_channels) | |||
| @@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): | |||
| super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, | |||
| symmetric=symmetric, narrow_range=narrow_range, | |||
| num_channels=num_channels) | |||
| Validator.check_type("min_init", min_init, [int, float]) | |||
| Validator.check_type("max_init", max_init, [int, float]) | |||
| Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__) | |||
| Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__) | |||
| Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | |||
| Validator.check_non_negative_int(quant_delay, 'quant_delay') | |||
| self.min_init = min_init | |||
| @@ -68,7 +68,7 @@ class GumbelCDF(Bijector): | |||
| """ | |||
| param = dict(locals()) | |||
| valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | |||
| super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | |||
| @@ -119,7 +119,7 @@ class Bernoulli(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'probs': probs} | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | |||
| self._probs = self._add_parameter(probs, 'probs') | |||
| @@ -109,7 +109,7 @@ class Categorical(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'probs': probs} | |||
| valid_dtype = mstype.int_type | |||
| Validator.check_type("Categorical", dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Categorical, self).__init__(seed, dtype, name, param) | |||
| self._probs = self._add_parameter(probs, 'probs') | |||
| @@ -121,7 +121,7 @@ class Exponential(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'rate': rate} | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Exponential, self).__init__(seed, dtype, name, param) | |||
| self._rate = self._add_parameter(rate, 'rate') | |||
| @@ -122,7 +122,7 @@ class Geometric(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'probs': probs} | |||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Geometric, self).__init__(seed, dtype, name, param) | |||
| self._probs = self._add_parameter(probs, 'probs') | |||
| @@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution): | |||
| Constructor of Gumbel distribution. | |||
| """ | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | |||
| super(Gumbel, self).__init__( | |||
| distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | |||
| @@ -111,7 +111,7 @@ class Logistic(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'loc': loc, 'scale': scale} | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Logistic, self).__init__(seed, dtype, name, param) | |||
| self._loc = self._add_parameter(loc, 'loc') | |||
| @@ -127,7 +127,7 @@ class Normal(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'mean': mean, 'sd': sd} | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Normal, self).__init__(seed, dtype, name, param) | |||
| self._mean_value = self._add_parameter(mean, 'mean') | |||
| @@ -126,7 +126,7 @@ class Uniform(Distribution): | |||
| param = dict(locals()) | |||
| param['param_dict'] = {'low': low, 'high': high} | |||
| valid_dtype = mstype.float_type | |||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||
| super(Uniform, self).__init__(seed, dtype, name, param) | |||
| self._low = self._add_parameter(low, 'low') | |||
| @@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer): | |||
| return [1] | |||
| def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | |||
| args = {"indices": indices_dtype} | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||
| return input_x_dtype | |||
| @@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer): | |||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | |||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||
| out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | |||
| return out_dtype | |||
| @@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||
| return out_shape | |||
| def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | |||
| args = {"miss_emb_idx": miss_emb_idx_dtype} | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||
| out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | |||
| return out_dtype | |||
| @@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer): | |||
| return miss_value_shape | |||
| def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | |||
| args = {"swap_cache_idx": swap_cache_idx_dtype} | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||
| return miss_value_dtype | |||
| @@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer): | |||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | |||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||
| out_dtype = (hashmap_dtype, hashmap_dtype, | |||
| hashmap_dtype, hashmap_dtype) | |||
| return out_dtype | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Operators for gradients.""" | |||
| from functools import partial | |||
| from .. import signature as sig | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| @@ -23,6 +24,7 @@ from ...common import dtype as mstype | |||
| from .. import functional as F | |||
| from ... import context | |||
| class AbsGrad(PrimitiveWithInfer): | |||
| """Computes gradients for abs operation.""" | |||
| @@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, dout): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x | |||
| @@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, dout): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x | |||
| @@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, dout): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x | |||
| @@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, dout): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x | |||
| @@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| @@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], | |||
| self.name) | |||
| return x_dtype | |||
| @@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| @@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, dout_dtype): | |||
| args = {"x": x_dtype, "dout": dout_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| @@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, y_type, doutput_type): | |||
| args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type, y_type | |||
| @@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, y_type, doutput_type, weight_type): | |||
| args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| if weight_type: | |||
| validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) | |||
| return x_type | |||
| @@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): | |||
| for i, dim_len in enumerate(w_size_v): | |||
| validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) | |||
| args = {"x": x['dtype'], "doutput": doutput['dtype']} | |||
| validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], | |||
| self.name) | |||
| out = { | |||
| 'value': None, | |||
| 'shape': w_size_v, | |||
| @@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): | |||
| def __infer__(self, x, w_size, dout): | |||
| w_size_v = w_size['value'] | |||
| args = {'x': x['dtype'], 'dout': dout['dtype']} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| out = { | |||
| 'value': None, | |||
| 'shape': w_size_v, | |||
| @@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): | |||
| def __infer__(self, x_size, w, dout): | |||
| args = {'w': w['dtype'], 'dout': dout['dtype']} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| x_size_v = x_size['value'] | |||
| out = { | |||
| 'value': None, | |||
| @@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer): | |||
| return dy_shape | |||
| def infer_dtype(self, dy_dtype, mask_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name) | |||
| return dy_dtype | |||
| @@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): | |||
| validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("y_backprop", "x", "y"), | |||
| (y_backprop_dtype, x_dtype, y_dtype))) | |||
| return x_dtype | |||
| @@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad): | |||
| def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): | |||
| args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) | |||
| return x1_dtype | |||
| @@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad): | |||
| def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): | |||
| args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) | |||
| return grad_dtype | |||
| @@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, input_x, out, dout): | |||
| args = {'input_x': input_x, 'out': out, 'dout': dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return input_x | |||
| @@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer): | |||
| def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, | |||
| hx_dtype, cx_dtype, reserve_dtype, state_dtype): | |||
| args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) | |||
| return (dy_dtype, dy_dtype, dy_dtype) | |||
| @@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||
| args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, | |||
| "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, | |||
| "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | |||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) | |||
| if seq_dtype is not None: | |||
| validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name) | |||
| if mask_dtype is not None: | |||
| validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) | |||
| validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name) | |||
| return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype | |||
| @@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer): | |||
| return y_backprop_shape, w_shape | |||
| def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ('y_backprop', "input_x", "weight"), | |||
| (y_backprop_dtype, A_dtype, w_dtype))) | |||
| return y_backprop_dtype, w_dtype | |||
| @@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer): | |||
| return gradients_shape | |||
| def infer_dtype(self, gradients_dtype, mask_dtype): | |||
| validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) | |||
| validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name) | |||
| return gradients_dtype | |||
| @@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| args = {'y_grad': y_grad_dtype, 'x': x_dtype} | |||
| validator.check_tensor_type_same(args, mstype.float_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) | |||
| return x_dtype | |||
| @@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, out, dout): | |||
| args = {'out': out, 'dout': dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return out | |||
| @@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, y_grad_dtype, x_dtype): | |||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, y_dtype, dout_dtype): | |||
| args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return dout_dtype | |||
| @@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, prediction, target, dloss): | |||
| args = {"prediction": prediction, "target": target, 'dloss': dloss} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return dloss | |||
| @@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | |||
| def __infer__(self, dy, shapex, begin, end, strides): | |||
| args = {"dy": dy['dtype']} | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name) | |||
| for idx, item in enumerate(shapex['value']): | |||
| validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) | |||
| @@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_dtype, x_dtype): | |||
| args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} | |||
| validator.check_tensor_type_same(args, mstype.float_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) | |||
| return x_dtype | |||
| @@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, out, dout): | |||
| args = {"out": out, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return out | |||
| @@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, x, dout): | |||
| args = {"x": x, "dout": dout} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x | |||
| @@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, grads, x, y): | |||
| args = {"grads": grads, "x": x, "y": y} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name) | |||
| return x | |||
| def infer_shape(self, grads, x, y): | |||
| @@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, ksizes, strides, rates, padding="valid"): | |||
| """init""" | |||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | |||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | |||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | |||
| @@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| def infer_dtype(self, input_x): | |||
| """infer dtype""" | |||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name) | |||
| return input_x | |||
| @@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) | |||
| return x_dtype | |||
| @@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer): | |||
| >>> dequant = P.Dequant(False, False) | |||
| >>> y = dequant(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, sqrt_mode=False, relu_flag=False): | |||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | |||
| @@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer): | |||
| return assist | |||
| def infer_dtype(self, assist, start, stop, num): | |||
| args = {"num": num} | |||
| validator.check_tensor_type_same(args, (mstype.int32,), self.name) | |||
| validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name) | |||
| args = {"assist": assist, "start": start, "stop": stop} | |||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name) | |||
| return assist | |||
| @@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) | |||
| validator.check('rank of x', len(x_shape)+1, | |||
| validator.check('rank of x', len(x_shape) + 1, | |||
| 'rank of assist', len(assist_shape), Rel.LE, self.name) | |||
| validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | |||
| assist_shape[-1], Rel.EQ, self.name) | |||
| @@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer): | |||
| while r_idx >= r_end_dim: | |||
| if x_shape[r_idx] != 1: | |||
| validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % | |||
| assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) | |||
| assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name) | |||
| r_idx = r_idx - 1 | |||
| return assist_shape | |||
| @@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, assist_shape): | |||
| @@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | |||
| args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, diagonal_shape, assist_shape): | |||
| @@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer): | |||
| return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape | |||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | |||
| validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) | |||
| validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) | |||
| b_dtype = mstype.float32 | |||
| if binput_dtype is not None: | |||
| validator.check_tensor_type_same({"bias input dtype": binput_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = binput_dtype | |||
| elif bhidden_dtype is not None: | |||
| validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = bhidden_dtype | |||
| elif h_dtype is not None: | |||
| validator.check_tensor_type_same({"init_h dtype": h_dtype}, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("init_h dtype", h_dtype, | |||
| (mstype.float16, mstype.float32), self.name) | |||
| b_dtype = h_dtype | |||
| return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Operators for quantization.""" | |||
| from functools import partial | |||
| import mindspore.context as context | |||
| from ..._checkparam import Validator as validator | |||
| @@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return min_type, max_type | |||
| @@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same( | |||
| {"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return min_type, max_type | |||
| @@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): | |||
| >>> input_tensor, min_tensor, max_tensor) | |||
| >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| num_bits=8, | |||
| @@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return x_type | |||
| @@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): | |||
| >>> min_gradient shape: (1,) data type: mstype.float32 | |||
| >>> max_gradient shape: (1,) data type: mstype.float32 | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| num_bits=8, | |||
| @@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): | |||
| return x_shape, min_shape, max_shape | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ('dout', "x", "min", "max"), | |||
| (dout_type, x_type, min_type, max_type))) | |||
| return x_type, min_type, max_type | |||
| @@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): | |||
| >>> input_tensor, min_tensor, max_tensor) | |||
| >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| num_bits=8, | |||
| @@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return x_type | |||
| @@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): | |||
| >>> min_gradient shape: (4,) data type: mstype.float32 | |||
| >>> max_gradient shape: (4,) data type: mstype.float32 | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| num_bits=8, | |||
| @@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): | |||
| return x_shape, min_shape, max_shape | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("dout", "x", "min", "max"), | |||
| (dout_type, x_type, min_type, max_type))) | |||
| return x_type, min_type, max_type | |||
| @@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| if context.get_context('device_target') == "GPU": | |||
| valid_types = (mstype.float32,) | |||
| valid_dtypes = (mstype.float32,) | |||
| else: | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return x_type | |||
| @@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| if context.get_context('device_target') == "GPU": | |||
| valid_types = (mstype.float32,) | |||
| valid_dtypes = (mstype.float32,) | |||
| else: | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same( | |||
| {"dout": dout_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||
| ("dout", "x", "min", "max"), | |||
| (dout_type, x_type, min_type, max_type))) | |||
| return dout_type | |||
| @@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| if context.get_context('device_target') == "GPU": | |||
| valid_types = (mstype.float32,) | |||
| valid_dtypes = (mstype.float32,) | |||
| else: | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||
| ("x", "min", "max"), | |||
| (x_type, min_type, max_type))) | |||
| return x_type | |||
| @@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | |||
| if context.get_context('device_target') == "GPU": | |||
| valid_types = (mstype.float32,) | |||
| valid_dtypes = (mstype.float32,) | |||
| else: | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same( | |||
| {"dout": dout_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||
| ("dout", "x", "min", "max"), | |||
| (dout_type, x_type, min_type, max_type))) | |||
| return dout_type | |||
| @@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| validator.check("input type", x_type, "mean type", mean_type) | |||
| validator.check("input type", x_type, "variance type", variance_type) | |||
| args = {"x": x_type, "mean": mean_type, "variance": variance_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||
| return x_type, x_type, x_type, x_type | |||
| @@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| global_step_type): | |||
| args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, | |||
| "batch_mean": batch_mean_type, "batch_std": batch_std_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||
| return x_type | |||
| @@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, batch_std_type, running_std_type): | |||
| args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type | |||
| @@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | |||
| args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| if context.get_context('device_target') == "Ascend": | |||
| return x_type, x_type | |||
| return x_type, gamma_type | |||
| @@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer): | |||
| running_mean_type, global_step_type): | |||
| args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | |||
| "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||
| return x_type | |||
| @@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||
| "dout type", dout_type) | |||
| args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | |||
| "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||
| return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type | |||
| @@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer): | |||
| validator.check("input type", x_type, "mean type", mean_type) | |||
| validator.check("input type", x_type, "variance type", variance_type) | |||
| args = {"x": x_type, "mean": mean_type, "variance": variance_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type, x_type, x_type, x_type, x_type, x_type, x_type | |||
| @@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer): | |||
| validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) | |||
| validator.check("input type", x_type, "batch_mean type", batch_mean_type) | |||
| validator.check("input type", x_type, "batch_std type", batch_std_type) | |||
| args = {"input type": x_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name) | |||
| return x_type | |||
| @@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): | |||
| args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | |||
| "beta": beta_type, "gamma": gamma_type, "x": x_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| return x_type | |||
| @@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer): | |||
| "dout type", dout_type) | |||
| args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | |||
| "running_std": running_std_type, "dout": dout_type} | |||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||
| return gamma_type, gamma_type, gamma_type, gamma_type | |||
| @@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): | |||
| def infer_shape(self, data1_shape): | |||
| ll = [] | |||
| if len(data1_shape) == 2: | |||
| ll = [1,] | |||
| ll = [1] | |||
| else: | |||
| ll = [32, 64] | |||
| return ll | |||
| @@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer): | |||
| >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) | |||
| >>> output = img2col(input_x) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, | |||
| kernel_size, | |||
| @@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer): | |||
| return out_shape | |||
| def infer_dtype(self, x_dtype): | |||
| args = {'x': x_dtype} | |||
| valid_types = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer): | |||
| return x2_shape | |||
| def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): | |||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, | |||
| [mstype.float32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid( | |||
| {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, | |||
| [mstype.float32], self.name) | |||
| return x2_dtype | |||
| class Cholesky(PrimitiveWithInfer): | |||
| """ | |||
| Inner API for resnet50 THOR GPU backend | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, split_dim=0): | |||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||
| @@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer): | |||
| return out_shape | |||
| def infer_dtype(self, x1_dtype): | |||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) | |||
| return x1_dtype | |||
| class DetTriangle(PrimitiveWithInfer): | |||
| """ | |||
| Calculate the determinant of triangle matrices | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, fill_mode=0): | |||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||
| @@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer): | |||
| return out_shape | |||
| def infer_dtype(self, x1_dtype): | |||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) | |||
| return x1_dtype | |||
| @@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "updates": updates_dtype} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp): | |||
| """ | |||
| Defines _ScatterNd operators | |||
| """ | |||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | |||
| validator.check('the dimension of x', len(x_shape), | |||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | |||
| @@ -627,6 +628,7 @@ class Unique(Primitive): | |||
| >>> out = P.Unique()(x) | |||
| (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| @@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck): | |||
| def __init__(self): | |||
| """Initialize index_select""" | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | |||
| self.add_prim_attr("dynamic_shape_depends", [2,]) | |||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||
| def __check__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||
| @@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer): | |||
| >>> out = P.Padding(pad_dim_size)(x) | |||
| [[8, 0, 0, 0], [10, 0, 0, 0]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, pad_dim_size=8): | |||
| """Initialize padding""" | |||
| @@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer): | |||
| >>> out = P.UniqueWithPad()(x, pad_num) | |||
| ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """init UniqueWithPad""" | |||
| def __infer__(self, x, pad_num): | |||
| validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name) | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name) | |||
| validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) | |||
| x_shape = list(x['shape']) | |||
| validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) | |||
| @@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer): | |||
| def __init__(self, seed=0, dtype=mstype.float32): | |||
| """Initialize TruncatedNormal""" | |||
| validator.check_value_type('seed', seed, [int], self.name) | |||
| validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name) | |||
| validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name) | |||
| def __infer__(self, shape): | |||
| shape_value = shape['value'] | |||
| @@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer): | |||
| validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) | |||
| for i, item in enumerate(dims['value']): | |||
| validator.check_positive_int(item, f'dims[{i}]', self.name) | |||
| valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||
| mstype.uint8, mstype.uint32, mstype.uint64, | |||
| mstype.float16, mstype.float32, mstype.float64] | |||
| validator.check_type_same({"value": dtype['value']}, valid_types, self.name) | |||
| valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||
| mstype.uint8, mstype.uint32, mstype.uint64, | |||
| mstype.float16, mstype.float32, mstype.float64] | |||
| validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name) | |||
| x_nptype = mstype.dtype_to_nptype(dtype['value']) | |||
| ret = np.full(dims['value'], x['value'], x_nptype) | |||
| out = { | |||
| @@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||
| return x_dtype | |||
| @@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||
| return x_dtype | |||
| @@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer): | |||
| """Initialize Argmax""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| validator.check_value_type("axis", axis, [int], self.name) | |||
| validator.check_type_same({'output': output_type}, [mstype.int32], self.name) | |||
| validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name) | |||
| self.axis = axis | |||
| self.add_prim_attr('output_type', output_type) | |||
| @@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| """Initialize UnsortedSegmentSum""" | |||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | |||
| self.add_prim_attr("dynamic_shape_depends", [2,]) | |||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||
| def __infer__(self, x, segment_ids, num_segments): | |||
| x_type = x['dtype'] | |||
| @@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||
| num_segments_type = num_segments['dtype'] | |||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | |||
| if isinstance(num_segments_type, type(mstype.tensor)): | |||
| validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name) | |||
| shp = [-1] | |||
| else: | |||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | |||
| @@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer): | |||
| x_shape = x['shape'] | |||
| segment_ids_shape = segment_ids['shape'] | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| @@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer): | |||
| x_shape = x['shape'] | |||
| segment_ids_shape = segment_ids['shape'] | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| @@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer): | |||
| validator.check_subclass("input_x", x_type, mstype.tensor, self.name) | |||
| validator.check_value_type("x_shape", x_shape, [list], self.name) | |||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) | |||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | |||
| validator.check(f'first shape of input_x', x_shape[0], | |||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | |||
| @@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||
| validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) | |||
| args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| first_elem = x_shp[0] | |||
| for i, elem in enumerate(x_shp[1:]): | |||
| @@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| @@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer): | |||
| self.add_prim_attr('T', x_type) | |||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | |||
| validator.check_subclass("y_type", y_type, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name) | |||
| validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) | |||
| if x_type != y_type: | |||
| raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) | |||
| return x_type | |||
| @@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer): | |||
| validator.check_positive_int(n, "n", self.name) | |||
| validator.check_positive_int(m, "m", self.name) | |||
| args = {"dtype": t} | |||
| validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| np_type = mstype.dtype_to_nptype(t) | |||
| ret = np.eye(n, m, dtype=np_type) | |||
| return Tensor(ret) | |||
| @@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer): | |||
| def __infer__(self, indices, update, shape): | |||
| shp = shape['value'] | |||
| validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name) | |||
| validator.check_value_type("shape", shp, [tuple], self.name) | |||
| for i, x in enumerate(shp): | |||
| validator.check_positive_int(x, f'shape[{i}]', self.name) | |||
| @@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||
| validator.check_non_negative_int(value, f'{i}th value of size', self.name) | |||
| self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | |||
| def infer_shape(self, x): | |||
| validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) | |||
| return tuple(x)[:-2] + tuple(self.size) | |||
| def infer_shape(self, x_shape): | |||
| validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name) | |||
| return tuple(x_shape)[:-2] + tuple(self.size) | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name) | |||
| return x | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| class GatherNd(PrimitiveWithInfer): | |||
| @@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer): | |||
| return indices_shape[:-1] + x_shape[indices_shape[-1]:] | |||
| def infer_dtype(self, x_dtype, indices_dtype): | |||
| validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||
| return x_dtype | |||
| @@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp): | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp): | |||
| [0.4 2.2 -3.2]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, use_locking=True): | |||
| """Initialize ScatterNdUpdate""" | |||
| @@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp): | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp): | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | |||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "updates": updates_dtype} | |||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name) | |||
| return x_dtype | |||
| @@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||
| self.paddings = paddings | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| @@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer): | |||
| self.crops = crops | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| @@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||
| self.add_prim_attr("paddings", paddings_append) | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| @@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||
| self.add_prim_attr("crops", crops_append) | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape): | |||
| @@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer): | |||
| out_shape = tuple(tuple(shape_0) for _ in range(n)) | |||
| return out_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) | |||
| n = len(x_type) | |||
| @@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer): | |||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) | |||
| return x_type | |||
| class InplaceUpdate(PrimitiveWithInfer): | |||
| r""" | |||
| Updates specified rows with values in `v`. | |||
| @@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, v_dtype): | |||
| args = {'x': x_dtype, 'v': v_dtype} | |||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, v_shape): | |||
| @@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x, seq_lengths): | |||
| validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) | |||
| validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name) | |||
| return x | |||
| @@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer): | |||
| validator.check_const_input('truth_shape', truth_shape['value'], self.name) | |||
| args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], | |||
| "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} | |||
| validator.check_tensor_type_same(args_int, [mstype.int64], self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name) | |||
| args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] | |||
| validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) | |||
| @@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer): | |||
| Outputs: | |||
| Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.__setattr_flag__ = True | |||
| @@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer): | |||
| def __infer__(self, x, shape): | |||
| shp = shape['value'] | |||
| dtype = x['dtype'] | |||
| validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name) | |||
| self.add_prim_attr('out_shape', tuple(shp)) | |||
| return {'shape': shp, | |||
| 'dtype': dtype, | |||
| @@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer): | |||
| return x_shape, x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name) | |||
| return x_dtype, mstype.tensor_type(mstype.int32) | |||
| @@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) | |||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize index_select""" | |||
| @@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||
| def __infer__(self, params, indices, offset): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | |||
| params_shp = params['shape'] | |||
| if len(params_shp) != 2: | |||
| @@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer): | |||
| >>> out = P.GatherD()(x, dim, index) | |||
| [[1, 1], [4, 3]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize GatherD""" | |||
| @@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer): | |||
| def __infer__(self, x, dim, index): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name) | |||
| validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name) | |||
| validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) | |||
| x_shp = x['shape'] | |||
| idx_shp = index['shape'] | |||
| @@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer): | |||
| >>> y = P.Identity()(x) | |||
| [1, 2, 3, 4] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize identity""" | |||
| @@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| @@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| @@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| @@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| @@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer): | |||
| if not isinstance(x_dtype, tuple): | |||
| raise TypeError(f"{self.name}'s input should be a tuple!") | |||
| for _ele in x_dtype: | |||
| validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name) | |||
| return x_dtype | |||
| @@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||
| return x_dtype | |||
| def __call__(self, tensor): | |||
| @@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer): | |||
| def infer_dtype(self, data_type, pred_type): | |||
| validator.check_subclass( | |||
| "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"pred": pred_type}, [mstype.bool_], self.name) | |||
| validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name) | |||
| return (data_type, data_type) | |||
| @@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer): | |||
| for i, item in enumerate(inputs): | |||
| args['inputs[%d]' % i] = item | |||
| validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return (inputs[0], mstype.int32) | |||
| @@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer): | |||
| return [1] | |||
| def infer_dtype(self, condition, inputs): | |||
| validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name) | |||
| validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name) | |||
| for dtype in inputs: | |||
| validator.check_subclass("input", dtype, [mstype.tensor], self.name) | |||
| return mstype.int32 | |||
| @@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer): | |||
| box_index_dtype = box_index['dtype'] | |||
| crop_size_dtype = crop_size['dtype'] | |||
| # check dytpe | |||
| validator.check_tensor_type_same({"x": x_dtype}, | |||
| [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, | |||
| mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) | |||
| validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, | |||
| [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, | |||
| mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) | |||
| validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name) | |||
| validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) | |||
| # check input shape rank | |||
| validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) | |||
| @@ -16,6 +16,8 @@ | |||
| """Operators for math.""" | |||
| import copy | |||
| from functools import partial | |||
| import numpy as np | |||
| from ... import context | |||
| from .. import signature as sig | |||
| @@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp): | |||
| @staticmethod | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): | |||
| args_type = {"x": x_dtype, "y": y_dtype} | |||
| validator.check_tensor_type_same(args_type, valid_dtype, prim_name) | |||
| validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name) | |||
| return x_dtype | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| @@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp): | |||
| @staticmethod | |||
| def _check_bitwise_op_input_type(x1_type, x2_type, prim): | |||
| args = {'x1': x1_type, 'x2': x2_type} | |||
| valid_types = mstype.int_type + mstype.uint_type | |||
| validator.check_tensor_type_same(args, valid_types, prim) | |||
| valid_dtypes = mstype.int_type + mstype.uint_type | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim) | |||
| return x1_type | |||
| def infer_dtype(self, x1_type, x2_type): | |||
| @@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer): | |||
| def infer_dtype(self, variable, value): | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) | |||
| return value | |||
| @@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer): | |||
| def infer_dtype(self, variable, value): | |||
| args = {"variable": variable, "value": value} | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) | |||
| return value | |||
| @@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer): | |||
| axis_v = axis['value'] | |||
| input_shp = input_x['shape'] | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensor_type_same(args, valid_dtype, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name) | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| @@ -504,6 +506,7 @@ class ReduceMax(_Reduce): | |||
| def __infer__(self, input_x, axis): | |||
| return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) | |||
| class ReduceMin(_Reduce): | |||
| """ | |||
| Reduce a dimension of a tensor by the minimum value in the dimension. | |||
| @@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_type, axis_type): | |||
| cls_name = self.name | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) | |||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name) | |||
| validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | |||
| return x_type | |||
| @@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer): | |||
| def infer_dtype(self, x1, x2): | |||
| args = {"x1": x1, "x2": x2} | |||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) | |||
| if x1.element_type() == mstype.int8: | |||
| return mstype.tensor_type(mstype.int32) | |||
| return x1 | |||
| @@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer): | |||
| self.axes = axes | |||
| validator.check_value_type('axes', axes, [int, tuple, list], self.name) | |||
| if not isinstance(self.axes, int): | |||
| self.axes = list(self.axes) # to avoid immutability issues | |||
| self.axes = list(self.axes) # to avoid immutability issues | |||
| if len(self.axes) != 2: | |||
| raise ValueError("Require two axes inputs, given less") | |||
| self.int_to_tuple_conv() # convert before length checks | |||
| self.int_to_tuple_conv() # convert before length checks | |||
| if len(self.axes[0]) != len(self.axes[1]): | |||
| raise ValueError("Axes have to be the same size/length") | |||
| if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): | |||
| @@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer): | |||
| if isinstance(self.axes, int): | |||
| if self.axes <= 0: | |||
| # outer product, no input validation required | |||
| self.axes = ([], []) # no axes selected for either | |||
| self.axes = ([], []) # no axes selected for either | |||
| return | |||
| if self.axes > len(x1_shape) or self.axes > len(x2_shape): | |||
| raise ValueError( | |||
| @@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer): | |||
| def infer_dtype(self, x1, x2): | |||
| args = {"x1": x1, "x2": x2} | |||
| valid_types = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||
| return x1 | |||
| @@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer): | |||
| if axis['value'] is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| validator.check_value_type('axis', axis['value'], [int], cls_name) | |||
| valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | |||
| valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name) | |||
| return {'shape': x_shp, | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| @@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer): | |||
| if dtype == mstype.undetermined: | |||
| contains_undetermined = True | |||
| if not contains_undetermined: | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| return inputs[0] | |||
| def infer_value(self, inputs): | |||
| @@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer): | |||
| args = {} | |||
| for i, dtype in enumerate(inputs): | |||
| args[f"inputs[{i}]"] = dtype | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| return inputs[0] | |||
| @@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer): | |||
| """Initialize Neg""" | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| def infer_shape(self, input_x): | |||
| return input_x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, input_x): | |||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | |||
| return input_x | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_value(self, input_x): | |||
| if input_x is not None: | |||
| @@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, v_dtype): | |||
| args = {'x': x_dtype, 'v': v_dtype} | |||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, v_shape): | |||
| @@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, v_dtype): | |||
| args = {'x': x_dtype, 'v': v_dtype} | |||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||
| return x_dtype | |||
| def infer_shape(self, x_shape, v_shape): | |||
| @@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||
| return x_type | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_value(self, x): | |||
| if x is not None: | |||
| @@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||
| return x_type | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| def infer_value(self, x): | |||
| if x is not None: | |||
| @@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| def check_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name) | |||
| def infer_value(self, x): | |||
| if x is not None: | |||
| @@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) | |||
| return x_type | |||
| @@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer): | |||
| return (self.nbins,) | |||
| def infer_dtype(self, x_dtype, range_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||
| valid_types = (mstype.float16, mstype.float32, mstype.int32) | |||
| validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float16, mstype.float32, mstype.int32) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||
| validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name) | |||
| y_dtype = mstype.int32 | |||
| return y_dtype | |||
| @@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||
| def infer_shape(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name) | |||
| return x | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class Erf(PrimitiveWithInfer): | |||
| @@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||
| return x_type | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| class Erfc(PrimitiveWithInfer): | |||
| @@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) | |||
| return x_type | |||
| @@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name) | |||
| return x_dtype | |||
| @@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return x_dtype | |||
| @@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp): | |||
| @staticmethod | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): | |||
| args_dtype = {"x": x_dtype, "y": y_dtype} | |||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name) | |||
| validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name) | |||
| return mstype.tensor_type(mstype.bool_) | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| @@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp): | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| args_dtype = {"x": x_dtype, "y": y_dtype} | |||
| valid_type = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name) | |||
| return mstype.tensor_type(mstype.bool_) | |||
| @@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer): | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| args = {'x': x_dtype, 'y': y_dtype} | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||
| return x_dtype | |||
| @@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name) | |||
| return mstype.tensor_type(mstype.bool_) | |||
| @@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||
| return mstype.bool_ | |||
| @@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer): | |||
| return [1] | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name) | |||
| return x_dtype | |||
| @@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer): | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return mstype.float32 | |||
| @@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer): | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return mstype.float32 | |||
| @@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer): | |||
| return (bboxes_shape, (num,), (num,)) | |||
| def infer_dtype(self, bboxes_dtype): | |||
| validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name) | |||
| return (bboxes_dtype, mstype.int32, mstype.bool_) | |||
| @@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||
| return x_type | |||
| def infer_value(self, x): | |||
| @@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| @@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||
| return x_type | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||
| return x_dtype | |||
| class Tan(PrimitiveWithInfer): | |||
| @@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| valid_types = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||
| valid_dtypes = [mstype.float16, mstype.float32, mstype.int32] | |||
| validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name) | |||
| return x_type | |||
| @@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||
| return x_type | |||
| @@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||
| return x_type | |||
| @@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer): | |||
| return [], [] | |||
| def infer_dtype(self, x_type, y_type): | |||
| validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name) | |||
| validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name) | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name) | |||
| validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name) | |||
| return x_type, y_type | |||
| @@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) | |||
| return x | |||
| @@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) | |||
| return x | |||
| @@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32, | |||
| validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32, | |||
| mstype.int32], self.name) | |||
| return x_dtype | |||
| @@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name) | |||
| validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name) | |||
| return x_dtype | |||
| @@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['input_x'], outputs=['y']) | |||
| def __infer__(self, input_x): | |||
| valid_types = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name) | |||
| valid_dtypes = [mstype.float16, mstype.float32] | |||
| validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name) | |||
| x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) | |||
| if x_nptype == np.float16: | |||
| @@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer): | |||
| return (1,), (1,) | |||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | |||
| valid_types = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name) | |||
| validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name) | |||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||
| ("input_value", "input_min", "input_max"), | |||
| (data_dtype, data_min_dtype, data_max_dtype))) | |||
| validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) | |||
| return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | |||
| @@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck): | |||
| def check_dtype(self, variable, value): | |||
| if variable != mstype.type_refkey: | |||
| validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) | |||
| validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) | |||
| class BoundingBoxEncode(PrimitiveWithInfer): | |||
| @@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer): | |||
| def infer_dtype(self, anchor_box, groundtruth_box): | |||
| args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return anchor_box | |||
| @@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer): | |||
| def infer_dtype(self, anchor_box, deltas): | |||
| args = {"anchor_box": anchor_box, "deltas": deltas} | |||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| return anchor_box | |||
| @@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer): | |||
| def infer_dtype(self, bboxes_type, metas_type): | |||
| valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] | |||
| validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name) | |||
| return mstype.bool_ | |||
| @@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer): | |||
| def infer_dtype(self, anchor_boxes, gt_boxes): | |||
| valid_type = [mstype.float32, mstype.float16] | |||
| validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name) | |||
| validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name) | |||
| validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name) | |||
| return anchor_boxes | |||
| @@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer): | |||
| if weights is not None: | |||
| validator.check_subclass('weights', weights, mstype.tensor, self.name) | |||
| args = {"labels": labels, "predictions": predictions} | |||
| validator.check_tensor_type_same(args, (mstype.number_type), self.name) | |||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name) | |||
| return labels | |||
| @@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| args = {"x": x_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) | |||
| validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) | |||
| return mstype.tensor_type(mstype.uint8) | |||
| class Push(PrimitiveWithInfer): | |||
| @@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer): | |||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| for i, shape_i in enumerate(shape_v): | |||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | |||
| Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) | |||
| Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) | |||
| Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name) | |||
| Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name) | |||
| broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) | |||
| broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) | |||
| out = { | |||
| @@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer): | |||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| for i, shape_i in enumerate(shape_v): | |||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | |||
| Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) | |||
| Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name) | |||
| broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) | |||
| out = { | |||
| 'shape': broadcast_shape, | |||
| @@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer): | |||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | |||
| for i, shape_i in enumerate(shape_v): | |||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | |||
| Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) | |||
| Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) | |||
| Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name) | |||
| Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name) | |||
| minval_shape = minval['shape'] | |||
| maxval_shape = maxval['shape'] | |||
| Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) | |||
| @@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||
| return ([self.count, len(x_shape)], [self.count]) | |||
| def infer_dtype(self, x_dtype): | |||
| Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) | |||
| Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name) | |||
| return (mstype.int32, mstype.bool_) | |||
| @@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer): | |||
| def __infer__(self, logits, num_samples, seed): | |||
| logits_dtype = logits['dtype'] | |||
| valid_types = (mstype.float32, mstype.float16, mstype.float64) | |||
| Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) | |||
| valid_dtypes = (mstype.float32, mstype.float16, mstype.float64) | |||
| Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name) | |||
| num_samples_v = num_samples['value'] | |||
| seed_v = seed['value'] | |||
| Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) | |||
| @@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer): | |||
| input_shape = inputs["shape"] | |||
| if len(input_shape) != 1 and len(input_shape) != 2: | |||
| raise ValueError("input dim must be 1 or 2") | |||
| Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) | |||
| Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name) | |||
| num_samples_value = num_samples["value"] | |||
| if num_samples_value is None: | |||
| raise ValueError(f"For {self.name}, shape nust be const") | |||
| @@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs): | |||
| if quant_mode not in quant_mode_formats: | |||
| raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') | |||
| mean = Validator.check_type("mean", mean, (int, float)) | |||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||
| mean = Validator.check_value_type("mean", mean, (int, float)) | |||
| std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) | |||
| if context.get_context('device_target') not in supported_device: | |||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | |||
| @@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer): | |||
| def __infer__(self, params, indices, axis): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||