From 6dd72f654acb2daaeb04503dbdb6b12ee61ddb91 Mon Sep 17 00:00:00 2001 From: fary86 Date: Tue, 7 Apr 2020 15:50:48 +0800 Subject: [PATCH] Add prim name to error message for nn_ops.py --- mindspore/_checkparam.py | 46 +- mindspore/context.py | 2 +- mindspore/ops/operations/nn_ops.py | 766 ++++++++++------------- tests/ut/python/nn/test_dynamic_lr.py | 20 +- tests/ut/python/nn/test_ssim.py | 2 +- tests/ut/python/ops/test_nn_ops.py | 20 +- tests/ut/python/ops/test_nn_ops_check.py | 463 ++++++++++++++ 7 files changed, 821 insertions(+), 498 deletions(-) create mode 100755 tests/ut/python/ops/test_nn_ops_check.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 7b8c89351c..f0b7fa0af1 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -117,10 +117,12 @@ class Validator: """Integer value judgment.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" - raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') + raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' + f' with type `{type(arg_value).__name__}`.') return arg_value @staticmethod @@ -137,10 +139,11 @@ class Validator: """Method for checking whether an int value is in some range.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' - f' but got {arg_value}.') + raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' + f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') return arg_value @staticmethod @@ -192,19 +195,23 @@ class Validator: @staticmethod def check_const_input(arg_name, arg_value, prim_name): - """Check valid value.""" + """Checks valid value.""" if arg_value is None: raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') @staticmethod - def check_scalar_type_same(args, valid_values, prim_name): - """check whether the types of inputs are the same.""" + def check_type_same(args, valid_values, prim_name): + """Checks whether the types of inputs are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val + type_names = [] if not elem_type in valid_values: - raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' - f' but `{arg_key}` is {elem_type}.') + for t in valid_values: + type_names.append(str(t)) + types_info = '[' + ", ".join(type_names) + ']' + raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' + f' but got {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): @@ -212,7 +219,7 @@ class Validator: arg2_name, arg2_type = arg2 if arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' - f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') + f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') return arg1 elem_types = map(_check_tensor_type, args.items()) @@ -221,25 +228,8 @@ class Validator: @staticmethod def check_tensor_type_same(args, valid_values, prim_name): """Checks whether the element types of input tensors are the same.""" - def _check_tensor_type(arg): - arg_key, arg_val = arg - Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) - elem_type = arg_val.element_type() - if not elem_type in valid_values: - raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' - f' but element type of `{arg_key}` is {elem_type}.') - return (arg_key, elem_type) - - def _check_types_same(arg1, arg2): - arg1_name, arg1_type = arg1 - arg2_name, arg2_type = arg2 - if arg1_type != arg2_type: - raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' - f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') - return arg1 - - elem_types = map(_check_tensor_type, args.items()) - reduce(_check_types_same, elem_types) + tensor_types = [mstype.tensor_type(t) for t in valid_values] + Validator.check_type_same(args, tensor_types, prim_name) @staticmethod def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): diff --git a/mindspore/context.py b/mindspore/context.py index f6fe8705fd..159522a87a 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -34,7 +34,7 @@ GRAPH_MODE = 0 PYNATIVE_MODE = 1 -def _make_directory(path: str): +def _make_directory(path): """Make directory.""" real_path = None if path is None or not isinstance(path, str) or path.strip() == "": diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1fb65e3b76..5fd2c24a6e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -24,12 +24,39 @@ import numpy as np from ... import context from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_bool, check_int_positive +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): + """ + Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. + """ + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " + f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) + elif len(arg_value) == 2: + ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value + elif len(arg_value) == 4: + if not allow_four: + _raise_message() + ret = arg_value if ret_four else (arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and item > 0: + continue + _raise_message() + return ret_value + class Flatten(PrimitiveWithInfer): r""" Flattens a tensor without changing its batch size on the 0-th axis. @@ -53,12 +80,12 @@ class Flatten(PrimitiveWithInfer): pass def infer_shape(self, input_x): - validator.check('input_x rank', len(input_x), '', 1, Rel.GE) + validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) return input_x[0], prod def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -88,21 +115,21 @@ class Softmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type("axis", axis, [int, tuple]) + validator.check_value_type("axis", axis, [int, tuple], self.name) if isinstance(axis, int): self.add_prim_attr('axis', (axis,)) for item in self.axis: - validator.check_type("item of axis", item, [int]) + validator.check_value_type("item of axis", item, [int], self.name) def infer_shape(self, logits): - validator.check_shape_length("axis shape", len(self.axis), 1, Rel.GE) + validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) rank = len(logits) for axis_v in self.axis: - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -131,15 +158,15 @@ class LogSoftmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -171,8 +198,7 @@ class ReLU(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, mstype.number_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name) return input_x @@ -203,8 +229,7 @@ class ReLU6(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -233,14 +258,13 @@ class Elu(PrimitiveWithInfer): @prim_attr_register def __init__(self, alpha=1.0): """Init Elu""" - validator.check_type("alpha", alpha, [float]) + validator.check_value_type("alpha", alpha, [float], self.name) def infer_shape(self, input_x): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, mstype.float_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) return input_x @@ -272,8 +296,7 @@ class HSwish(PrimitiveWithInfer): return xshape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -305,8 +328,7 @@ class Sigmoid(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -339,8 +361,7 @@ class HSigmoid(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -370,7 +391,7 @@ class Tanh(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -418,9 +439,9 @@ class FusedBatchNorm(Primitive): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) - self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH) + self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) class BatchNorm(PrimitiveWithInfer): @@ -464,32 +485,34 @@ class BatchNorm(PrimitiveWithInfer): @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + validator.check_value_type('is_training', is_training, (bool,), self.name) + validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', 'reserve_space_3']) def infer_shape(self, input_x, scale, bias, mean, variance): - validator.check("BatchNorm scale shape length", len(scale), "1", 1, Rel.EQ) - validator.check("BatchNorm scale shape", scale, "BatchNorm bias shape", bias) - validator.check("BatchNorm scale shape", scale[0], "BatchNorm input_x shape[1]", input_x[1]) + validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) if not self.is_training: - validator.check("BatchNorm mean shape length", len(mean), "1", 1, Rel.EQ) - validator.check("BatchNorm mean shape", mean, "BatchNorm variance shape", variance) - validator.check("BatchNorm mean shape", mean, "BatchNorm scale shape", scale) + validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) return (input_x, scale, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): - args = {"BatchNorm scale type": scale, "BatchNorm bias type": bias} - args_moving = {"BatchNorm mean type": mean, "BatchNorm variance type": variance} - validator.check_typename("input_x", input_x, [mstype.float32, mstype.float16]) - validator.check_type_same(args, [mstype.float32, mstype.float16]) + validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + args = {"scale": scale, "bias": bias} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} if self.is_training: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16, None]) + valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] + validator.check_type_same(args_moving, valid_types, self.name) else: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16]) + args_moving = {"mean": mean, "variance": variance} + validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) return (input_x, scale, bias, input_x, input_x, input_x) @@ -559,53 +582,28 @@ class Conv2D(PrimitiveWithInfer): group=1): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE) + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.group = validator.check_integer('group', group, 0, Rel.GT) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1]) - validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -647,10 +645,9 @@ class Conv2D(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_subclass('input', x_dtype, mstype.tensor) - validator.check_subclass('weight', w_dtype, mstype.tensor) - validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + args = {'x': x_dtype, 'w': w_dtype} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) return x_dtype @@ -697,49 +694,25 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): group=1): """init DepthwiseConv2dNative""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (dilation, dilation) - if len(self.dilation) != 2 or (not isinstance(self.dilation[0], int)) or \ - (not isinstance(self.dilation[1], int)) or \ - self.dilation[0] < 1 or self.dilation[1] < 1: - raise ValueError(f"The \'dilation\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name) self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1])) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - if pad_mode not in ("same", "valid", "pad"): - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - f"{pad_mode} not in valid, same, pad.") - self.pad_mode = pad_mode - self.mode = validator.check_integer("mode", mode, 3, Rel.EQ) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT) - self.group = validator.check_integer("group", group, 0, Rel.GT) - self.pad = pad + self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT, + self.name) + self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -772,9 +745,6 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) - else: - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - "{pad_mode} not in valid, same, pad.") self.pad_list = (pad_top, pad_bottom, pad_left, pad_right) self.add_prim_attr('pads', self.pad_list) @@ -784,8 +754,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_type_same(args, mstype.number_type) + args = {'x': x_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -805,48 +775,26 @@ class _Pool(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksize=1, strides=1, padding="valid"): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") if not self.is_maxpoolwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - self.ksize = (1, 1, ksize, ksize) - else: - if (len(ksize) != 2 or - (not isinstance(ksize[0], int)) or - (not isinstance(ksize[1], int)) or - ksize[0] <= 0 or - ksize[1] <= 0): - raise ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {ksize}") - self.ksize = (1, 1, ksize[0], ksize[1]) + self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.ksize = (1, self.ksize[-2], self.ksize[-1], 1) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - self.strides = (1, 1, strides, strides) - else: - if (len(strides) != 2 or - (not isinstance(strides[0], int)) or - (not isinstance(strides[1], int)) or - strides[0] <= 0 or - strides[1] <= 0): - raise ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {strides}") - self.strides = (1, 1, strides[0], strides[1]) + self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.strides = (1, self.strides[-2], self.strides[-1], 1) self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) batch, channel, input_h, input_w = x_shape if self.is_maxpoolwithargmax: _, kernel_h, kernel_w, _ = self.ksize @@ -861,18 +809,16 @@ class _Pool(PrimitiveWithInfer): elif self.padding == "SAME": out_h = math.ceil(input_h / stride_h) out_w = math.ceil(input_w / stride_w) - else: - raise ValueError(f"The padding of operator {self.name} should be a str and must be 'SAME' or 'VALID', " - f"but got {self.padding}.") out_shape = [batch, channel, out_h, out_w] for shape_value in out_shape: if shape_value <= 0: - raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") + raise ValueError(f"For '{self.name}' The kernel size is not valid, " + f"please check it if is larger than data's shape size.") return out_shape def infer_dtype(self, x_dtype): - validator.check_subclass("input", x_dtype, mstype.tensor) + validator.check_subclass("input", x_dtype, mstype.tensor, self.name) return x_dtype @@ -987,7 +933,7 @@ class MaxPoolWithArgmax(_Pool): def infer_dtype(self, x_dtype): out_dtype = x_dtype - validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 return out_dtype, argmax_dtype @@ -1071,56 +1017,33 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """init Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - elif isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {stride}") + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) - self.group = validator.check_integer('group', group, 0, Rel.GT) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: - self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) + for x in pad_list: + validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) + self.pad_list = pad_list def __infer__(self, doutput, w, x_size): x_size_v = x_size['value'] - validator.check_type('x_size', x_size_v, [tuple]) + validator.check_value_type('x_size', x_size_v, [tuple], self.name) for i, dim_len in enumerate(x_size_v): - validator.check_type("x_size[%d]" % i, dim_len, [int]) - validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) + validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) + args = {'doutput': doutput['dtype'], 'w': w['dtype']} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) # infer shape dout_shape = doutput['shape'] @@ -1173,16 +1096,15 @@ class BiasAdd(PrimitiveWithInfer): self.add_prim_attr('data_format', 'NCHW') def infer_shape(self, x_shape, b_shape): - if len(b_shape) != 1 or len(x_shape) < 2 or b_shape[0] != x_shape[1]: - raise ValueError("Input_x and bias shapes do not match", - "(require: rank of input_x must be at least 2, rank of bias must be 1, " - "input_x.dim[1] must equal bias.dim[0])," - " but got input_x shape {}, bias shape {}.".format(x_shape, b_shape)) + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, b_type): - args = {"input_x type": x_type, "bias type": b_type} - validator.check_type_same(args, (mstype.float16, mstype.float32, mstype.int8, mstype.int32)) + args = {"input_x": x_type, "bias": b_type} + valid_types = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) return x_type @@ -1215,22 +1137,21 @@ class TopK(PrimitiveWithInfer): @prim_attr_register def __init__(self, sorted=False): - validator.check_type("sorted", sorted, [bool]) + validator.check_value_type("sorted", sorted, [bool], self.name) self.init_prim_io_names(inputs=['input', 'k'], outputs=['values', 'indices']) def __infer__(self, input_x, k): + x_dtype = input_x['dtype'] + valid_types = (mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name) + k_v = k['value'] + validator.check_value_type('k', k_v, (int,), self.name) x_shape = list(input_x['shape']) ndim = len(x_shape) - 1 - k_v = k['value'] x_shape[ndim] = k_v - input_dtype = input_x['dtype'] - validator.check_typename("TopK input_dtype", - input_dtype, (mstype.float16, mstype.float32, mstype.int32)) - if not isinstance(k_v, int): - raise ValueError('The k must int.', k) return {'shape': (x_shape, x_shape), - 'dtype': (input_dtype, mstype.int32), + 'dtype': (x_dtype, mstype.int32), 'value': None} @@ -1260,16 +1181,14 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): pass def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SoftmaxCrossEntropyWithLogits logits_shape", logits_shape, - "SoftmaxCrossEntropyWithLogits labels_shape", labels_shape) + validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name) loss_shape = [logits_shape[0]] dlogits_shape = logits_shape return (loss_shape, dlogits_shape) def infer_dtype(self, logits_type, labels_type): - args = {"SoftmaxCrossEntropyWithLogits logits_type": logits_type, - "SoftmaxCrossEntropyWithLogits labels_type": labels_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {"logits": logits_type, "labels": labels_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return (logits_type, logits_type) @@ -1308,18 +1227,15 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): self.add_prim_attr('sens', 1.0) def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SparseSoftmaxCrossEntropyWithLogits logits_shape", logits_shape[0], - "SparseSoftmaxCrossEntropyWithLogits labels_shape", labels_shape[0]) + validator.check("logits_shape[0]", logits_shape[0], "labels_shape[0]", labels_shape[0], Rel.EQ, self.name) loss_shape = [] if self.is_grad: return logits_shape return loss_shape def infer_dtype(self, logits_type, labels_type): - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits logits_type", - logits_type, (mstype.float16, mstype.float32)) - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits labels_type", - labels_type, (mstype.int32, mstype.int64)) + validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name) return logits_type @@ -1364,14 +1280,13 @@ class ApplyMomentum(PrimitiveWithInfer): return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): + valid_types = [mstype.float16, mstype.float32, mstype.float64] if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: - validator.check_subclass("v_dtype", v_dtype, mstype.tensor) - validator.check_subclass("a_dtype", a_dtype, mstype.tensor) - validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) return g_dtype @@ -1403,17 +1318,17 @@ class SmoothL1Loss(PrimitiveWithInfer): @prim_attr_register def __init__(self, sigma=1.0): - validator.check_type('sigma', sigma, [float]) - validator.check('sigma', sigma, '', 0, Rel.GT) + validator.check_value_type('sigma', sigma, [float], self.name) + validator.check('sigma', sigma, '', 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) def infer_shape(self, prediction, target): - validator.check_param_equal('prediction shape', prediction, 'target shape', target) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target): args = {"prediction": prediction, "target": target} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return prediction @@ -1446,29 +1361,30 @@ class SGD(PrimitiveWithInfer): @prim_attr_register def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False): - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.name) self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], outputs=['output']) def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): - validator.check(f'parameters shape {parameters_shape}', len(parameters_shape), '', 0, Rel.GT) - validator.check(f'gradient shape {gradient_shape}', len(gradient_shape), '', 0, Rel.GE) - validator.check(f'learning rate shape {learning_rate_shape}', len(learning_rate_shape), '', 0, Rel.GE) - validator.check(f'accumulation shape {accum_shape}', len(accum_shape), '', 0, Rel.GT) - validator.check(f'momentum shape {momentum_shape}', len(momentum_shape), '', 0, Rel.GE) - validator.check(f'stat shape {stat_shape}', len(stat_shape), '', 0, Rel.GE) - validator.check("gradient shape", gradient_shape, "stat shape", stat_shape) + validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name) + validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) + validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) + validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name) + validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) + validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) + validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) return parameters_shape def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype): - validator.check_typename("parameters_dtype", parameters_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("gradient_dtype", gradient_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("accum_dtype", accum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("momentum_dtype", momentum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name) return parameters_dtype class ApplyRMSProp(PrimitiveWithInfer): @@ -1514,28 +1430,23 @@ class ApplyRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, - "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype, - 'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "decay": decay_dtype, + 'momentum': momentum_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1587,30 +1498,25 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape) - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype, - "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype, - "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype, + "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype, + "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1651,8 +1557,8 @@ class LayerNorm(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): - validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - validator.check_type('begin_params_axis', begin_params_axis, [int]) + validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) class L2Normalize(PrimitiveWithInfer): @@ -1679,16 +1585,16 @@ class L2Normalize(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x): dim = len(input_x) - validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) + validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) return input_x def infer_dtype(self, input_x): - validator.check_subclass("x", input_x, mstype.tensor) + validator.check_subclass("x", input_x, mstype.tensor, self.name) return input_x @@ -1718,8 +1624,8 @@ class DropoutGenMask(Primitive): @prim_attr_register def __init__(self, Seed0=0, Seed1=0): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) - validator.check_type("Seed0", Seed0, [int]) - validator.check_type("Seed1", Seed1, [int]) + validator.check_value_type("Seed0", Seed0, [int], self.name) + validator.check_value_type("Seed1", Seed1, [int], self.name) class DropoutDoMask(PrimitiveWithInfer): @@ -1759,7 +1665,7 @@ class DropoutDoMask(PrimitiveWithInfer): input_x_shape = input_x['shape'] mask_shape = mask['shape'] keep_prob_shape = keep_prob['shape'] - validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0) + validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0, Rel.EQ, self.name) size_x = reduce(lambda x, y: x * y, input_x_shape) if len(mask_shape) != 1: raise ValueError("DropoutDoMask mask shape should be 1-dimension.") @@ -1768,13 +1674,13 @@ class DropoutDoMask(PrimitiveWithInfer): raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:" "{input_x_shape}, mask shape: {mask_shape}.") - validator.check_typename("input_x type", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32]) - validator.check_typename("input_mask type", mask['dtype'], [mstype.uint8]) + validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32], + self.name) + validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name) keep_prob_v = keep_prob['value'] if keep_prob_v is not None: - validator.check_const_input('keep_prob', keep_prob_v) - validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH) + validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) out = {'shape': input_x_shape, 'dtype': input_x['dtype'], @@ -1858,23 +1764,20 @@ class OneHot(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output']) - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def __infer__(self, indices, depth, on_value, off_value): # check type - validator.check_subclass("indices", indices['dtype'], mstype.tensor) - validator.check_typename("indices", indices['dtype'], (mstype.int32,)) - validator.check_typename("depth", depth['dtype'], mstype.int_type) - validator.check_subclass("on_value", on_value['dtype'], mstype.tensor) - validator.check_subclass("off_value", off_value['dtype'], mstype.tensor) - args = {"on_value dtype": on_value['dtype'], "off_value dtype": off_value['dtype']} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name) + validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) + args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) # check shape indices_shp = indices['shape'] - validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH) + validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) depth_val = depth['value'] - validator.check_integer("depth", depth_val, 0, Rel.GE) + validator.check_integer("depth", depth_val, 0, Rel.GE, self.name) # create new dimension at end if self.axis is -1 indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) @@ -1919,8 +1822,7 @@ class Gelu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -1953,10 +1855,10 @@ class GetNext(PrimitiveWithInfer): @prim_attr_register def __init__(self, types, shapes, output_num, shared_name): - validator.check_type("types", types, [list, tuple]) - validator.check_type("shapes", shapes, [list, tuple]) - validator.check("types length", len(types), "shapes length", len(shapes)) - validator.check_type("output_num", output_num, [int]) + validator.check_value_type("types", types, [list, tuple], self.name) + validator.check_value_type("shapes", shapes, [list, tuple], self.name) + validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name) + validator.check_value_type("output_num", output_num, [int], self.name) def infer_shape(self): return tuple(self.shapes) @@ -1997,24 +1899,22 @@ class PReLU(PrimitiveWithInfer): weight_dim = len(weight_shape) if weight_dim != 1: - raise ValueError(f'weight_dim must be 1, while weight_dim is {weight_dim}.') + raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.') if input_x_dim == 1 and weight_shape[0] != 1: - raise ValueError(f'when input_x_dim is 1, weight_shape[0] must be 1, ' + raise ValueError(f'For \'{self.name}\' when input_x_dim is 1, weight_shape[0] must be 1, ' f'while weight_shape[0] is {weight_shape[0]}.') if input_x_dim != 1 and weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1: - raise ValueError(f'channel of input_x and weight must be matched,' + raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,' f' while channel of input_x is {input_x_shape[1]},' f' weight_shape[0] is {weight_shape[0]}.') return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - validator.check_subclass("input_x_dtype", input_x_dtype, mstype.tensor) - validator.check_subclass("weight_dtype", weight_dtype, mstype.tensor) - validator.check_typename("input_x_dtype", input_x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("weight_dtype", weight_dtype, (mstype.float16, mstype.float32)) + args = {"input_x": input_x_dtype, "weight": weight_dtype} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return input_x_dtype @@ -2027,13 +1927,13 @@ class LSTM(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -2042,19 +1942,16 @@ class LSTM(PrimitiveWithInfer): def infer_shape(self, x_shape, h_shape, c_shape, w_shape): # (batch, seq, feature) - validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) # h and c should be same shape - validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(h_shape), len(c_shape), Rel.EQ) - validator.check_integer("h_shape", h_shape[0], c_shape[0], Rel.EQ) - validator.check_integer("h_shape", h_shape[1], c_shape[1], Rel.EQ) - validator.check_integer("h_shape", h_shape[2], c_shape[2], Rel.EQ) + validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) + validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) # (num_layers * num_directions, batch, hidden_size) - validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ) - validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name) + validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name) y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) @@ -2064,13 +1961,8 @@ class LSTM(PrimitiveWithInfer): return (y_shape, h_shape, c_shape, reserved_shape, state_shape) def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): - validator.check_typename("x_dtype", x_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("h_dtype", h_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("c_dtype", c_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("w_dtype", w_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", x_dtype, (h_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (c_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (w_dtype.element_type(),)) + args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) @@ -2101,12 +1993,12 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): self.init_prim_io_names(inputs=['predict', 'target'], outputs=['loss']) def infer_shape(self, x_shape, y_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -2150,7 +2042,7 @@ class Pad(PrimitiveWithInfer): def infer_shape(self, x): paddings = np.array(self.paddings) - validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name) if not np.all(paddings >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2159,7 +2051,7 @@ class Pad(PrimitiveWithInfer): return y_shape def infer_dtype(self, x): - validator.check_subclass("input_x", x, mstype.tensor) + validator.check_subclass("input_x", x, mstype.tensor, self.name) return x @@ -2210,16 +2102,16 @@ class MirrorPad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode='REFLECT'): """Init Pad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, input_x, paddings): - validator.check_subclass("input_x", input_x['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) + validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) x_shape = list(input_x['shape']) paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size - validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) if not np.all(paddings_size >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2270,10 +2162,10 @@ class ROIAlign(PrimitiveWithInfer): @prim_attr_register def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlign""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) self.pooled_height = pooled_height self.pooled_width = pooled_width self.spatial_scale = spatial_scale @@ -2338,24 +2230,24 @@ class Adam(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False, use_nesterov=False): - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape): - validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape) - validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): - args = {"var_dtype": var_dtype, "m_dtype": m_dtype, "v_dtype": v_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) - args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype, - "beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, + "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) return var_dtype, m_dtype, v_dtype @@ -2397,12 +2289,12 @@ class BinaryCrossEntropy(PrimitiveWithInfer): @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) if self.reduction in ('mean', 'sum'): shape = [] else: @@ -2410,10 +2302,11 @@ class BinaryCrossEntropy(PrimitiveWithInfer): return shape def infer_dtype(self, x_type, y_type, weight_type): - args = {'x_type': x_type, 'y_type': y_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {'x': x_type, 'y': y_type} + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name) return x_type @@ -2445,27 +2338,22 @@ class SparseApplyAdagrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, lr, use_locking=False): - self.lr = validator.check_type("lr", lr, [float]) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.lr = validator.check_value_type("lr", lr, [float], self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape)) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) if len(var_shape) > 1: - validator.check_param_equal('var_shape', var_shape[1:], 'grad_shape', grad_shape[1:]) - validator.check_integer("len of indices shape", len(indices_shape), 1, Rel.EQ) - validator.check('the first dimension of grad', grad_shape[0], - 'the shape of indices', indices_shape[0], Rel.EQ) + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, grad_type, indices_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - validator.check_subclass("indices_type", indices_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32,)) - validator.check_typename('indices_type', indices_type, [mstype.int32]) + args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} + validator.check_tensor_type_same(args, (mstype.float32,), self.name) + validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) return var_type @@ -2493,34 +2381,34 @@ class LARSUpdate(PrimitiveWithInfer): @prim_attr_register def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False): """init""" - validator.check_type("epsilon", epsilon, [float]) - validator.check_type("hyperpara", hyperpara, [float]) - validator.check_type("use_clip", use_clip, [bool]) + validator.check_value_type("epsilon", epsilon, [float], self.name) + validator.check_value_type("hyperpara", hyperpara, [float], self.name) + validator.check_value_type("use_clip", use_clip, [bool], self.name) def infer_shape(self, weight_shape, gradient_shape, norm_weight_shape, norm_gradient_shape, weight_decay_shape, learning_rate_shape): - validator.check_param_equal("Weight shape", weight_shape, "gradient shape", gradient_shape) - validator.check_param_equal("Norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape) + validator.check("weight shape", weight_shape, "gradient shape", gradient_shape, Rel.EQ, self.name) + validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, + self.name) shp_len = len(weight_decay_shape) - validator.check_shape_length("Weight decay's shape", shp_len, 1, Rel.LE) + validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Weight decay's shape", weight_decay_shape[0], 1, Rel.EQ) + validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name) shp_len = len(learning_rate_shape) - validator.check_shape_length("Learning rate's shape", shp_len, 1, Rel.LE) + validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Learning rate's shape", learning_rate_shape[0], 1, Rel.EQ) + validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name) return weight_shape def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, weight_decay_dtype, learning_rate_dtype): args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype, "norm gradient dtype": norm_gradient_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32]) - validator.check_args_tensor(args) - validator.check_typename("weight_decay_dtype", weight_decay_dtype, - [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, - [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name) + validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) + validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) return weight_dtype @@ -2553,26 +2441,23 @@ class ApplyFtrl(PrimitiveWithInfer): def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, lr_power_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('var shape', var_shape, 'linear shape', linear_shape) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("linear_type", linear_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'linear_type': linear_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32, mstype.float16)) - - validator.check_typename("lr", lr_type, [mstype.float16, mstype.float32]) - validator.check_typename("l1", l1_type, [mstype.float16, mstype.float32]) - validator.check_typename("l2", l2_type, [mstype.float16, mstype.float32]) - validator.check_typename("lr_power", lr_power_type, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} + validator.check_tensor_type_same(args, valid_types, self.name) + + validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) return var_type @@ -2607,36 +2492,22 @@ class ExtractImagePatches(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksizes, strides, rates, padding="valid"): """init""" - validator.check_type("ksizes", ksizes, [tuple, list]) - validator.check_type("strides", strides, [tuple, list]) - validator.check_type("rates", rates, [tuple, list]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + def _check_tuple_or_list(arg_name, arg_val, prim_name): + validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " + f"{arg_name}_col, 1], but got {arg_val}.") + if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " + f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " + f"is {arg_val[2]}") + + _check_tuple_or_list("ksize", ksizes, self.name) + _check_tuple_or_list("stride", strides, self.name) + _check_tuple_or_list("rate", rates, self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) - if len(ksizes) != 4 or ksizes[0] != 1 or ksizes[3] != 1: - raise ValueError("The format of ksizes should be [1, ksize_row, ksize_col, 1], " - f"but got {ksizes}.") - if not isinstance(ksizes[1], int) or not isinstance(ksizes[2], int) or \ - ksizes[1] < 1 or ksizes[2] < 1: - raise ValueError("The ksize_row and ksize_col in ksizes should be an positive integer number, " - f"but got ksize_row is {ksizes[1]}, ksize_col is {ksizes[2]}") - - if len(strides) != 4 or strides[0] != 1 or strides[3] != 1: - raise ValueError("The format of strides should be [1, stride_row, stride_col, 1], " - f"but got {strides}.") - if not isinstance(strides[1], int) or not isinstance(strides[2], int) or \ - strides[1] < 1 or strides[2] < 1: - raise ValueError("The stride_row and stride_col in strides should be an positive integer number, " - f"but got stride_row is {strides[1]}, stride_col is {strides[2]}") - - if len(rates) != 4 or rates[0] != 1 or rates[3] != 1: - raise ValueError("The format of rates should be [1, rate_row, rate_col, 1], " - f"but got {rates}.") - if not isinstance(rates[1], int) or not isinstance(rates[2], int) or \ - rates[1] < 1 or rates[2] < 1: - raise ValueError("The rate_row and rate_col in rates should be an positive integer number, " - f"but got rate_row is {rates[1]}, rate_col is {rates[2]}") - def infer_shape(self, input_x): in_batch, in_row, in_col, in_depth = input_x _, ksize_row, ksize_col, _ = self.ksizes @@ -2662,6 +2533,5 @@ class ExtractImagePatches(PrimitiveWithInfer): return out_shape def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, (mstype.int8, mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) return input_x diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index 96f9d5afde..8d03be1766 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -41,7 +41,7 @@ class TestInputs: dr.piecewise_constant_lr(milestone1, learning_rates) milestone2 = [1.0, 2.0, True] - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone2, learning_rates) def test_learning_rates1(self): @@ -92,13 +92,13 @@ class TestInputs: def test_total_step1(self): total_step1 = 2.0 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power) def test_total_step2(self): @@ -114,13 +114,13 @@ class TestInputs: def test_step_per_epoch1(self): step_per_epoch1 = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power) def test_step_per_epoch2(self): @@ -136,13 +136,13 @@ class TestInputs: def test_decay_epoch1(self): decay_epoch1 = 'm' - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power) def test_decay_epoch2(self): diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index cf946a1617..77d065b100 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -60,7 +60,7 @@ def test_ssim_max_val_zero(): net = SSIMNet(max_val) def test_ssim_filter_size_float(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = SSIMNet(filter_size=1.1) def test_ssim_filter_size_zero(): diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 09a4248c19..ab6f31095d 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -516,7 +516,7 @@ test_cases = [ test_cases_for_verify_exception = [ ('Conv2d_ValueError_1', { - 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), + 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('Conv2d_ValueError_2', { @@ -528,7 +528,7 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_2', { - 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_3', { @@ -540,7 +540,7 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_1', { - 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_2', { @@ -560,31 +560,31 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('Softmax_ValueError_1', { - 'block': (lambda _: P.Softmax("1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax("1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('Softmax_ValueError_2', { - 'block': (lambda _: P.Softmax(1.1), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(1.1), {'exception': TypeError}), 'desc_inputs': [0], }), ('Softmax_ValueError_3', { - 'block': (lambda _: P.Softmax(axis="1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(axis="1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_1', { - 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_2', { - 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_3', { - 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_4', { - 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('MaxPool2d_ValueError_1', { diff --git a/tests/ut/python/ops/test_nn_ops_check.py b/tests/ut/python/ops/test_nn_ops_check.py new file mode 100755 index 0000000000..c2a751aa0c --- /dev/null +++ b/tests/ut/python/ops/test_nn_ops_check.py @@ -0,0 +1,463 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test ops """ +import functools +import numpy as np +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from ..ut_filter import non_graph_engine +from mindspore.common.api import _executor + +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward\ + import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, + pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ + import pipeline_for_compile_grad_ge_graph_for_case_by_case_config + + +class Conv2DBackpropInputNet(nn.Cell): + def __init__(self, net, x_shape): + super(Conv2DBackpropInputNet, self).__init__() + self.net = net + self.x_shape = x_shape + + def construct(self, dout, w): + return self.net(dout, w, self.x_shape) + + +class TopKNet(nn.Cell): + def __init__(self, net, k): + super(TopKNet, self).__init__() + self.net = net + self.k = k + + def construct(self, x): + return self.net(x, self.k) + + +raise_set = [ + # input is scalar + ('Flatten0', { + 'block': (P.Flatten(), {'exception': TypeError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # dim of input is zero + ('Flatten1', { + 'block': (P.Flatten(), {'exception': ValueError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [F.scalar_to_tensor(5.0)], + 'skip': ['backward']}), + + # input is scalar + ('Softmax0', { + 'block': (P.Softmax(), {'exception': TypeError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis is empty tuple + ('Softmax1', { + 'block': (P.Softmax(axis=()), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + # axis value is not in range + ('Softmax2', { + 'block': (P.Softmax(axis=2), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('LogSoftmax0', { + 'block': (P.LogSoftmax(), {'exception': TypeError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis value is not in range + ('LogSoftmax1', { + 'block': (P.LogSoftmax(axis=2), {'exception': ValueError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU0', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(Bool) + ('ReLU1', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU60', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('ReLU61', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Elu0', { + 'block': (P.Elu(), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Elu1', { + 'block': (P.Elu(alpha=0.9), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Sigmoid0', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Sigmoid1', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Tanh0', { + 'block': (P.Tanh(), {'exception': TypeError, 'error_keywords': ['Tanh']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + + # input is scalar + ('BatchNorm0', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [5.0, 5.0, 5.0, 5.0, 5.0], + 'skip': ['backward']}), + # is_training=False and mean=None + ('BatchNorm1', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), None, None], + 'skip': ['backward']}), + # is_training=True and mean=None + ('BatchNorm2', { + 'block': (P.BatchNorm(is_training=True), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float16)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias rank > 1 + ('BatchNorm3', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias shape not match + ('BatchNorm4', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([7]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and variance shape not match + ('BatchNorm5', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and scale shape not match + ('BatchNorm6', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2D0', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2D1', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('Conv2D2', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('Conv2D3', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('Conv2D4', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] / group != w_shape[1] + ('Conv2D5', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # out_channel != w_shape[0] + ('Conv2D6', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('Conv2D7', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('DepthwiseConv2dNative0', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('DepthwiseConv2dNative1', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('DepthwiseConv2dNative2', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('DepthwiseConv2dNative3', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('DepthwiseConv2dNative4', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] != w_shape[1] + ('DepthwiseConv2dNative5', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('DepthwiseConv2dNative6', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPoolWithArgmax0', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('MaxPoolWithArgmax1', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPoolWithArgmax2', { + 'block': (P.MaxPoolWithArgmax(), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # kernel size is invalid(very large) + ('MaxPoolWithArgmax3', { + 'block': (P.MaxPoolWithArgmax(ksize=50), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPool0', { + 'block': (P.MaxPool(), {'exception': TypeError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool1', { + 'block': (P.MaxPool(), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool2', { + 'block': (P.MaxPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('AvgPool0', { + 'block': (P.AvgPool(), {'exception': TypeError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool1', { + 'block': (P.AvgPool(), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool2', { + 'block': (P.AvgPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2DBackpropInput0', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2DBackpropInput1', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of doutput and w mismatch + ('Conv2DBackpropInput2', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple + ('Conv2DBackpropInput3', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), 2), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple(int,...) + ('Conv2DBackpropInput4', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2, 3.0)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('BiasAdd0', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('BiasAdd1', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of x and bias mismatch + ('BiasAdd2', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x less than 2 + ('BiasAdd3', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of bias is not equal to 1 + ('BiasAdd4', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32))], + 'skip': ['backward']}), + # b_shape[0] != x_shape[1] + ('BiasAdd5', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input x is scalar + ('TopK0', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input x is Tensor(bool) + ('TopK1', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.bool_))], + 'skip': ['backward']}), + # k is not integer + ('TopK2', { + 'block': (TopKNet(P.TopK(), 5.0), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SoftmaxCrossEntropyWithLogits0', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('SoftmaxCrossEntropyWithLogits1', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits2', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float16)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # shapes of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits3', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SparseSoftmaxCrossEntropyWithLogits0', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # logits is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits1', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # labels is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits2', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # logits_shape[0] != labels_shape[0] + ('SparseSoftmaxCrossEntropyWithLogits3', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.int32))], + 'skip': ['backward']}), +] + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +def test_check_exception(): + return raise_set