Merge pull request !8114 from zhangbuxue/rectify_and_optimize_the_type_checking_functiontags/v1.1.0
| @@ -415,37 +415,20 @@ class Validator: | |||||
| break | break | ||||
| if not hit: | if not hit: | ||||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | ||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||||
| f' of {",".join((str(x) for x in template_types))}, but got {type_str}.') | |||||
| raise TypeError(f'For \'{prim_name}\', the type of `{arg_name}` should be subclass' | |||||
| f' of {", ".join((str(x) for x in template_types))}, but got {type_str}.') | |||||
| @staticmethod | @staticmethod | ||||
| def check_const_input(arg_name, arg_value, prim_name): | def check_const_input(arg_name, arg_value, prim_name): | ||||
| """Checks valid value.""" | """Checks valid value.""" | ||||
| if arg_value is None: | if arg_value is None: | ||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||||
| raise ValueError(f'For \'{prim_name}\', the `{arg_name}` must be a const input, but got {arg_value}.') | |||||
| return arg_value | return arg_value | ||||
| @staticmethod | @staticmethod | ||||
| def check_type(arg_name, arg_value, valid_types): | |||||
| """Type checking.""" | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| raise TypeError(f'The type of `{arg_name}` should be in {valid_types}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| raise_error_msg() | |||||
| if arg_value in valid_types: | |||||
| return arg_value | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| @staticmethod | |||||
| def check_type_same(args, valid_values, prim_name): | |||||
| """Checks whether the types of inputs are the same.""" | |||||
| def _check_tensor_type(arg): | |||||
| def check_types_same_and_valid(args, valid_values, prim_name): | |||||
| """Checks whether the types of inputs are the same and valid.""" | |||||
| def _check_type_valid(arg): | |||||
| arg_key, arg_val = arg | arg_key, arg_val = arg | ||||
| elem_type = arg_val | elem_type = arg_val | ||||
| Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) | Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) | ||||
| @@ -455,21 +438,27 @@ class Validator: | |||||
| arg1_name, arg1_type = arg1 | arg1_name, arg1_type = arg1 | ||||
| arg2_name, arg2_type = arg2 | arg2_name, arg2_type = arg2 | ||||
| if arg1_type != arg2_type: | if arg1_type != arg2_type: | ||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||||
| raise TypeError(f'For \'{prim_name}\', type of `{arg2_name}` should be same as `{arg1_name}`,' | |||||
| f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') | f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') | ||||
| return arg1 | return arg1 | ||||
| elem_types = map(_check_tensor_type, args.items()) | |||||
| elem_types = map(_check_type_valid, args.items()) | |||||
| reduce(_check_types_same, elem_types) | reduce(_check_types_same, elem_types) | ||||
| @staticmethod | @staticmethod | ||||
| def check_tensor_type_same(args, valid_values, prim_name): | |||||
| """Checks whether the element types of input tensors are the same.""" | |||||
| tensor_types = [mstype.tensor_type(t) for t in valid_values] | |||||
| Validator.check_type_same(args, tensor_types, prim_name) | |||||
| def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): | |||||
| """Checks whether the element types of input tensors are the same and valid.""" | |||||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | |||||
| Validator.check_types_same_and_valid(args, tensor_types, prim_name) | |||||
| @staticmethod | |||||
| def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): | |||||
| """Checks whether the element types of input tensors are valid.""" | |||||
| tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] | |||||
| Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||||
| def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False): | |||||
| """ | """ | ||||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | ||||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | ||||
| @@ -480,7 +469,7 @@ class Validator: | |||||
| if isinstance(arg_val, type(mstype.tensor)): | if isinstance(arg_val, type(mstype.tensor)): | ||||
| arg_val = arg_val.element_type() | arg_val = arg_val.element_type() | ||||
| if not arg_val in valid_values: | if not arg_val in valid_values: | ||||
| raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' | |||||
| raise TypeError(f'For \'{prim_name}\', the `{arg_key}` should be in {valid_values},' | |||||
| f' but `{arg_key}` is {arg_val}.') | f' but `{arg_key}` is {arg_val}.') | ||||
| return arg | return arg | ||||
| @@ -512,40 +501,40 @@ class Validator: | |||||
| def raise_error_msg(): | def raise_error_msg(): | ||||
| """func for raising error message when check failed""" | """func for raising error message when check failed""" | ||||
| type_names = [t.__name__ for t in valid_types] | |||||
| type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] | |||||
| num_types = len(valid_types) | num_types = len(valid_types) | ||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| msg_prefix = f"For '{prim_name}', the" if prim_name else "The" | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | ||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| f'{type_names if num_types > 1 else type_names[0]}, ' | |||||
| f'but got {arg_value} with type {type(arg_value).__name__}.') | |||||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | ||||
| # `check_value_type('x', True, [bool, int])` will check pass | # `check_value_type('x', True, [bool, int])` will check pass | ||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | if isinstance(arg_value, bool) and bool not in tuple(valid_types): | ||||
| raise_error_msg() | raise_error_msg() | ||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| if not isinstance(arg_value, tuple(valid_types)): | |||||
| raise_error_msg() | |||||
| return arg_value | |||||
| @staticmethod | @staticmethod | ||||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | def check_type_name(arg_name, arg_type, valid_types, prim_name): | ||||
| """Checks whether a type in some specified types""" | """Checks whether a type in some specified types""" | ||||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | ||||
| def get_typename(t): | |||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| msg_prefix = f"For '{prim_name}', the" if prim_name else "The" | |||||
| raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" | |||||
| f"{type_names if num_types > 1 else type_names[0]}, " | |||||
| f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.") | |||||
| if isinstance(arg_type, type(mstype.tensor)): | if isinstance(arg_type, type(mstype.tensor)): | ||||
| arg_type = arg_type.element_type() | arg_type = arg_type.element_type() | ||||
| if arg_type in valid_types: | |||||
| return arg_type | |||||
| type_names = [get_typename(t) for t in valid_types] | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| if len(valid_types) == 1: | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| if arg_type not in valid_types: | |||||
| raise_error_msg() | |||||
| return arg_type | |||||
| @staticmethod | @staticmethod | ||||
| def check_reduce_shape(ori_shape, shape, axis, prim_name): | def check_reduce_shape(ori_shape, shape, axis, prim_name): | ||||
| @@ -611,65 +600,6 @@ def check_output_data(data): | |||||
| once = _expand_tuple(1) | once = _expand_tuple(1) | ||||
| twice = _expand_tuple(2) | twice = _expand_tuple(2) | ||||
| triple = _expand_tuple(3) | triple = _expand_tuple(3) | ||||
| valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64, | |||||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | |||||
| np.float32, np.float64, bool, np.bool_) | |||||
| def check_type(arg_name, arg_value, valid_types): | |||||
| """Check value type.""" | |||||
| # if input type is Tensor ,get element type | |||||
| if isinstance(arg_value, type(mstype.tensor)): | |||||
| arg_value = arg_value.element_type() | |||||
| # First, check if arg_value has argvalid_types | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return type(arg_value).__name__ | |||||
| # Second, wrap arg_value with numpy array so that it can be checked through numpy api | |||||
| if isinstance(arg_value, (list, tuple)): | |||||
| arg_value = np.array(arg_value) | |||||
| # Thirdly, check the data type by numpy's dtype api | |||||
| valid = False | |||||
| if isinstance(arg_value, np.ndarray): | |||||
| valid = arg_value.dtype in valid_data_types | |||||
| # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and | |||||
| # `check_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| valid = False | |||||
| if not valid: | |||||
| type_names = [t.__name__ for t in valid_types] | |||||
| if len(valid_types) == 1: | |||||
| raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {type(arg_value).__name__}.') | |||||
| raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {type(arg_value).__name__}.') | |||||
| return type(arg_value).__name__ | |||||
| def check_typename(arg_name, arg_type, valid_types): | |||||
| """Check type name.""" | |||||
| def get_typename(t): | |||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||||
| if isinstance(arg_type, type(mstype.tensor)): | |||||
| arg_type = arg_type.element_type() | |||||
| if arg_type in valid_types: | |||||
| return arg_type | |||||
| if isinstance(arg_type, tuple(valid_types)): | |||||
| return arg_type | |||||
| type_names = [get_typename(t) for t in valid_types] | |||||
| if len(valid_types) == 1: | |||||
| raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| def args_type_check(*type_args, **type_kwargs): | def args_type_check(*type_args, **type_kwargs): | ||||
| @@ -19,7 +19,7 @@ from mindspore import log as logger | |||||
| from mindspore.communication.management import get_rank, get_group_size | from mindspore.communication.management import get_rank, get_group_size | ||||
| from .._c_expression import Tensor as Tensor_ | from .._c_expression import Tensor as Tensor_ | ||||
| from .._c_expression import MetaTensor as MetaTensor_ | from .._c_expression import MetaTensor as MetaTensor_ | ||||
| from .._checkparam import check_type, check_typename | |||||
| from .._checkparam import Validator as validator | |||||
| from . import dtype as mstype | from . import dtype as mstype | ||||
| from ._register_for_tensor import tensor_operator_registry | from ._register_for_tensor import tensor_operator_registry | ||||
| @@ -64,9 +64,19 @@ class Tensor(Tensor_): | |||||
| input_data = np.array(input_data) | input_data = np.array(input_data) | ||||
| # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. | # If input_data is tuple/list/numpy.ndarray, it's support in check_type method. | ||||
| check_type('tensor input_data', input_data, (Tensor_, float, int)) | |||||
| validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), | |||||
| 'Tensor') | |||||
| valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, | |||||
| np.float16, np.float32, np.float64, np.bool_) | |||||
| if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes: | |||||
| raise TypeError(f"For Tensor, the input_data is a numpy array whose data type is " | |||||
| f"{input_data.dtype} that is not supported to initialize a Tensor.") | |||||
| if isinstance(input_data, (tuple, list)): | |||||
| if np.array(input_data).dtype not in valid_dtypes: | |||||
| raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.") | |||||
| if dtype is not None: | if dtype is not None: | ||||
| check_typename('dtype', dtype, mstype.number_type + (mstype.bool_,)) | |||||
| validator.check_type_name('dtype', dtype, mstype.number_type + (mstype.bool_,), "Tensor") | |||||
| if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | if isinstance(input_data, np.ndarray) and (not input_data.flags['FORC']): | ||||
| input_data = np.ascontiguousarray(input_data) | input_data = np.ascontiguousarray(input_data) | ||||
| if dtype is None: | if dtype is None: | ||||
| @@ -405,8 +415,9 @@ class MetaTensor(MetaTensor_): | |||||
| Returns: | Returns: | ||||
| Array, an array after being initialized. | Array, an array after being initialized. | ||||
| """ | """ | ||||
| def __init__(self, dtype, shape, init=None): | def __init__(self, dtype, shape, init=None): | ||||
| #check param | |||||
| # check param | |||||
| self.init = init | self.init = init | ||||
| MetaTensor_.__init__(self, dtype, shape) | MetaTensor_.__init__(self, dtype, shape) | ||||
| @@ -434,8 +445,10 @@ class MetaTensor(MetaTensor_): | |||||
| msg = "Error shape={}".format(shape) | msg = "Error shape={}".format(shape) | ||||
| logger.error(msg) | logger.error(msg) | ||||
| raise ValueError(msg) | raise ValueError(msg) | ||||
| class seed_context: | class seed_context: | ||||
| '''set and restore seed''' | '''set and restore seed''' | ||||
| def __init__(self, init): | def __init__(self, init): | ||||
| self.init = init | self.init = init | ||||
| from .seed import get_seed | from .seed import get_seed | ||||
| @@ -482,4 +495,5 @@ def _vm_compare(*args): | |||||
| y = args[0] | y = args[0] | ||||
| return Tensor(np.array(fn(y))) | return Tensor(np.array(fn(y))) | ||||
| tensor_operator_registry.register('vm_compare', _vm_compare) | tensor_operator_registry.register('vm_compare', _vm_compare) | ||||
| @@ -21,7 +21,7 @@ from ...ops import operations as P | |||||
| from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | from ...ops.primitive import PrimitiveWithInfer, prim_attr_register | ||||
| from ...ops.composite import multitype_ops as C | from ...ops.composite import multitype_ops as C | ||||
| from ...ops.operations import _grad_ops as G | from ...ops.operations import _grad_ops as G | ||||
| from ..._checkparam import Validator | |||||
| from ..._checkparam import Validator as validator | |||||
| from ..cell import Cell, GraphKernel | from ..cell import Cell, GraphKernel | ||||
| @@ -194,7 +194,7 @@ class ApplyMomentum(GraphKernel): | |||||
| use_locking=False, | use_locking=False, | ||||
| gradient_scale=1.0): | gradient_scale=1.0): | ||||
| super(ApplyMomentum, self).__init__() | super(ApplyMomentum, self).__init__() | ||||
| self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float]) | |||||
| self.gradient_scale = validator.check_value_type('gradient_scale', gradient_scale, [float], type(self).__name__) | |||||
| self.fake_output_assign_1 = InplaceAssign() | self.fake_output_assign_1 = InplaceAssign() | ||||
| self.fake_output_assign_1.add_prim_attr("fake_output", True) | self.fake_output_assign_1.add_prim_attr("fake_output", True) | ||||
| self.fake_output_assign_2 = InplaceAssign() | self.fake_output_assign_2 = InplaceAssign() | ||||
| @@ -334,7 +334,7 @@ class ReduceMean(GraphKernel): | |||||
| def __init__(self, keep_dims=True): | def __init__(self, keep_dims=True): | ||||
| super(ReduceMean, self).__init__() | super(ReduceMean, self).__init__() | ||||
| self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], type(self).__name__) | |||||
| self.sum = P.ReduceSum(self.keep_dims) | self.sum = P.ReduceSum(self.keep_dims) | ||||
| def construct(self, x, axis): | def construct(self, x, axis): | ||||
| @@ -431,8 +431,10 @@ class LayerNormForward(GraphKernel): | |||||
| """ Forward function of the LayerNorm operator. """ | """ Forward function of the LayerNorm operator. """ | ||||
| def __init__(self, begin_norm_axis=1, begin_params_axis=1): | def __init__(self, begin_norm_axis=1, begin_params_axis=1): | ||||
| super(LayerNormForward, self).__init__() | super(LayerNormForward, self).__init__() | ||||
| self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int]) | |||||
| self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int]) | |||||
| self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], | |||||
| type(self).__name__) | |||||
| self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], | |||||
| type(self).__name__) | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.sum_keep_dims = P.ReduceSum(keep_dims=True) | self.sum_keep_dims = P.ReduceSum(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| @@ -686,7 +688,7 @@ class LogSoftmax(GraphKernel): | |||||
| def __init__(self, axis=-1): | def __init__(self, axis=-1): | ||||
| super(LogSoftmax, self).__init__() | super(LogSoftmax, self).__init__() | ||||
| self.axis = Validator.check_type('axis', axis, [int]) | |||||
| self.axis = validator.check_value_type('axis', axis, [int], type(self).__name__) | |||||
| self.max_keep_dims = P.ReduceMax(keep_dims=True) | self.max_keep_dims = P.ReduceMax(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -952,13 +954,13 @@ class Softmax(GraphKernel): | |||||
| def __init__(self, axis): | def __init__(self, axis): | ||||
| super(Softmax, self).__init__() | super(Softmax, self).__init__() | ||||
| Validator.check_type("axis", axis, [int, tuple]) | |||||
| validator.check_value_type("axis", axis, [int, tuple], type(self).__name__) | |||||
| if isinstance(axis, int): | if isinstance(axis, int): | ||||
| self.axis = (axis,) | self.axis = (axis,) | ||||
| else: | else: | ||||
| self.axis = axis | self.axis = axis | ||||
| for item in self.axis: | for item in self.axis: | ||||
| Validator.check_type("item of axis", item, [int]) | |||||
| validator.check_value_type("item of axis", item, [int], type(self).__name__) | |||||
| self.max = P.ReduceMax(keep_dims=True) | self.max = P.ReduceMax(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.exp = P.Exp() | self.exp = P.Exp() | ||||
| @@ -19,7 +19,7 @@ from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore._checkparam import Validator, check_typename | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from mindspore.communication.management import get_group_size, get_rank | from mindspore.communication.management import get_group_size, get_rank | ||||
| from mindspore.communication import management | from mindspore.communication import management | ||||
| @@ -52,7 +52,7 @@ class _BatchNorm(Cell): | |||||
| if momentum < 0 or momentum > 1: | if momentum < 0 or momentum > 1: | ||||
| raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) | ||||
| self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | if context.get_context("device_target") != "GPU" and self.format == "NHWC": | ||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.use_batch_statistics = use_batch_statistics | self.use_batch_statistics = use_batch_statistics | ||||
| @@ -67,7 +67,7 @@ class _BatchNorm(Cell): | |||||
| gamma_init, num_features), name="gamma", requires_grad=affine) | gamma_init, num_features), name="gamma", requires_grad=affine) | ||||
| self.beta = Parameter(initializer( | self.beta = Parameter(initializer( | ||||
| beta_init, num_features), name="beta", requires_grad=affine) | beta_init, num_features), name="beta", requires_grad=affine) | ||||
| self.group = Validator.check_positive_int(device_num_each_group) | |||||
| self.group = validator.check_positive_int(device_num_each_group) | |||||
| self.is_global = False | self.is_global = False | ||||
| if self.group != 1: | if self.group != 1: | ||||
| self.rank_id = get_rank() | self.rank_id = get_rank() | ||||
| @@ -472,7 +472,7 @@ class GlobalBatchNorm(_BatchNorm): | |||||
| use_batch_statistics, | use_batch_statistics, | ||||
| device_num_each_group, | device_num_each_group, | ||||
| input_dims='both') | input_dims='both') | ||||
| self.group = Validator.check_positive_int(device_num_each_group) | |||||
| self.group = validator.check_positive_int(device_num_each_group) | |||||
| if self.group <= 1: | if self.group <= 1: | ||||
| raise ValueError("the number of group must be greater than 1.") | raise ValueError("the number of group must be greater than 1.") | ||||
| @@ -607,12 +607,12 @@ class GroupNorm(Cell): | |||||
| def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): | def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): | ||||
| super(GroupNorm, self).__init__() | super(GroupNorm, self).__init__() | ||||
| self.num_groups = Validator.check_positive_int(num_groups) | |||||
| self.num_channels = Validator.check_positive_int(num_channels) | |||||
| self.num_groups = validator.check_positive_int(num_groups) | |||||
| self.num_channels = validator.check_positive_int(num_channels) | |||||
| if num_channels % num_groups != 0: | if num_channels % num_groups != 0: | ||||
| raise ValueError("num_channels should be divided by num_groups") | raise ValueError("num_channels should be divided by num_groups") | ||||
| self.eps = check_typename('eps', eps, (float,)) | |||||
| self.affine = Validator.check_bool(affine) | |||||
| self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__) | |||||
| self.affine = validator.check_bool(affine) | |||||
| gamma = initializer(gamma_init, num_channels) | gamma = initializer(gamma_init, num_channels) | ||||
| beta = initializer(beta_init, num_channels) | beta = initializer(beta_init, num_channels) | ||||
| @@ -442,8 +442,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): | |||||
| super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, | super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, | ||||
| symmetric=symmetric, narrow_range=narrow_range, | symmetric=symmetric, narrow_range=narrow_range, | ||||
| num_channels=num_channels) | num_channels=num_channels) | ||||
| Validator.check_type("min_init", min_init, [int, float]) | |||||
| Validator.check_type("max_init", max_init, [int, float]) | |||||
| Validator.check_value_type("min_init", min_init, [int, float], type(self).__name__) | |||||
| Validator.check_value_type("max_init", max_init, [int, float], type(self).__name__) | |||||
| Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) | ||||
| Validator.check_non_negative_int(quant_delay, 'quant_delay') | Validator.check_non_negative_int(quant_delay, 'quant_delay') | ||||
| self.min_init = min_init | self.min_init = min_init | ||||
| @@ -68,7 +68,7 @@ class GumbelCDF(Bijector): | |||||
| """ | """ | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | valid_dtype = mstype.float_type + mstype.int_type + mstype.uint_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | parameter_type = set_param_type({'loc': loc, "scale": scale}, dtype) | ||||
| super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | super(GumbelCDF, self).__init__(name=name, dtype=dtype, param=param) | ||||
| @@ -119,7 +119,7 @@ class Bernoulli(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Bernoulli, self).__init__(seed, dtype, name, param) | super(Bernoulli, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -109,7 +109,7 @@ class Categorical(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type | valid_dtype = mstype.int_type | ||||
| Validator.check_type("Categorical", dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Categorical, self).__init__(seed, dtype, name, param) | super(Categorical, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -121,7 +121,7 @@ class Exponential(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'rate': rate} | param['param_dict'] = {'rate': rate} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Exponential, self).__init__(seed, dtype, name, param) | super(Exponential, self).__init__(seed, dtype, name, param) | ||||
| self._rate = self._add_parameter(rate, 'rate') | self._rate = self._add_parameter(rate, 'rate') | ||||
| @@ -122,7 +122,7 @@ class Geometric(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs} | param['param_dict'] = {'probs': probs} | ||||
| valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | valid_dtype = mstype.int_type + mstype.uint_type + mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Geometric, self).__init__(seed, dtype, name, param) | super(Geometric, self).__init__(seed, dtype, name, param) | ||||
| self._probs = self._add_parameter(probs, 'probs') | self._probs = self._add_parameter(probs, 'probs') | ||||
| @@ -102,7 +102,7 @@ class Gumbel(TransformedDistribution): | |||||
| Constructor of Gumbel distribution. | Constructor of Gumbel distribution. | ||||
| """ | """ | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | gumbel_cdf = msb.GumbelCDF(loc, scale, dtype) | ||||
| super(Gumbel, self).__init__( | super(Gumbel, self).__init__( | ||||
| distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | distribution=msd.Uniform(0.0, 1.0, dtype=dtype), | ||||
| @@ -111,7 +111,7 @@ class Logistic(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'loc': loc, 'scale': scale} | param['param_dict'] = {'loc': loc, 'scale': scale} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Logistic, self).__init__(seed, dtype, name, param) | super(Logistic, self).__init__(seed, dtype, name, param) | ||||
| self._loc = self._add_parameter(loc, 'loc') | self._loc = self._add_parameter(loc, 'loc') | ||||
| @@ -127,7 +127,7 @@ class Normal(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'mean': mean, 'sd': sd} | param['param_dict'] = {'mean': mean, 'sd': sd} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Normal, self).__init__(seed, dtype, name, param) | super(Normal, self).__init__(seed, dtype, name, param) | ||||
| self._mean_value = self._add_parameter(mean, 'mean') | self._mean_value = self._add_parameter(mean, 'mean') | ||||
| @@ -126,7 +126,7 @@ class Uniform(Distribution): | |||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'low': low, 'high': high} | param['param_dict'] = {'low': low, 'high': high} | ||||
| valid_dtype = mstype.float_type | valid_dtype = mstype.float_type | ||||
| Validator.check_type(type(self).__name__, dtype, valid_dtype) | |||||
| Validator.check_type_name("dtype", dtype, valid_dtype, type(self).__name__) | |||||
| super(Uniform, self).__init__(seed, dtype, name, param) | super(Uniform, self).__init__(seed, dtype, name, param) | ||||
| self._low = self._add_parameter(low, 'low') | self._low = self._add_parameter(low, 'low') | ||||
| @@ -55,8 +55,7 @@ class UpdateCache(PrimitiveWithInfer): | |||||
| return [1] | return [1] | ||||
| def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | ||||
| args = {"indices": indices_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||||
| return input_x_dtype | return input_x_dtype | ||||
| @@ -140,7 +139,7 @@ class SearchCacheIdx(PrimitiveWithInfer): | |||||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | ||||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||||
| out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | ||||
| return out_dtype | return out_dtype | ||||
| @@ -182,8 +181,7 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | ||||
| args = {"miss_emb_idx": miss_emb_idx_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||||
| out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | ||||
| return out_dtype | return out_dtype | ||||
| @@ -224,8 +222,7 @@ class CacheSwapTable(PrimitiveWithInfer): | |||||
| return miss_value_shape | return miss_value_shape | ||||
| def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | ||||
| args = {"swap_cache_idx": swap_cache_idx_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||||
| return miss_value_dtype | return miss_value_dtype | ||||
| @@ -261,7 +258,7 @@ class MapCacheIdx(PrimitiveWithInfer): | |||||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | ||||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||||
| out_dtype = (hashmap_dtype, hashmap_dtype, | out_dtype = (hashmap_dtype, hashmap_dtype, | ||||
| hashmap_dtype, hashmap_dtype) | hashmap_dtype, hashmap_dtype) | ||||
| return out_dtype | return out_dtype | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Operators for gradients.""" | """Operators for gradients.""" | ||||
| from functools import partial | |||||
| from .. import signature as sig | from .. import signature as sig | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| @@ -23,6 +24,7 @@ from ...common import dtype as mstype | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from ... import context | from ... import context | ||||
| class AbsGrad(PrimitiveWithInfer): | class AbsGrad(PrimitiveWithInfer): | ||||
| """Computes gradients for abs operation.""" | """Computes gradients for abs operation.""" | ||||
| @@ -55,7 +57,7 @@ class ACosGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x, dout): | def infer_dtype(self, x, dout): | ||||
| args = {"x": x, "dout": dout} | args = {"x": x, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -72,7 +74,7 @@ class AcoshGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x, dout): | def infer_dtype(self, x, dout): | ||||
| args = {"x": x, "dout": dout} | args = {"x": x, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -94,7 +96,7 @@ class AsinGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x, dout): | def infer_dtype(self, x, dout): | ||||
| args = {"x": x, "dout": dout} | args = {"x": x, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -111,7 +113,7 @@ class AsinhGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x, dout): | def infer_dtype(self, x, dout): | ||||
| args = {"x": x, "dout": dout} | args = {"x": x, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -128,7 +130,7 @@ class ReciprocalGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | def infer_dtype(self, x_dtype, dout_dtype): | ||||
| args = {"x": x_dtype, "dout": dout_dtype} | args = {"x": x_dtype, "dout": dout_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -145,7 +147,8 @@ class RsqrtGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | def infer_dtype(self, x_dtype, dout_dtype): | ||||
| args = {"x": x_dtype, "dout": dout_dtype} | args = {"x": x_dtype, "dout": dout_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], | |||||
| self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -162,7 +165,7 @@ class SoftmaxGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | def infer_dtype(self, x_dtype, dout_dtype): | ||||
| args = {"x": x_dtype, "dout": dout_dtype} | args = {"x": x_dtype, "dout": dout_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -179,7 +182,7 @@ class SqrtGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, dout_dtype): | def infer_dtype(self, x_dtype, dout_dtype): | ||||
| args = {"x": x_dtype, "dout": dout_dtype} | args = {"x": x_dtype, "dout": dout_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -232,7 +235,7 @@ class KLDivLossGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, y_type, doutput_type): | def infer_dtype(self, x_type, y_type, doutput_type): | ||||
| args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| return x_type, y_type | return x_type, y_type | ||||
| @@ -251,7 +254,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, y_type, doutput_type, weight_type): | def infer_dtype(self, x_type, y_type, doutput_type, weight_type): | ||||
| args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| if weight_type: | if weight_type: | ||||
| validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) | validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) | ||||
| return x_type | return x_type | ||||
| @@ -343,7 +346,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): | |||||
| for i, dim_len in enumerate(w_size_v): | for i, dim_len in enumerate(w_size_v): | ||||
| validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) | validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) | ||||
| args = {"x": x['dtype'], "doutput": doutput['dtype']} | args = {"x": x['dtype'], "doutput": doutput['dtype']} | ||||
| validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], | |||||
| self.name) | |||||
| out = { | out = { | ||||
| 'value': None, | 'value': None, | ||||
| 'shape': w_size_v, | 'shape': w_size_v, | ||||
| @@ -406,7 +410,7 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): | |||||
| def __infer__(self, x, w_size, dout): | def __infer__(self, x, w_size, dout): | ||||
| w_size_v = w_size['value'] | w_size_v = w_size['value'] | ||||
| args = {'x': x['dtype'], 'dout': dout['dtype']} | args = {'x': x['dtype'], 'dout': dout['dtype']} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| out = { | out = { | ||||
| 'value': None, | 'value': None, | ||||
| 'shape': w_size_v, | 'shape': w_size_v, | ||||
| @@ -466,7 +470,7 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): | |||||
| def __infer__(self, x_size, w, dout): | def __infer__(self, x_size, w, dout): | ||||
| args = {'w': w['dtype'], 'dout': dout['dtype']} | args = {'w': w['dtype'], 'dout': dout['dtype']} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| x_size_v = x_size['value'] | x_size_v = x_size['value'] | ||||
| out = { | out = { | ||||
| 'value': None, | 'value': None, | ||||
| @@ -505,10 +509,9 @@ class DropoutGrad(PrimitiveWithInfer): | |||||
| return dy_shape | return dy_shape | ||||
| def infer_dtype(self, dy_dtype, mask_dtype): | def infer_dtype(self, dy_dtype, mask_dtype): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_subclass("dy", dy_dtype, mstype.tensor, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) | validator.check_subclass("mask", mask_dtype, mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"dy_dtype": dy_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid("dy", dy_dtype, valid_dtypes, self.name) | |||||
| return dy_dtype | return dy_dtype | ||||
| @@ -627,9 +630,10 @@ class GeluGrad(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): | def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): | ||||
| validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("y_backprop", "x", "y"), | |||||
| (y_backprop_dtype, x_dtype, y_dtype))) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -782,7 +786,7 @@ class MaxPoolGradGrad(_PoolGrad): | |||||
| def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): | def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype): | ||||
| args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} | args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) | |||||
| return x1_dtype | return x1_dtype | ||||
| @@ -858,7 +862,7 @@ class MaxPoolGradGradWithArgmax(_PoolGrad): | |||||
| def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): | def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype): | ||||
| args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} | args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name) | |||||
| return grad_dtype | return grad_dtype | ||||
| @@ -902,7 +906,7 @@ class L2NormalizeGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, input_x, out, dout): | def infer_dtype(self, input_x, out, dout): | ||||
| args = {'input_x': input_x, 'out': out, 'dout': dout} | args = {'input_x': input_x, 'out': out, 'dout': dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return input_x | return input_x | ||||
| @@ -993,7 +997,7 @@ class LSTMGradData(PrimitiveWithInfer): | |||||
| def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, | def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, | ||||
| hx_dtype, cx_dtype, reserve_dtype, state_dtype): | hx_dtype, cx_dtype, reserve_dtype, state_dtype): | ||||
| args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} | args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) | |||||
| return (dy_dtype, dy_dtype, dy_dtype) | return (dy_dtype, dy_dtype, dy_dtype) | ||||
| @@ -1265,14 +1269,14 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||||
| args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, | args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype, | ||||
| "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, | "dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype, | ||||
| "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | "reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype} | ||||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name) | |||||
| if seq_dtype is not None: | if seq_dtype is not None: | ||||
| validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name) | |||||
| validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name) | |||||
| if mask_dtype is not None: | if mask_dtype is not None: | ||||
| validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name) | |||||
| validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name) | |||||
| return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype | return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype | ||||
| @@ -1302,10 +1306,10 @@ class PReLUGrad(PrimitiveWithInfer): | |||||
| return y_backprop_shape, w_shape | return y_backprop_shape, w_shape | ||||
| def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): | def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ('y_backprop', "input_x", "weight"), | |||||
| (y_backprop_dtype, A_dtype, w_dtype))) | |||||
| return y_backprop_dtype, w_dtype | return y_backprop_dtype, w_dtype | ||||
| @@ -1335,8 +1339,9 @@ class ReLU6Grad(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, y_grad_dtype, x_dtype): | def infer_dtype(self, y_grad_dtype, x_dtype): | ||||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1354,8 +1359,8 @@ class ReluGradV2(PrimitiveWithInfer): | |||||
| return gradients_shape | return gradients_shape | ||||
| def infer_dtype(self, gradients_dtype, mask_dtype): | def infer_dtype(self, gradients_dtype, mask_dtype): | ||||
| validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) | |||||
| validator.check_tensor_dtype_valid('gradients', gradients_dtype, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('mask', mask_dtype, (mstype.uint8,), self.name) | |||||
| return gradients_dtype | return gradients_dtype | ||||
| @@ -1371,7 +1376,7 @@ class EluGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, y_grad_dtype, x_dtype): | def infer_dtype(self, y_grad_dtype, x_dtype): | ||||
| args = {'y_grad': y_grad_dtype, 'x': x_dtype} | args = {'y_grad': y_grad_dtype, 'x': x_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.float_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1474,7 +1479,7 @@ class SigmoidGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, out, dout): | def infer_dtype(self, out, dout): | ||||
| args = {'out': out, 'dout': dout} | args = {'out': out, 'dout': dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return out | return out | ||||
| @@ -1489,8 +1494,9 @@ class HSigmoidGrad(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, y_grad_dtype, x_dtype): | def infer_dtype(self, y_grad_dtype, x_dtype): | ||||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1505,8 +1511,9 @@ class HSwishGrad(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, y_grad_dtype, x_dtype): | def infer_dtype(self, y_grad_dtype, x_dtype): | ||||
| validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1525,7 +1532,7 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, y_dtype, dout_dtype): | def infer_dtype(self, x_dtype, y_dtype, dout_dtype): | ||||
| args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} | args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return dout_dtype | return dout_dtype | ||||
| @@ -1562,7 +1569,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, prediction, target, dloss): | def infer_dtype(self, prediction, target, dloss): | ||||
| args = {"prediction": prediction, "target": target, 'dloss': dloss} | args = {"prediction": prediction, "target": target, 'dloss': dloss} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return dloss | return dloss | ||||
| @@ -1597,8 +1604,7 @@ class StridedSliceGrad(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) | ||||
| def __infer__(self, dy, shapex, begin, end, strides): | def __infer__(self, dy, shapex, begin, end, strides): | ||||
| args = {"dy": dy['dtype']} | |||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name) | |||||
| for idx, item in enumerate(shapex['value']): | for idx, item in enumerate(shapex['value']): | ||||
| validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) | validator.check_value_type("shapex[%d]" % idx, item, [int], self.name) | ||||
| @@ -1627,7 +1633,7 @@ class SoftplusGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, dout_dtype, x_dtype): | def infer_dtype(self, dout_dtype, x_dtype): | ||||
| args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} | args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.float_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1643,7 +1649,7 @@ class TanhGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, out, dout): | def infer_dtype(self, out, dout): | ||||
| args = {"out": out, "dout": dout} | args = {"out": out, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return out | return out | ||||
| @@ -1756,7 +1762,7 @@ class AtanGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x, dout): | def infer_dtype(self, x, dout): | ||||
| args = {"x": x, "dout": dout} | args = {"x": x, "dout": dout} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -1900,7 +1906,7 @@ class LRNGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, grads, x, y): | def infer_dtype(self, grads, x, y): | ||||
| args = {"grads": grads, "x": x, "y": y} | args = {"grads": grads, "x": x, "y": y} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name) | |||||
| return x | return x | ||||
| def infer_shape(self, grads, x, y): | def infer_shape(self, grads, x, y): | ||||
| @@ -54,6 +54,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, ksizes, strides, rates, padding="valid"): | def __init__(self, ksizes, strides, rates, padding="valid"): | ||||
| """init""" | """init""" | ||||
| def _check_tuple_or_list(arg_name, arg_val, prim_name): | def _check_tuple_or_list(arg_name, arg_val, prim_name): | ||||
| validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) | ||||
| if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: | ||||
| @@ -103,7 +104,7 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| def infer_dtype(self, input_x): | def infer_dtype(self, input_x): | ||||
| """infer dtype""" | """infer dtype""" | ||||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name) | |||||
| return input_x | return input_x | ||||
| @@ -161,7 +162,7 @@ class Range(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -254,6 +255,7 @@ class Dequant(PrimitiveWithInfer): | |||||
| >>> dequant = P.Dequant(False, False) | >>> dequant = P.Dequant(False, False) | ||||
| >>> y = dequant(input_x) | >>> y = dequant(input_x) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, sqrt_mode=False, relu_flag=False): | def __init__(self, sqrt_mode=False, relu_flag=False): | ||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| @@ -303,10 +305,9 @@ class LinSpace(PrimitiveWithInfer): | |||||
| return assist | return assist | ||||
| def infer_dtype(self, assist, start, stop, num): | def infer_dtype(self, assist, start, stop, num): | ||||
| args = {"num": num} | |||||
| validator.check_tensor_type_same(args, (mstype.int32,), self.name) | |||||
| validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name) | |||||
| args = {"assist": assist, "start": start, "stop": stop} | args = {"assist": assist, "start": start, "stop": stop} | ||||
| validator.check_tensor_type_same(args, (mstype.float32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name) | |||||
| return assist | return assist | ||||
| @@ -343,12 +344,12 @@ class MatrixDiag(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, assist_dtype): | def infer_dtype(self, x_dtype, assist_dtype): | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | ||||
| args = {"x": x_dtype, "assist": assist_dtype} | args = {"x": x_dtype, "assist": assist_dtype} | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, assist_shape): | def infer_shape(self, x_shape, assist_shape): | ||||
| validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) | validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name) | ||||
| validator.check('rank of x', len(x_shape)+1, | |||||
| validator.check('rank of x', len(x_shape) + 1, | |||||
| 'rank of assist', len(assist_shape), Rel.LE, self.name) | 'rank of assist', len(assist_shape), Rel.LE, self.name) | ||||
| validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension', | ||||
| assist_shape[-1], Rel.EQ, self.name) | assist_shape[-1], Rel.EQ, self.name) | ||||
| @@ -358,7 +359,7 @@ class MatrixDiag(PrimitiveWithInfer): | |||||
| while r_idx >= r_end_dim: | while r_idx >= r_end_dim: | ||||
| if x_shape[r_idx] != 1: | if x_shape[r_idx] != 1: | ||||
| validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % | validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" % | ||||
| assist_shape[r_idx-1], assist_shape[r_idx-1], Rel.EQ, self.name) | |||||
| assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name) | |||||
| r_idx = r_idx - 1 | r_idx = r_idx - 1 | ||||
| return assist_shape | return assist_shape | ||||
| @@ -391,7 +392,7 @@ class MatrixDiagPart(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, assist_dtype): | def infer_dtype(self, x_dtype, assist_dtype): | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | ||||
| args = {"x": x_dtype, "assist": assist_dtype} | args = {"x": x_dtype, "assist": assist_dtype} | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, assist_shape): | def infer_shape(self, x_shape, assist_shape): | ||||
| @@ -434,7 +435,7 @@ class MatrixSetDiag(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): | def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype): | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8] | ||||
| args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} | args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype} | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, diagonal_shape, assist_shape): | def infer_shape(self, x_shape, diagonal_shape, assist_shape): | ||||
| @@ -583,21 +584,21 @@ class DynamicGRUV2(PrimitiveWithInfer): | |||||
| return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape | return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape | ||||
| def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype): | ||||
| validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name) | |||||
| validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name) | |||||
| validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name) | |||||
| validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) | |||||
| validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) | |||||
| validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) | |||||
| b_dtype = mstype.float32 | b_dtype = mstype.float32 | ||||
| if binput_dtype is not None: | if binput_dtype is not None: | ||||
| validator.check_tensor_type_same({"bias input dtype": binput_dtype}, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| b_dtype = binput_dtype | b_dtype = binput_dtype | ||||
| elif bhidden_dtype is not None: | elif bhidden_dtype is not None: | ||||
| validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype}, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| b_dtype = bhidden_dtype | b_dtype = bhidden_dtype | ||||
| elif h_dtype is not None: | elif h_dtype is not None: | ||||
| validator.check_tensor_type_same({"init_h dtype": h_dtype}, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("init_h dtype", h_dtype, | |||||
| (mstype.float16, mstype.float32), self.name) | |||||
| b_dtype = h_dtype | b_dtype = h_dtype | ||||
| return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Operators for quantization.""" | """Operators for quantization.""" | ||||
| from functools import partial | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| @@ -92,12 +93,10 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||||
| return min_shape, max_shape | return min_shape, max_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return min_type, max_type | return min_type, max_type | ||||
| @@ -157,13 +156,10 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||||
| return min_shape, max_shape | return min_shape, max_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same( | |||||
| {"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return min_type, max_type | return min_type, max_type | ||||
| @@ -193,6 +189,7 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): | |||||
| >>> input_tensor, min_tensor, max_tensor) | >>> input_tensor, min_tensor, max_tensor) | ||||
| >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 | >>> output_tensor shape: (3, 16, 5, 5) data type: mstype.float32 | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, | def __init__(self, | ||||
| num_bits=8, | num_bits=8, | ||||
| @@ -217,10 +214,10 @@ class FakeQuantWithMinMaxVars(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return x_type | return x_type | ||||
| @@ -256,6 +253,7 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): | |||||
| >>> min_gradient shape: (1,) data type: mstype.float32 | >>> min_gradient shape: (1,) data type: mstype.float32 | ||||
| >>> max_gradient shape: (1,) data type: mstype.float32 | >>> max_gradient shape: (1,) data type: mstype.float32 | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, | def __init__(self, | ||||
| num_bits=8, | num_bits=8, | ||||
| @@ -281,11 +279,10 @@ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): | |||||
| return x_shape, min_shape, max_shape | return x_shape, min_shape, max_shape | ||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ('dout', "x", "min", "max"), | |||||
| (dout_type, x_type, min_type, max_type))) | |||||
| return x_type, min_type, max_type | return x_type, min_type, max_type | ||||
| @@ -315,6 +312,7 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): | |||||
| >>> input_tensor, min_tensor, max_tensor) | >>> input_tensor, min_tensor, max_tensor) | ||||
| >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 | >>> output_tensor shape: (3, 16, 3, 4) data type: mstype.float32 | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, | def __init__(self, | ||||
| num_bits=8, | num_bits=8, | ||||
| @@ -332,10 +330,10 @@ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return x_type | return x_type | ||||
| @@ -372,6 +370,7 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): | |||||
| >>> min_gradient shape: (4,) data type: mstype.float32 | >>> min_gradient shape: (4,) data type: mstype.float32 | ||||
| >>> max_gradient shape: (4,) data type: mstype.float32 | >>> max_gradient shape: (4,) data type: mstype.float32 | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, | def __init__(self, | ||||
| num_bits=8, | num_bits=8, | ||||
| @@ -390,11 +389,10 @@ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): | |||||
| return x_shape, min_shape, max_shape | return x_shape, min_shape, max_shape | ||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({'dout': dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'min': min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({'max': max_type}, valid_types, self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("dout", "x", "min", "max"), | |||||
| (dout_type, x_type, min_type, max_type))) | |||||
| return x_type, min_type, max_type | return x_type, min_type, max_type | ||||
| @@ -468,14 +466,12 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| if context.get_context('device_target') == "GPU": | if context.get_context('device_target') == "GPU": | ||||
| valid_types = (mstype.float32,) | |||||
| valid_dtypes = (mstype.float32,) | |||||
| else: | else: | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return x_type | return x_type | ||||
| @@ -525,16 +521,12 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| if context.get_context('device_target') == "GPU": | if context.get_context('device_target') == "GPU": | ||||
| valid_types = (mstype.float32,) | |||||
| valid_dtypes = (mstype.float32,) | |||||
| else: | else: | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same( | |||||
| {"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||||
| ("dout", "x", "min", "max"), | |||||
| (dout_type, x_type, min_type, max_type))) | |||||
| return dout_type | return dout_type | ||||
| @@ -623,14 +615,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| if context.get_context('device_target') == "GPU": | if context.get_context('device_target') == "GPU": | ||||
| valid_types = (mstype.float32,) | |||||
| valid_dtypes = (mstype.float32,) | |||||
| else: | else: | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||||
| ("x", "min", "max"), | |||||
| (x_type, min_type, max_type))) | |||||
| return x_type | return x_type | ||||
| @@ -680,16 +670,12 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| if context.get_context('device_target') == "GPU": | if context.get_context('device_target') == "GPU": | ||||
| valid_types = (mstype.float32,) | |||||
| valid_dtypes = (mstype.float32,) | |||||
| else: | else: | ||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same( | |||||
| {"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), | |||||
| ("dout", "x", "min", "max"), | |||||
| (dout_type, x_type, min_type, max_type))) | |||||
| return dout_type | return dout_type | ||||
| @@ -750,8 +736,8 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| validator.check("input type", x_type, "mean type", mean_type) | validator.check("input type", x_type, "mean type", mean_type) | ||||
| validator.check("input type", x_type, "variance type", variance_type) | validator.check("input type", x_type, "variance type", variance_type) | ||||
| args = {"x": x_type, "mean": mean_type, "variance": variance_type} | args = {"x": x_type, "mean": mean_type, "variance": variance_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||||
| return x_type, x_type, x_type, x_type | return x_type, x_type, x_type, x_type | ||||
| @@ -797,8 +783,8 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||||
| global_step_type): | global_step_type): | ||||
| args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, | args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, | ||||
| "batch_mean": batch_mean_type, "batch_std": batch_std_type} | "batch_mean": batch_mean_type, "batch_std": batch_std_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||||
| return x_type | return x_type | ||||
| @@ -841,7 +827,7 @@ class CorrectionMul(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, batch_std_type, running_std_type): | def infer_dtype(self, x_type, batch_std_type, running_std_type): | ||||
| args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} | args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| return x_type | return x_type | ||||
| @@ -879,7 +865,7 @@ class CorrectionMulGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): | ||||
| args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} | args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| if context.get_context('device_target') == "Ascend": | if context.get_context('device_target') == "Ascend": | ||||
| return x_type, x_type | return x_type, x_type | ||||
| return x_type, gamma_type | return x_type, gamma_type | ||||
| @@ -972,8 +958,8 @@ class BatchNormFold2(PrimitiveWithInfer): | |||||
| running_mean_type, global_step_type): | running_mean_type, global_step_type): | ||||
| args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | ||||
| "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} | "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||||
| return x_type | return x_type | ||||
| @@ -1031,8 +1017,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): | |||||
| "dout type", dout_type) | "dout type", dout_type) | ||||
| args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | ||||
| "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} | "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) | |||||
| return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type | return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type | ||||
| @@ -1061,7 +1047,7 @@ class BatchNormFoldD(PrimitiveWithInfer): | |||||
| validator.check("input type", x_type, "mean type", mean_type) | validator.check("input type", x_type, "mean type", mean_type) | ||||
| validator.check("input type", x_type, "variance type", variance_type) | validator.check("input type", x_type, "variance type", variance_type) | ||||
| args = {"x": x_type, "mean": mean_type, "variance": variance_type} | args = {"x": x_type, "mean": mean_type, "variance": variance_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| return x_type, x_type, x_type, x_type, x_type, x_type, x_type | return x_type, x_type, x_type, x_type, x_type, x_type, x_type | ||||
| @@ -1090,8 +1076,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer): | |||||
| validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) | validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) | ||||
| validator.check("input type", x_type, "batch_mean type", batch_mean_type) | validator.check("input type", x_type, "batch_mean type", batch_mean_type) | ||||
| validator.check("input type", x_type, "batch_std type", batch_std_type) | validator.check("input type", x_type, "batch_std type", batch_std_type) | ||||
| args = {"input type": x_type} | |||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name) | |||||
| return x_type | return x_type | ||||
| @@ -1136,7 +1121,7 @@ class BatchNormFold2_D(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): | def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): | ||||
| args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, | ||||
| "beta": beta_type, "gamma": gamma_type, "x": x_type} | "beta": beta_type, "gamma": gamma_type, "x": x_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| return x_type | return x_type | ||||
| @@ -1174,7 +1159,7 @@ class BatchNormFold2GradD(PrimitiveWithInfer): | |||||
| "dout type", dout_type) | "dout type", dout_type) | ||||
| args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, | ||||
| "running_std": running_std_type, "dout": dout_type} | "running_std": running_std_type, "dout": dout_type} | ||||
| validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) | |||||
| return gamma_type, gamma_type, gamma_type, gamma_type | return gamma_type, gamma_type, gamma_type, gamma_type | ||||
| @@ -165,7 +165,7 @@ class CusFusedAbsMax1(PrimitiveWithInfer): | |||||
| def infer_shape(self, data1_shape): | def infer_shape(self, data1_shape): | ||||
| ll = [] | ll = [] | ||||
| if len(data1_shape) == 2: | if len(data1_shape) == 2: | ||||
| ll = [1,] | |||||
| ll = [1] | |||||
| else: | else: | ||||
| ll = [32, 64] | ll = [32, 64] | ||||
| return ll | return ll | ||||
| @@ -497,6 +497,7 @@ class Im2Col(PrimitiveWithInfer): | |||||
| >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) | >>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2) | ||||
| >>> output = img2col(input_x) | >>> output = img2col(input_x) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, | def __init__(self, | ||||
| kernel_size, | kernel_size, | ||||
| @@ -556,9 +557,8 @@ class Im2Col(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| args = {'x': x_dtype} | |||||
| valid_types = [mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||||
| valid_dtypes = [mstype.float16, mstype.float32] | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -602,14 +602,17 @@ class UpdateThorGradient(PrimitiveWithInfer): | |||||
| return x2_shape | return x2_shape | ||||
| def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): | def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype): | ||||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, | |||||
| [mstype.float32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid( | |||||
| {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype}, | |||||
| [mstype.float32], self.name) | |||||
| return x2_dtype | return x2_dtype | ||||
| class Cholesky(PrimitiveWithInfer): | class Cholesky(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Inner API for resnet50 THOR GPU backend | Inner API for resnet50 THOR GPU backend | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, split_dim=0): | def __init__(self, split_dim=0): | ||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | self.init_prim_io_names(inputs=['x1'], outputs=['y']) | ||||
| @@ -634,13 +637,15 @@ class Cholesky(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x1_dtype): | def infer_dtype(self, x1_dtype): | ||||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) | |||||
| return x1_dtype | return x1_dtype | ||||
| class DetTriangle(PrimitiveWithInfer): | class DetTriangle(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Calculate the determinant of triangle matrices | Calculate the determinant of triangle matrices | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, fill_mode=0): | def __init__(self, fill_mode=0): | ||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | self.init_prim_io_names(inputs=['x1'], outputs=['y']) | ||||
| @@ -653,5 +658,5 @@ class DetTriangle(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x1_dtype): | def infer_dtype(self, x1_dtype): | ||||
| validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid('x1', x1_dtype, [mstype.float32], self.name) | |||||
| return x1_dtype | return x1_dtype | ||||
| @@ -63,9 +63,9 @@ class _ScatterOp(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "updates": updates_dtype} | args = {"x": x_dtype, "updates": updates_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -73,6 +73,7 @@ class _ScatterNdOp(_ScatterOp): | |||||
| """ | """ | ||||
| Defines _ScatterNd operators | Defines _ScatterNd operators | ||||
| """ | """ | ||||
| def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): | ||||
| validator.check('the dimension of x', len(x_shape), | validator.check('the dimension of x', len(x_shape), | ||||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | 'the dimension of indices', indices_shape[-1], Rel.GE) | ||||
| @@ -627,6 +628,7 @@ class Unique(Primitive): | |||||
| >>> out = P.Unique()(x) | >>> out = P.Unique()(x) | ||||
| (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) | (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32)) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| @@ -661,11 +663,11 @@ class GatherV2(PrimitiveWithCheck): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize index_select""" | """Initialize index_select""" | ||||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) | ||||
| self.add_prim_attr("dynamic_shape_depends", [2,]) | |||||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||||
| def __check__(self, params, indices, axis): | def __check__(self, params, indices, axis): | ||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | ||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| @@ -727,6 +729,7 @@ class Padding(PrimitiveWithInfer): | |||||
| >>> out = P.Padding(pad_dim_size)(x) | >>> out = P.Padding(pad_dim_size)(x) | ||||
| [[8, 0, 0, 0], [10, 0, 0, 0]] | [[8, 0, 0, 0], [10, 0, 0, 0]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, pad_dim_size=8): | def __init__(self, pad_dim_size=8): | ||||
| """Initialize padding""" | """Initialize padding""" | ||||
| @@ -766,12 +769,13 @@ class UniqueWithPad(PrimitiveWithInfer): | |||||
| >>> out = P.UniqueWithPad()(x, pad_num) | >>> out = P.UniqueWithPad()(x, pad_num) | ||||
| ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) | ([1, 5, 4, 3, 2, 8, 8, 8, 8, 8], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """init UniqueWithPad""" | """init UniqueWithPad""" | ||||
| def __infer__(self, x, pad_num): | def __infer__(self, x, pad_num): | ||||
| validator.check_tensor_type_same({"x": x['dtype']}, [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) | validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name) | ||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) | validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name) | ||||
| @@ -903,7 +907,7 @@ class TruncatedNormal(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, dtype=mstype.float32): | def __init__(self, seed=0, dtype=mstype.float32): | ||||
| """Initialize TruncatedNormal""" | """Initialize TruncatedNormal""" | ||||
| validator.check_value_type('seed', seed, [int], self.name) | validator.check_value_type('seed', seed, [int], self.name) | ||||
| validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name) | |||||
| validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name) | |||||
| def __infer__(self, shape): | def __infer__(self, shape): | ||||
| shape_value = shape['value'] | shape_value = shape['value'] | ||||
| @@ -984,10 +988,10 @@ class Fill(PrimitiveWithInfer): | |||||
| validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) | validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name) | ||||
| for i, item in enumerate(dims['value']): | for i, item in enumerate(dims['value']): | ||||
| validator.check_positive_int(item, f'dims[{i}]', self.name) | validator.check_positive_int(item, f'dims[{i}]', self.name) | ||||
| valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||||
| mstype.uint8, mstype.uint32, mstype.uint64, | |||||
| mstype.float16, mstype.float32, mstype.float64] | |||||
| validator.check_type_same({"value": dtype['value']}, valid_types, self.name) | |||||
| valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||||
| mstype.uint8, mstype.uint32, mstype.uint64, | |||||
| mstype.float16, mstype.float32, mstype.float64] | |||||
| validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name) | |||||
| x_nptype = mstype.dtype_to_nptype(dtype['value']) | x_nptype = mstype.dtype_to_nptype(dtype['value']) | ||||
| ret = np.full(dims['value'], x['value'], x_nptype) | ret = np.full(dims['value'], x['value'], x_nptype) | ||||
| out = { | out = { | ||||
| @@ -1026,7 +1030,7 @@ class OnesLike(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1059,7 +1063,7 @@ class ZerosLike(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1264,7 +1268,7 @@ class Argmax(PrimitiveWithInfer): | |||||
| """Initialize Argmax""" | """Initialize Argmax""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| validator.check_value_type("axis", axis, [int], self.name) | validator.check_value_type("axis", axis, [int], self.name) | ||||
| validator.check_type_same({'output': output_type}, [mstype.int32], self.name) | |||||
| validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name) | |||||
| self.axis = axis | self.axis = axis | ||||
| self.add_prim_attr('output_type', output_type) | self.add_prim_attr('output_type', output_type) | ||||
| @@ -1547,7 +1551,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize UnsortedSegmentSum""" | """Initialize UnsortedSegmentSum""" | ||||
| self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) | ||||
| self.add_prim_attr("dynamic_shape_depends", [2,]) | |||||
| self.add_prim_attr("dynamic_shape_depends", [2]) | |||||
| def __infer__(self, x, segment_ids, num_segments): | def __infer__(self, x, segment_ids, num_segments): | ||||
| x_type = x['dtype'] | x_type = x['dtype'] | ||||
| @@ -1570,7 +1574,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): | |||||
| num_segments_type = num_segments['dtype'] | num_segments_type = num_segments['dtype'] | ||||
| validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) | ||||
| if isinstance(num_segments_type, type(mstype.tensor)): | if isinstance(num_segments_type, type(mstype.tensor)): | ||||
| validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32], self.name) | |||||
| shp = [-1] | shp = [-1] | ||||
| else: | else: | ||||
| validator.check_value_type('num_segments', num_segments_v, [int], self.name) | validator.check_value_type('num_segments', num_segments_v, [int], self.name) | ||||
| @@ -1623,8 +1627,8 @@ class UnsortedSegmentMin(PrimitiveWithInfer): | |||||
| x_shape = x['shape'] | x_shape = x['shape'] | ||||
| segment_ids_shape = segment_ids['shape'] | segment_ids_shape = segment_ids['shape'] | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | valid_type = [mstype.float16, mstype.float32, mstype.int32] | ||||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) | |||||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | ||||
| validator.check(f'first shape of input_x', x_shape[0], | validator.check(f'first shape of input_x', x_shape[0], | ||||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | ||||
| @@ -1673,8 +1677,8 @@ class UnsortedSegmentMax(PrimitiveWithInfer): | |||||
| x_shape = x['shape'] | x_shape = x['shape'] | ||||
| segment_ids_shape = segment_ids['shape'] | segment_ids_shape = segment_ids['shape'] | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | valid_type = [mstype.float16, mstype.float32, mstype.int32] | ||||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | ||||
| validator.check(f'first shape of input_x', x_shape[0], | validator.check(f'first shape of input_x', x_shape[0], | ||||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | ||||
| @@ -1726,8 +1730,8 @@ class UnsortedSegmentProd(PrimitiveWithInfer): | |||||
| validator.check_subclass("input_x", x_type, mstype.tensor, self.name) | validator.check_subclass("input_x", x_type, mstype.tensor, self.name) | ||||
| validator.check_value_type("x_shape", x_shape, [list], self.name) | validator.check_value_type("x_shape", x_shape, [list], self.name) | ||||
| valid_type = [mstype.float16, mstype.float32, mstype.int32] | valid_type = [mstype.float16, mstype.float32, mstype.int32] | ||||
| validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name) | |||||
| validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name) | ||||
| validator.check(f'first shape of input_x', x_shape[0], | validator.check(f'first shape of input_x', x_shape[0], | ||||
| 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) | ||||
| @@ -1833,7 +1837,7 @@ class ParallelConcat(PrimitiveWithInfer): | |||||
| validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) | validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name) | ||||
| args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} | args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} | ||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| first_elem = x_shp[0] | first_elem = x_shp[0] | ||||
| for i, elem in enumerate(x_shp[1:]): | for i, elem in enumerate(x_shp[1:]): | ||||
| @@ -2070,7 +2074,7 @@ class ReverseV2(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2100,7 +2104,7 @@ class Rint(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2167,7 +2171,7 @@ class Select(PrimitiveWithInfer): | |||||
| self.add_prim_attr('T', x_type) | self.add_prim_attr('T', x_type) | ||||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | ||||
| validator.check_subclass("y_type", y_type, mstype.tensor, self.name) | validator.check_subclass("y_type", y_type, mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name) | |||||
| validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) | |||||
| if x_type != y_type: | if x_type != y_type: | ||||
| raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) | raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type)) | ||||
| return x_type | return x_type | ||||
| @@ -2542,7 +2546,7 @@ class Eye(PrimitiveWithInfer): | |||||
| validator.check_positive_int(n, "n", self.name) | validator.check_positive_int(n, "n", self.name) | ||||
| validator.check_positive_int(m, "m", self.name) | validator.check_positive_int(m, "m", self.name) | ||||
| args = {"dtype": t} | args = {"dtype": t} | ||||
| validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| np_type = mstype.dtype_to_nptype(t) | np_type = mstype.dtype_to_nptype(t) | ||||
| ret = np.eye(n, m, dtype=np_type) | ret = np.eye(n, m, dtype=np_type) | ||||
| return Tensor(ret) | return Tensor(ret) | ||||
| @@ -2581,7 +2585,7 @@ class ScatterNd(PrimitiveWithInfer): | |||||
| def __infer__(self, indices, update, shape): | def __infer__(self, indices, update, shape): | ||||
| shp = shape['value'] | shp = shape['value'] | ||||
| validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) | validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32], self.name) | |||||
| validator.check_value_type("shape", shp, [tuple], self.name) | validator.check_value_type("shape", shp, [tuple], self.name) | ||||
| for i, x in enumerate(shp): | for i, x in enumerate(shp): | ||||
| validator.check_positive_int(x, f'shape[{i}]', self.name) | validator.check_positive_int(x, f'shape[{i}]', self.name) | ||||
| @@ -2632,14 +2636,13 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||||
| validator.check_non_negative_int(value, f'{i}th value of size', self.name) | validator.check_non_negative_int(value, f'{i}th value of size', self.name) | ||||
| self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | ||||
| def infer_shape(self, x): | |||||
| validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) | |||||
| return tuple(x)[:-2] + tuple(self.size) | |||||
| def infer_shape(self, x_shape): | |||||
| validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name) | |||||
| return tuple(x_shape)[:-2] + tuple(self.size) | |||||
| def infer_dtype(self, x): | |||||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name) | |||||
| return x | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| class GatherNd(PrimitiveWithInfer): | class GatherNd(PrimitiveWithInfer): | ||||
| @@ -2674,8 +2677,7 @@ class GatherNd(PrimitiveWithInfer): | |||||
| return indices_shape[:-1] + x_shape[indices_shape[-1]:] | return indices_shape[:-1] + x_shape[indices_shape[-1]:] | ||||
| def infer_dtype(self, x_dtype, indices_dtype): | def infer_dtype(self, x_dtype, indices_dtype): | ||||
| validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2715,9 +2717,9 @@ class TensorScatterUpdate(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | args = {"x": x_dtype, "value": value_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2763,9 +2765,9 @@ class ScatterUpdate(_ScatterOp): | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | args = {"x": x_dtype, "value": value_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2802,7 +2804,6 @@ class ScatterNdUpdate(_ScatterNdOp): | |||||
| [0.4 2.2 -3.2]] | [0.4 2.2 -3.2]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, use_locking=True): | def __init__(self, use_locking=True): | ||||
| """Initialize ScatterNdUpdate""" | """Initialize ScatterNdUpdate""" | ||||
| @@ -2810,9 +2811,9 @@ class ScatterNdUpdate(_ScatterNdOp): | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | args = {"x": x_dtype, "value": value_dtype} | ||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3131,9 +3132,9 @@ class ScatterNonAliasingAdd(_ScatterNdOp): | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) | ||||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | ||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "updates": updates_dtype} | args = {"x": x_dtype, "updates": updates_dtype} | ||||
| validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3304,7 +3305,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||||
| self.paddings = paddings | self.paddings = paddings | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| @@ -3376,7 +3377,7 @@ class BatchToSpace(PrimitiveWithInfer): | |||||
| self.crops = crops | self.crops = crops | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| @@ -3465,7 +3466,7 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| self.add_prim_attr("paddings", paddings_append) | self.add_prim_attr("paddings", paddings_append) | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| @@ -3558,7 +3559,7 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| self.add_prim_attr("crops", crops_append) | self.add_prim_attr("crops", crops_append) | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| @@ -3721,7 +3722,6 @@ class Meshgrid(PrimitiveWithInfer): | |||||
| out_shape = tuple(tuple(shape_0) for _ in range(n)) | out_shape = tuple(tuple(shape_0) for _ in range(n)) | ||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) | validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name) | ||||
| n = len(x_type) | n = len(x_type) | ||||
| @@ -3729,6 +3729,7 @@ class Meshgrid(PrimitiveWithInfer): | |||||
| validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) | validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError) | ||||
| return x_type | return x_type | ||||
| class InplaceUpdate(PrimitiveWithInfer): | class InplaceUpdate(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Updates specified rows with values in `v`. | Updates specified rows with values in `v`. | ||||
| @@ -3771,7 +3772,7 @@ class InplaceUpdate(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, v_dtype): | def infer_dtype(self, x_dtype, v_dtype): | ||||
| args = {'x': x_dtype, 'v': v_dtype} | args = {'x': x_dtype, 'v': v_dtype} | ||||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | valid_type = [mstype.int32, mstype.float16, mstype.float32] | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, v_shape): | def infer_shape(self, x_shape, v_shape): | ||||
| @@ -3831,8 +3832,8 @@ class ReverseSequence(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x, seq_lengths): | def infer_dtype(self, x, seq_lengths): | ||||
| validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name) | |||||
| return x | return x | ||||
| @@ -3899,9 +3900,9 @@ class EditDistance(PrimitiveWithInfer): | |||||
| validator.check_const_input('truth_shape', truth_shape['value'], self.name) | validator.check_const_input('truth_shape', truth_shape['value'], self.name) | ||||
| args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], | args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], | ||||
| "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} | "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} | ||||
| validator.check_tensor_type_same(args_int, [mstype.int64], self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name) | |||||
| args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} | args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] | hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] | ||||
| validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) | validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) | ||||
| @@ -3941,6 +3942,7 @@ class TransShape(PrimitiveWithInfer): | |||||
| Outputs: | Outputs: | ||||
| Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. | Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`. | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| self.__setattr_flag__ = True | self.__setattr_flag__ = True | ||||
| @@ -3948,7 +3950,7 @@ class TransShape(PrimitiveWithInfer): | |||||
| def __infer__(self, x, shape): | def __infer__(self, x, shape): | ||||
| shp = shape['value'] | shp = shape['value'] | ||||
| dtype = x['dtype'] | dtype = x['dtype'] | ||||
| validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name) | |||||
| self.add_prim_attr('out_shape', tuple(shp)) | self.add_prim_attr('out_shape', tuple(shp)) | ||||
| return {'shape': shp, | return {'shape': shp, | ||||
| 'dtype': dtype, | 'dtype': dtype, | ||||
| @@ -3989,7 +3991,7 @@ class Sort(PrimitiveWithInfer): | |||||
| return x_shape, x_shape | return x_shape, x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name) | |||||
| validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name) | |||||
| return x_dtype, mstype.tensor_type(mstype.int32) | return x_dtype, mstype.tensor_type(mstype.int32) | ||||
| @@ -4019,6 +4021,7 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||||
| >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) | >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) | ||||
| [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize index_select""" | """Initialize index_select""" | ||||
| @@ -4028,7 +4031,7 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||||
| def __infer__(self, params, indices, offset): | def __infer__(self, params, indices, offset): | ||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) | ||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| if len(params_shp) != 2: | if len(params_shp) != 2: | ||||
| @@ -4060,6 +4063,7 @@ class GatherD(PrimitiveWithInfer): | |||||
| >>> out = P.GatherD()(x, dim, index) | >>> out = P.GatherD()(x, dim, index) | ||||
| [[1, 1], [4, 3]] | [[1, 1], [4, 3]] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize GatherD""" | """Initialize GatherD""" | ||||
| @@ -4067,7 +4071,7 @@ class GatherD(PrimitiveWithInfer): | |||||
| def __infer__(self, x, dim, index): | def __infer__(self, x, dim, index): | ||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name) | |||||
| validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) | validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name) | ||||
| x_shp = x['shape'] | x_shp = x['shape'] | ||||
| idx_shp = index['shape'] | idx_shp = index['shape'] | ||||
| @@ -4103,6 +4107,7 @@ class Identity(PrimitiveWithInfer): | |||||
| >>> y = P.Identity()(x) | >>> y = P.Identity()(x) | ||||
| [1, 2, 3, 4] | [1, 2, 3, 4] | ||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| """Initialize identity""" | """Initialize identity""" | ||||
| @@ -105,7 +105,7 @@ class AllReduce(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -167,7 +167,7 @@ class AllGather(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def __call__(self, tensor): | def __call__(self, tensor): | ||||
| @@ -217,7 +217,7 @@ class _HostAllGather(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def __call__(self, tensor): | def __call__(self, tensor): | ||||
| @@ -279,7 +279,7 @@ class ReduceScatter(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def __call__(self, tensor): | def __call__(self, tensor): | ||||
| @@ -328,7 +328,7 @@ class _HostReduceScatter(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def __call__(self, tensor): | def __call__(self, tensor): | ||||
| @@ -390,7 +390,7 @@ class Broadcast(PrimitiveWithInfer): | |||||
| if not isinstance(x_dtype, tuple): | if not isinstance(x_dtype, tuple): | ||||
| raise TypeError(f"{self.name}'s input should be a tuple!") | raise TypeError(f"{self.name}'s input should be a tuple!") | ||||
| for _ele in x_dtype: | for _ele in x_dtype: | ||||
| validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -432,7 +432,7 @@ class _AlltoAll(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def __call__(self, tensor): | def __call__(self, tensor): | ||||
| @@ -132,8 +132,7 @@ class GeSwitch(PrimitiveWithInfer): | |||||
| def infer_dtype(self, data_type, pred_type): | def infer_dtype(self, data_type, pred_type): | ||||
| validator.check_subclass( | validator.check_subclass( | ||||
| "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) | "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) | ||||
| validator.check_tensor_type_same( | |||||
| {"pred": pred_type}, [mstype.bool_], self.name) | |||||
| validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name) | |||||
| return (data_type, data_type) | return (data_type, data_type) | ||||
| @@ -171,5 +170,5 @@ class Merge(PrimitiveWithInfer): | |||||
| for i, item in enumerate(inputs): | for i, item in enumerate(inputs): | ||||
| args['inputs[%d]' % i] = item | args['inputs[%d]' % i] = item | ||||
| validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return (inputs[0], mstype.int32) | return (inputs[0], mstype.int32) | ||||
| @@ -380,7 +380,7 @@ class Assert(PrimitiveWithInfer): | |||||
| return [1] | return [1] | ||||
| def infer_dtype(self, condition, inputs): | def infer_dtype(self, condition, inputs): | ||||
| validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name) | |||||
| validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name) | |||||
| for dtype in inputs: | for dtype in inputs: | ||||
| validator.check_subclass("input", dtype, [mstype.tensor], self.name) | validator.check_subclass("input", dtype, [mstype.tensor], self.name) | ||||
| return mstype.int32 | return mstype.int32 | ||||
| @@ -104,11 +104,11 @@ class CropAndResize(PrimitiveWithInfer): | |||||
| box_index_dtype = box_index['dtype'] | box_index_dtype = box_index['dtype'] | ||||
| crop_size_dtype = crop_size['dtype'] | crop_size_dtype = crop_size['dtype'] | ||||
| # check dytpe | # check dytpe | ||||
| validator.check_tensor_type_same({"x": x_dtype}, | |||||
| [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, | |||||
| mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) | |||||
| validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, | |||||
| [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, | |||||
| mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) | |||||
| validator.check_tensor_dtype_valid("boxes", boxes_dtype, [mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid("box_index", box_index_dtype, [mstype.int32], self.name) | |||||
| validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) | validator.check_value_type("crop_size", crop_size_value, [tuple], self.name) | ||||
| # check input shape rank | # check input shape rank | ||||
| validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) | validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) | ||||
| @@ -16,6 +16,8 @@ | |||||
| """Operators for math.""" | """Operators for math.""" | ||||
| import copy | import copy | ||||
| from functools import partial | |||||
| import numpy as np | import numpy as np | ||||
| from ... import context | from ... import context | ||||
| from .. import signature as sig | from .. import signature as sig | ||||
| @@ -85,7 +87,7 @@ class _MathBinaryOp(_BinaryOp): | |||||
| @staticmethod | @staticmethod | ||||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): | def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): | ||||
| args_type = {"x": x_dtype, "y": y_dtype} | args_type = {"x": x_dtype, "y": y_dtype} | ||||
| validator.check_tensor_type_same(args_type, valid_dtype, prim_name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args_type, valid_dtype, prim_name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| @@ -105,8 +107,8 @@ class _BitwiseBinaryOp(_MathBinaryOp): | |||||
| @staticmethod | @staticmethod | ||||
| def _check_bitwise_op_input_type(x1_type, x2_type, prim): | def _check_bitwise_op_input_type(x1_type, x2_type, prim): | ||||
| args = {'x1': x1_type, 'x2': x2_type} | args = {'x1': x1_type, 'x2': x2_type} | ||||
| valid_types = mstype.int_type + mstype.uint_type | |||||
| validator.check_tensor_type_same(args, valid_types, prim) | |||||
| valid_dtypes = mstype.int_type + mstype.uint_type | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim) | |||||
| return x1_type | return x1_type | ||||
| def infer_dtype(self, x1_type, x2_type): | def infer_dtype(self, x1_type, x2_type): | ||||
| @@ -198,7 +200,7 @@ class AssignAdd(PrimitiveWithInfer): | |||||
| def infer_dtype(self, variable, value): | def infer_dtype(self, variable, value): | ||||
| args = {"variable": variable, "value": value} | args = {"variable": variable, "value": value} | ||||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) | |||||
| return value | return value | ||||
| @@ -248,7 +250,7 @@ class AssignSub(PrimitiveWithInfer): | |||||
| def infer_dtype(self, variable, value): | def infer_dtype(self, variable, value): | ||||
| args = {"variable": variable, "value": value} | args = {"variable": variable, "value": value} | ||||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_types_same(args, mstype.number_type, self.name) | |||||
| return value | return value | ||||
| @@ -283,7 +285,7 @@ class _Reduce(PrimitiveWithInfer): | |||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| input_shp = input_x['shape'] | input_shp = input_x['shape'] | ||||
| args = {'input_x': input_x['dtype']} | args = {'input_x': input_x['dtype']} | ||||
| validator.check_tensor_type_same(args, valid_dtype, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtype, self.name) | |||||
| if axis_v is None: | if axis_v is None: | ||||
| raise ValueError(f"For {self.name}, axis must be const.") | raise ValueError(f"For {self.name}, axis must be const.") | ||||
| @@ -504,6 +506,7 @@ class ReduceMax(_Reduce): | |||||
| def __infer__(self, input_x, axis): | def __infer__(self, input_x, axis): | ||||
| return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) | return self.do_infer(input_x, axis, mstype.number_type + (mstype.bool_,)) | ||||
| class ReduceMin(_Reduce): | class ReduceMin(_Reduce): | ||||
| """ | """ | ||||
| Reduce a dimension of a tensor by the minimum value in the dimension. | Reduce a dimension of a tensor by the minimum value in the dimension. | ||||
| @@ -612,7 +615,7 @@ class CumProd(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_type, axis_type): | def infer_dtype(self, x_type, axis_type): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) | |||||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, cls_name) | |||||
| validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | ||||
| return x_type | return x_type | ||||
| @@ -689,7 +692,7 @@ class MatMul(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x1, x2): | def infer_dtype(self, x1, x2): | ||||
| args = {"x1": x1, "x2": x2} | args = {"x1": x1, "x2": x2} | ||||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type + mstype.int_type, self.name) | |||||
| if x1.element_type() == mstype.int8: | if x1.element_type() == mstype.int8: | ||||
| return mstype.tensor_type(mstype.int32) | return mstype.tensor_type(mstype.int32) | ||||
| return x1 | return x1 | ||||
| @@ -801,10 +804,10 @@ class TensorDot(PrimitiveWithInfer): | |||||
| self.axes = axes | self.axes = axes | ||||
| validator.check_value_type('axes', axes, [int, tuple, list], self.name) | validator.check_value_type('axes', axes, [int, tuple, list], self.name) | ||||
| if not isinstance(self.axes, int): | if not isinstance(self.axes, int): | ||||
| self.axes = list(self.axes) # to avoid immutability issues | |||||
| self.axes = list(self.axes) # to avoid immutability issues | |||||
| if len(self.axes) != 2: | if len(self.axes) != 2: | ||||
| raise ValueError("Require two axes inputs, given less") | raise ValueError("Require two axes inputs, given less") | ||||
| self.int_to_tuple_conv() # convert before length checks | |||||
| self.int_to_tuple_conv() # convert before length checks | |||||
| if len(self.axes[0]) != len(self.axes[1]): | if len(self.axes[0]) != len(self.axes[1]): | ||||
| raise ValueError("Axes have to be the same size/length") | raise ValueError("Axes have to be the same size/length") | ||||
| if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): | if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): | ||||
| @@ -825,7 +828,7 @@ class TensorDot(PrimitiveWithInfer): | |||||
| if isinstance(self.axes, int): | if isinstance(self.axes, int): | ||||
| if self.axes <= 0: | if self.axes <= 0: | ||||
| # outer product, no input validation required | # outer product, no input validation required | ||||
| self.axes = ([], []) # no axes selected for either | |||||
| self.axes = ([], []) # no axes selected for either | |||||
| return | return | ||||
| if self.axes > len(x1_shape) or self.axes > len(x2_shape): | if self.axes > len(x1_shape) or self.axes > len(x2_shape): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -877,8 +880,8 @@ class TensorDot(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x1, x2): | def infer_dtype(self, x1, x2): | ||||
| args = {"x1": x1, "x2": x2} | args = {"x1": x1, "x2": x2} | ||||
| valid_types = [mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same(args, valid_types, self.name) | |||||
| valid_dtypes = [mstype.float16, mstype.float32] | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) | |||||
| return x1 | return x1 | ||||
| @@ -922,8 +925,8 @@ class CumSum(PrimitiveWithInfer): | |||||
| if axis['value'] is None: | if axis['value'] is None: | ||||
| raise ValueError(f"For {self.name}, axis must be const.") | raise ValueError(f"For {self.name}, axis must be const.") | ||||
| validator.check_value_type('axis', axis['value'], [int], cls_name) | validator.check_value_type('axis', axis['value'], [int], cls_name) | ||||
| valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | |||||
| valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||||
| validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name) | |||||
| return {'shape': x_shp, | return {'shape': x_shp, | ||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| @@ -989,7 +992,7 @@ class AddN(PrimitiveWithInfer): | |||||
| if dtype == mstype.undetermined: | if dtype == mstype.undetermined: | ||||
| contains_undetermined = True | contains_undetermined = True | ||||
| if not contains_undetermined: | if not contains_undetermined: | ||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) | |||||
| return inputs[0] | return inputs[0] | ||||
| def infer_value(self, inputs): | def infer_value(self, inputs): | ||||
| @@ -1068,7 +1071,7 @@ class AccumulateNV2(PrimitiveWithInfer): | |||||
| args = {} | args = {} | ||||
| for i, dtype in enumerate(inputs): | for i, dtype in enumerate(inputs): | ||||
| args[f"inputs[{i}]"] = dtype | args[f"inputs[{i}]"] = dtype | ||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), cls_name) | |||||
| return inputs[0] | return inputs[0] | ||||
| @@ -1094,12 +1097,12 @@ class Neg(PrimitiveWithInfer): | |||||
| """Initialize Neg""" | """Initialize Neg""" | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | self.init_prim_io_names(inputs=['x'], outputs=['y']) | ||||
| def infer_shape(self, input_x): | |||||
| return input_x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, input_x): | |||||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) | |||||
| return input_x | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| def infer_value(self, input_x): | def infer_value(self, input_x): | ||||
| if input_x is not None: | if input_x is not None: | ||||
| @@ -1151,7 +1154,7 @@ class InplaceAdd(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, v_dtype): | def infer_dtype(self, x_dtype, v_dtype): | ||||
| args = {'x': x_dtype, 'v': v_dtype} | args = {'x': x_dtype, 'v': v_dtype} | ||||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | valid_type = [mstype.int32, mstype.float16, mstype.float32] | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, v_shape): | def infer_shape(self, x_shape, v_shape): | ||||
| @@ -1209,7 +1212,7 @@ class InplaceSub(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, v_dtype): | def infer_dtype(self, x_dtype, v_dtype): | ||||
| args = {'x': x_dtype, 'v': v_dtype} | args = {'x': x_dtype, 'v': v_dtype} | ||||
| valid_type = [mstype.int32, mstype.float16, mstype.float32] | valid_type = [mstype.int32, mstype.float16, mstype.float32] | ||||
| validator.check_tensor_type_same(args, valid_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_shape(self, x_shape, v_shape): | def infer_shape(self, x_shape, v_shape): | ||||
| @@ -1363,9 +1366,9 @@ class Square(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||||
| return x_type | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is not None: | if x is not None: | ||||
| @@ -1401,9 +1404,9 @@ class Rsqrt(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||||
| return x_type | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is not None: | if x is not None: | ||||
| @@ -1437,7 +1440,7 @@ class Sqrt(PrimitiveWithCheck): | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | self.init_prim_io_names(inputs=['x'], outputs=['output']) | ||||
| def check_dtype(self, x_type): | def check_dtype(self, x_type): | ||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_type, mstype.number_type, self.name) | |||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| if x is not None: | if x is not None: | ||||
| @@ -1599,8 +1602,7 @@ class Expm1(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x", x_type, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) | |||||
| return x_type | return x_type | ||||
| @@ -1641,10 +1643,9 @@ class HistogramFixedWidth(PrimitiveWithInfer): | |||||
| return (self.nbins,) | return (self.nbins,) | ||||
| def infer_dtype(self, x_dtype, range_dtype): | def infer_dtype(self, x_dtype, range_dtype): | ||||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||||
| valid_types = (mstype.float16, mstype.float32, mstype.int32) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"range": range_dtype}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float16, mstype.float32, mstype.int32) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name) | |||||
| validator.check_tensor_dtype_valid("range", range_dtype, valid_dtypes, self.name) | |||||
| y_dtype = mstype.int32 | y_dtype = mstype.int32 | ||||
| return y_dtype | return y_dtype | ||||
| @@ -1707,13 +1708,13 @@ class Log1p(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | self.init_prim_io_names(inputs=['x'], outputs=['y']) | ||||
| def infer_shape(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x): | |||||
| validator.check_subclass("x", x, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({"x": x}, [mstype.float16, mstype.float32], self.name) | |||||
| return x | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class Erf(PrimitiveWithInfer): | class Erf(PrimitiveWithInfer): | ||||
| @@ -1741,9 +1742,9 @@ class Erf(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||||
| return x_type | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | |||||
| class Erfc(PrimitiveWithInfer): | class Erfc(PrimitiveWithInfer): | ||||
| @@ -1772,7 +1773,7 @@ class Erfc(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name) | |||||
| return x_type | return x_type | ||||
| @@ -2126,7 +2127,7 @@ class Floor(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, mstype.float_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2185,7 +2186,7 @@ class Ceil(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2281,7 +2282,7 @@ class Acosh(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2310,7 +2311,7 @@ class Cosh(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2339,7 +2340,7 @@ class Asinh(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2368,7 +2369,7 @@ class Sinh(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2380,7 +2381,7 @@ class _LogicBinaryOp(_BinaryOp): | |||||
| @staticmethod | @staticmethod | ||||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): | def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): | ||||
| args_dtype = {"x": x_dtype, "y": y_dtype} | args_dtype = {"x": x_dtype, "y": y_dtype} | ||||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name) | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| @@ -2461,7 +2462,7 @@ class ApproximateEqual(_LogicBinaryOp): | |||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| args_dtype = {"x": x_dtype, "y": y_dtype} | args_dtype = {"x": x_dtype, "y": y_dtype} | ||||
| valid_type = [mstype.float32, mstype.float16] | valid_type = [mstype.float32, mstype.float16] | ||||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name=self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args_dtype, valid_type, prim_name=self.name) | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| @@ -2498,7 +2499,7 @@ class EqualCount(PrimitiveWithInfer): | |||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| args = {'x': x_dtype, 'y': y_dtype} | args = {'x': x_dtype, 'y': y_dtype} | ||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2711,7 +2712,7 @@ class LogicalNot(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, [mstype.bool_], self.name) | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| @@ -2859,8 +2860,7 @@ class IsFinite(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_subclass("x", x_dtype, mstype.tensor, self.name) | |||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type + (mstype.bool_,), self.name) | |||||
| return mstype.bool_ | return mstype.bool_ | ||||
| @@ -2890,7 +2890,7 @@ class FloatStatus(PrimitiveWithInfer): | |||||
| return [1] | return [1] | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.float16], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -2959,7 +2959,7 @@ class NPUGetFloatStatus(PrimitiveWithInfer): | |||||
| return [8] | return [8] | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return mstype.float32 | return mstype.float32 | ||||
| @@ -3002,7 +3002,7 @@ class NPUClearFloatStatus(PrimitiveWithInfer): | |||||
| return [8] | return [8] | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return mstype.float32 | return mstype.float32 | ||||
| @@ -3030,7 +3030,7 @@ class Cos(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3058,7 +3058,7 @@ class ACos(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3087,7 +3087,7 @@ class Sin(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3116,7 +3116,7 @@ class Asin(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3175,7 +3175,7 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| def infer_dtype(self, bboxes_dtype): | def infer_dtype(self, bboxes_dtype): | ||||
| validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_dtype_valid("bboxes", bboxes_dtype, [mstype.float16, mstype.float32], self.name) | |||||
| return (bboxes_dtype, mstype.int32, mstype.bool_) | return (bboxes_dtype, mstype.int32, mstype.bool_) | ||||
| @@ -3205,7 +3205,7 @@ class Abs(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||||
| return x_type | return x_type | ||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| @@ -3247,7 +3247,7 @@ class Sign(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3276,9 +3276,9 @@ class Round(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | |||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||||
| return x_type | |||||
| def infer_dtype(self, x_dtype): | |||||
| validator.check_tensor_dtype_valid('x', x_dtype, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| class Tan(PrimitiveWithInfer): | class Tan(PrimitiveWithInfer): | ||||
| @@ -3306,8 +3306,8 @@ class Tan(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| valid_types = [mstype.float16, mstype.float32, mstype.int32] | |||||
| validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) | |||||
| valid_dtypes = [mstype.float16, mstype.float32, mstype.int32] | |||||
| validator.check_tensor_dtype_valid('x', x_type, valid_dtypes, self.name) | |||||
| return x_type | return x_type | ||||
| @@ -3338,7 +3338,7 @@ class Atan(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||||
| return x_type | return x_type | ||||
| @@ -3367,7 +3367,7 @@ class Atanh(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x_type, mstype.number_type, self.name) | |||||
| return x_type | return x_type | ||||
| @@ -3431,8 +3431,9 @@ class SquareSumAll(PrimitiveWithInfer): | |||||
| return [], [] | return [], [] | ||||
| def infer_dtype(self, x_type, y_type): | def infer_dtype(self, x_type, y_type): | ||||
| validator.check_tensor_type_same({'x1_type': x_type}, [mstype.float16, mstype.float32], self.name) | |||||
| validator.check_tensor_type_same({'x2_type': y_type}, [mstype.float16, mstype.float32], self.name) | |||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_dtype_valid('x1_type', x_type, valid_types, self.name) | |||||
| validator.check_tensor_dtype_valid('x2_type', y_type, valid_types, self.name) | |||||
| return x_type, y_type | return x_type, y_type | ||||
| @@ -3539,7 +3540,7 @@ class BesselI0e(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -3568,7 +3569,7 @@ class BesselI1e(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid('x', x, mstype.number_type, self.name) | |||||
| return x | return x | ||||
| @@ -3598,7 +3599,7 @@ class Inv(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float16, mstype.float32, | |||||
| validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.float16, mstype.float32, | |||||
| mstype.int32], self.name) | mstype.int32], self.name) | ||||
| return x_dtype | return x_dtype | ||||
| @@ -3628,7 +3629,7 @@ class Invert(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.int16, mstype.uint16], self.name) | |||||
| validator.check_tensor_dtype_valid('x_dtype', x_dtype, [mstype.int16, mstype.uint16], self.name) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -3654,8 +3655,8 @@ class Eps(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['input_x'], outputs=['y']) | self.init_prim_io_names(inputs=['input_x'], outputs=['y']) | ||||
| def __infer__(self, input_x): | def __infer__(self, input_x): | ||||
| valid_types = [mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same({'input_x': input_x['dtype']}, valid_types, self.name) | |||||
| valid_dtypes = [mstype.float16, mstype.float32] | |||||
| validator.check_tensor_dtype_valid('input_x', input_x['dtype'], valid_dtypes, self.name) | |||||
| x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) | x_nptype = mstype.dtype_to_nptype(input_x['dtype'].element_type()) | ||||
| if x_nptype == np.float16: | if x_nptype == np.float16: | ||||
| @@ -3725,9 +3726,9 @@ class IFMR(PrimitiveWithInfer): | |||||
| return (1,), (1,) | return (1,), (1,) | ||||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | ||||
| valid_types = [mstype.float32, mstype.float16] | |||||
| validator.check_tensor_type_same({"input_value": data_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_min": data_min_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_max": data_max_dtype}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"input_bins": cumsum_dtype}, [mstype.int32], self.name) | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("input_value", "input_min", "input_max"), | |||||
| (data_dtype, data_min_dtype, data_max_dtype))) | |||||
| validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) | |||||
| return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | ||||
| @@ -61,8 +61,8 @@ class Assign(PrimitiveWithCheck): | |||||
| def check_dtype(self, variable, value): | def check_dtype(self, variable, value): | ||||
| if variable != mstype.type_refkey: | if variable != mstype.type_refkey: | ||||
| validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) | |||||
| validator.check_tensor_dtype_valid("variable", variable, mstype.number_type, self.name) | |||||
| validator.check_scalar_or_tensor_types_same({"value": value}, mstype.number_type, self.name) | |||||
| class BoundingBoxEncode(PrimitiveWithInfer): | class BoundingBoxEncode(PrimitiveWithInfer): | ||||
| @@ -112,7 +112,7 @@ class BoundingBoxEncode(PrimitiveWithInfer): | |||||
| def infer_dtype(self, anchor_box, groundtruth_box): | def infer_dtype(self, anchor_box, groundtruth_box): | ||||
| args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} | args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return anchor_box | return anchor_box | ||||
| @@ -169,7 +169,7 @@ class BoundingBoxDecode(PrimitiveWithInfer): | |||||
| def infer_dtype(self, anchor_box, deltas): | def infer_dtype(self, anchor_box, deltas): | ||||
| args = {"anchor_box": anchor_box, "deltas": deltas} | args = {"anchor_box": anchor_box, "deltas": deltas} | ||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||||
| return anchor_box | return anchor_box | ||||
| @@ -221,8 +221,8 @@ class CheckValid(PrimitiveWithInfer): | |||||
| def infer_dtype(self, bboxes_type, metas_type): | def infer_dtype(self, bboxes_type, metas_type): | ||||
| valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] | valid_type = [mstype.float32, mstype.float16, mstype.int16, mstype.uint8] | ||||
| validator.check_tensor_type_same({"bboxes_type": bboxes_type}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"metas_type": metas_type}, valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("bboxes_type", bboxes_type, valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("metas_type", metas_type, valid_type, self.name) | |||||
| return mstype.bool_ | return mstype.bool_ | ||||
| @@ -281,8 +281,8 @@ class IOU(PrimitiveWithInfer): | |||||
| def infer_dtype(self, anchor_boxes, gt_boxes): | def infer_dtype(self, anchor_boxes, gt_boxes): | ||||
| valid_type = [mstype.float32, mstype.float16] | valid_type = [mstype.float32, mstype.float16] | ||||
| validator.check_tensor_type_same({"anchor_boxes": anchor_boxes}, valid_type, self.name) | |||||
| validator.check_tensor_type_same({"gt_boxes": gt_boxes}, valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("anchor_boxes", anchor_boxes, valid_type, self.name) | |||||
| validator.check_tensor_dtype_valid("gt_boxes", gt_boxes, valid_type, self.name) | |||||
| return anchor_boxes | return anchor_boxes | ||||
| @@ -478,7 +478,7 @@ class ConfusionMatrix(PrimitiveWithInfer): | |||||
| if weights is not None: | if weights is not None: | ||||
| validator.check_subclass('weights', weights, mstype.tensor, self.name) | validator.check_subclass('weights', weights, mstype.tensor, self.name) | ||||
| args = {"labels": labels, "predictions": predictions} | args = {"labels": labels, "predictions": predictions} | ||||
| validator.check_tensor_type_same(args, (mstype.number_type), self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid(args, (mstype.number_type), self.name) | |||||
| return labels | return labels | ||||
| @@ -506,8 +506,7 @@ class PopulationCount(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| args = {"x": x_dtype} | |||||
| validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) | |||||
| validator.check_tensor_dtype_valid("x", x_dtype, (mstype.int16, mstype.uint16,), self.name) | |||||
| return mstype.tensor_type(mstype.uint8) | return mstype.tensor_type(mstype.uint8) | ||||
| class Push(PrimitiveWithInfer): | class Push(PrimitiveWithInfer): | ||||
| @@ -151,8 +151,8 @@ class Gamma(PrimitiveWithInfer): | |||||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | Validator.check_value_type("shape", shape_v, [tuple], self.name) | ||||
| for i, shape_i in enumerate(shape_v): | for i, shape_i in enumerate(shape_v): | ||||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | ||||
| Validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name) | |||||
| Validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name) | |||||
| Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name) | |||||
| Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name) | |||||
| broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) | broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name) | ||||
| broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) | broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name) | ||||
| out = { | out = { | ||||
| @@ -203,7 +203,7 @@ class Poisson(PrimitiveWithInfer): | |||||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | Validator.check_value_type("shape", shape_v, [tuple], self.name) | ||||
| for i, shape_i in enumerate(shape_v): | for i, shape_i in enumerate(shape_v): | ||||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | ||||
| Validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name) | |||||
| Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name) | |||||
| broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) | broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name) | ||||
| out = { | out = { | ||||
| 'shape': broadcast_shape, | 'shape': broadcast_shape, | ||||
| @@ -259,8 +259,8 @@ class UniformInt(PrimitiveWithInfer): | |||||
| Validator.check_value_type("shape", shape_v, [tuple], self.name) | Validator.check_value_type("shape", shape_v, [tuple], self.name) | ||||
| for i, shape_i in enumerate(shape_v): | for i, shape_i in enumerate(shape_v): | ||||
| Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | Validator.check_positive_int(shape_i, f'shape[{i}]', self.name) | ||||
| Validator.check_tensor_type_same({"minval": minval["dtype"]}, [mstype.int32], self.name) | |||||
| Validator.check_tensor_type_same({"maxval": maxval["dtype"]}, [mstype.int32], self.name) | |||||
| Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name) | |||||
| Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name) | |||||
| minval_shape = minval['shape'] | minval_shape = minval['shape'] | ||||
| maxval_shape = maxval['shape'] | maxval_shape = maxval['shape'] | ||||
| Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) | Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name) | ||||
| @@ -361,7 +361,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): | |||||
| return ([self.count, len(x_shape)], [self.count]) | return ([self.count, len(x_shape)], [self.count]) | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| Validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) | |||||
| Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name) | |||||
| return (mstype.int32, mstype.bool_) | return (mstype.int32, mstype.bool_) | ||||
| @@ -407,8 +407,8 @@ class RandomCategorical(PrimitiveWithInfer): | |||||
| def __infer__(self, logits, num_samples, seed): | def __infer__(self, logits, num_samples, seed): | ||||
| logits_dtype = logits['dtype'] | logits_dtype = logits['dtype'] | ||||
| valid_types = (mstype.float32, mstype.float16, mstype.float64) | |||||
| Validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) | |||||
| valid_dtypes = (mstype.float32, mstype.float16, mstype.float64) | |||||
| Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name) | |||||
| num_samples_v = num_samples['value'] | num_samples_v = num_samples['value'] | ||||
| seed_v = seed['value'] | seed_v = seed['value'] | ||||
| Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) | Validator.check_value_type('num_samples', num_samples_v, (int,), self.name) | ||||
| @@ -460,7 +460,7 @@ class Multinomial(PrimitiveWithInfer): | |||||
| input_shape = inputs["shape"] | input_shape = inputs["shape"] | ||||
| if len(input_shape) != 1 and len(input_shape) != 2: | if len(input_shape) != 1 and len(input_shape) != 2: | ||||
| raise ValueError("input dim must be 1 or 2") | raise ValueError("input dim must be 1 or 2") | ||||
| Validator.check_tensor_type_same({'inputs': inputs['dtype']}, [mstype.float32], self.name) | |||||
| Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name) | |||||
| num_samples_value = num_samples["value"] | num_samples_value = num_samples["value"] | ||||
| if num_samples_value is None: | if num_samples_value is None: | ||||
| raise ValueError(f"For {self.name}, shape nust be const") | raise ValueError(f"For {self.name}, shape nust be const") | ||||
| @@ -588,8 +588,8 @@ def _quant_export(network, *inputs, file_format, **kwargs): | |||||
| if quant_mode not in quant_mode_formats: | if quant_mode not in quant_mode_formats: | ||||
| raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') | raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') | ||||
| mean = Validator.check_type("mean", mean, (int, float)) | |||||
| std_dev = Validator.check_type("std_dev", std_dev, (int, float)) | |||||
| mean = Validator.check_value_type("mean", mean, (int, float)) | |||||
| std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) | |||||
| if context.get_context('device_target') not in supported_device: | if context.get_context('device_target') not in supported_device: | ||||
| raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) | ||||
| @@ -117,7 +117,7 @@ class MySparseGatherV2(PrimitiveWithInfer): | |||||
| def __infer__(self, params, indices, axis): | def __infer__(self, params, indices, axis): | ||||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | ||||
| validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name) | |||||
| validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name) | ||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||