| @@ -15,6 +15,7 @@ | |||||
| """Check parameters.""" | """Check parameters.""" | ||||
| import re | import re | ||||
| from enum import Enum | from enum import Enum | ||||
| from functools import reduce | |||||
| from itertools import repeat | from itertools import repeat | ||||
| from collections import Iterable | from collections import Iterable | ||||
| @@ -93,8 +94,131 @@ rel_strs = { | |||||
| } | } | ||||
| class Validator: | |||||
| """validator for checking input parameters""" | |||||
| @staticmethod | |||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): | |||||
| """ | |||||
| Method for judging relation between two int values or list/tuple made up of ints. | |||||
| This method is not suitable for judging relation between floats, since it does not consider float error. | |||||
| """ | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||||
| msg_prefix = f'For {prim_name} the' if prim_name else "The" | |||||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||||
| """Integer value judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' | |||||
| f' but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||||
| """Method for checking whether an int value is in some range.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) | |||||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||||
| f' but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||||
| """Check whether some type is sublcass of another type""" | |||||
| if not isinstance(template_type, Iterable): | |||||
| template_type = (template_type,) | |||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||||
| @staticmethod | |||||
| def check_tensor_type_same(args, valid_values, prim_name): | |||||
| """check whether the element types of input tensors are the same.""" | |||||
| def _check_tensor_type(arg): | |||||
| arg_key, arg_val = arg | |||||
| Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | |||||
| elem_type = arg_val.element_type() | |||||
| if not elem_type in valid_values: | |||||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | |||||
| f' but `{arg_key}` is {elem_type}.') | |||||
| return (arg_key, elem_type) | |||||
| def _check_types_same(arg1, arg2): | |||||
| arg1_name, arg1_type = arg1 | |||||
| arg2_name, arg2_type = arg2 | |||||
| if arg1_type != arg2_type: | |||||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' | |||||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||||
| return arg1 | |||||
| elem_types = map(_check_tensor_type, args.items()) | |||||
| reduce(_check_types_same, elem_types) | |||||
| @staticmethod | |||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name): | |||||
| """check whether the types of inputs are the same. if the input args are tensors, check their element types""" | |||||
| def _check_argument_type(arg): | |||||
| arg_key, arg_val = arg | |||||
| if isinstance(arg_val, type(mstype.tensor)): | |||||
| arg_val = arg_val.element_type() | |||||
| if not arg_val in valid_values: | |||||
| raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' | |||||
| f' but `{arg_key}` is {arg_val}.') | |||||
| return arg | |||||
| def _check_types_same(arg1, arg2): | |||||
| arg1_name, arg1_type = arg1 | |||||
| arg2_name, arg2_type = arg2 | |||||
| excp_flag = False | |||||
| if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): | |||||
| arg1_type = arg1_type.element_type() | |||||
| arg2_type = arg2_type.element_type() | |||||
| elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | |||||
| pass | |||||
| else: | |||||
| excp_flag = True | |||||
| if excp_flag or arg1_type != arg2_type: | |||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||||
| return arg1 | |||||
| reduce(_check_types_same, map(_check_argument_type, args.items())) | |||||
| @staticmethod | |||||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | |||||
| """Check whether a values is instance of some types.""" | |||||
| def raise_error_msg(): | |||||
| """func for raising error message when check failed""" | |||||
| type_names = [t.__name__ for t in valid_types] | |||||
| num_types = len(valid_types) | |||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' | |||||
| f'{"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | |||||
| # `check_value_type('x', True, [bool, int])` will check pass | |||||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||||
| raise_error_msg() | |||||
| if isinstance(arg_value, tuple(valid_types)): | |||||
| return arg_value | |||||
| raise_error_msg() | |||||
| class ParamValidator: | class ParamValidator: | ||||
| """Parameter validator.""" | |||||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||||
| @staticmethod | @staticmethod | ||||
| def equal(arg_name, arg_value, cond_str, cond): | def equal(arg_name, arg_value, cond_str, cond): | ||||
| @@ -16,13 +16,14 @@ | |||||
| """broadcast""" | """broadcast""" | ||||
| def _get_broadcast_shape(x_shape, y_shape): | |||||
| def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||||
| """ | """ | ||||
| Doing broadcast between tensor x and tensor y. | Doing broadcast between tensor x and tensor y. | ||||
| Args: | Args: | ||||
| x_shape (list): The shape of tensor x. | x_shape (list): The shape of tensor x. | ||||
| y_shape (list): The shape of tensor y. | y_shape (list): The shape of tensor y. | ||||
| prim_name (str): Primitive name. | |||||
| Returns: | Returns: | ||||
| List, the shape that broadcast between tensor x and tensor y. | List, the shape that broadcast between tensor x and tensor y. | ||||
| @@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape): | |||||
| elif x_shape[i] == y_shape[i]: | elif x_shape[i] == y_shape[i]: | ||||
| broadcast_shape_back.append(x_shape[i]) | broadcast_shape_back.append(x_shape[i]) | ||||
| else: | else: | ||||
| raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape)) | |||||
| raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( | |||||
| prim_name, x_shape, y_shape)) | |||||
| broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | ||||
| broadcast_shape = broadcast_shape_front + broadcast_shape_back | broadcast_shape = broadcast_shape_front + broadcast_shape_back | ||||
| @@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce | |||||
| from ..operations.math_ops import _infer_shape_reduce | |||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | ||||
| def _check_infer_attr_reduce(axis, keep_dims): | |||||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| validator.check_type('axis', axis, [int, tuple]) | |||||
| if isinstance(axis, tuple): | |||||
| for index, value in enumerate(axis): | |||||
| validator.check_type('axis[%d]' % index, value, [int]) | |||||
| class ExpandDims(PrimitiveWithInfer): | class ExpandDims(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -1090,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): | |||||
| axis = self.axis | axis = self.axis | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | ||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) | |||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) | |||||
| return ouput_shape, ouput_shape | return ouput_shape, ouput_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| @@ -1136,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||||
| axis = self.axis | axis = self.axis | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | ||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) | |||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) | |||||
| return ouput_shape, ouput_shape | return ouput_shape, ouput_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| @@ -19,7 +19,7 @@ import numpy as np | |||||
| from ..._c_expression import signature_rw as sig_rw | from ..._c_expression import signature_rw as sig_rw | ||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..._c_expression import signature_dtype as sig_dtype | from ..._c_expression import signature_dtype as sig_dtype | ||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| @@ -27,16 +27,16 @@ from .._utils import _get_broadcast_shape | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithInfer, prim_attr_register | ||||
| def _infer_shape_reduce(x, axis, keep_dims): | |||||
| def _infer_shape_reduce(x, axis, keep_dims, prim_name): | |||||
| """Common infer for reduce operator""" | """Common infer for reduce operator""" | ||||
| def reduce_one_axis(one_axis): | def reduce_one_axis(one_axis): | ||||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT) | |||||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name) | |||||
| if one_axis < 0: | if one_axis < 0: | ||||
| one_axis += dim | one_axis += dim | ||||
| axis_reduce.add(one_axis) | axis_reduce.add(one_axis) | ||||
| validator.check_type('axis', axis, [int, tuple, list]) | |||||
| validator.check_value_type('axis', axis, [int, tuple, list], prim_name) | |||||
| dim = len(x) | dim = len(x) | ||||
| axis_reduce = set() | axis_reduce = set() | ||||
| @@ -48,7 +48,7 @@ def _infer_shape_reduce(x, axis, keep_dims): | |||||
| return [1] * dim | return [1] * dim | ||||
| return [] | return [] | ||||
| for index, one_axis in enumerate(axis): | for index, one_axis in enumerate(axis): | ||||
| validator.check_type('axis[%d]' % index, one_axis, [int]) | |||||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], prim_name) | |||||
| reduce_one_axis(one_axis) | reduce_one_axis(one_axis) | ||||
| out_shape = [] | out_shape = [] | ||||
| @@ -61,14 +61,6 @@ def _infer_shape_reduce(x, axis, keep_dims): | |||||
| return out_shape | return out_shape | ||||
| def _check_infer_attr_reduce(axis, keep_dims): | |||||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| validator.check_type('axis', axis, [int, tuple]) | |||||
| if isinstance(axis, tuple): | |||||
| for index, value in enumerate(axis): | |||||
| validator.check_type('axis[%d]' % index, value, [int]) | |||||
| class _BinaryOp(PrimitiveWithInfer): | class _BinaryOp(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Define binary operators. | Define binary operators. | ||||
| @@ -82,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | ||||
| def infer_shape(self, x_shape, y_shape): | def infer_shape(self, x_shape, y_shape): | ||||
| return _get_broadcast_shape(x_shape, y_shape) | |||||
| return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) | |||||
| class _MathBinaryOp(_BinaryOp): | class _MathBinaryOp(_BinaryOp): | ||||
| @@ -91,15 +83,13 @@ class _MathBinaryOp(_BinaryOp): | |||||
| """ | """ | ||||
| @staticmethod | @staticmethod | ||||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type): | |||||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): | |||||
| args_type = {"x": x_dtype, "y": y_dtype} | args_type = {"x": x_dtype, "y": y_dtype} | ||||
| validator.check_args_tensor(args_type) | |||||
| args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype} | |||||
| validator.check_type_same(args_dtype, valid_dtype) | |||||
| validator.check_tensor_type_same(args_type, valid_dtype, prim_name) | |||||
| return x_dtype | return x_dtype | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype) | |||||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) | |||||
| class TensorAdd(_MathBinaryOp): | class TensorAdd(_MathBinaryOp): | ||||
| @@ -167,7 +157,7 @@ class AssignAdd(PrimitiveWithInfer): | |||||
| def infer_dtype(self, variable, value): | def infer_dtype(self, variable, value): | ||||
| args = {"value": value} | args = {"value": value} | ||||
| validator.check_type_same(args, mstype.number_type) | |||||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) | |||||
| return value | return value | ||||
| @@ -208,7 +198,7 @@ class AssignSub(PrimitiveWithInfer): | |||||
| def infer_dtype(self, variable, value): | def infer_dtype(self, variable, value): | ||||
| args = {"value": value} | args = {"value": value} | ||||
| validator.check_type_same(args, mstype.number_type) | |||||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) | |||||
| return value | return value | ||||
| @@ -229,15 +219,16 @@ class _Reduce(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, keep_dims=False): | def __init__(self, keep_dims=False): | ||||
| """init Reduce""" | """init Reduce""" | ||||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||||
| validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) | |||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | ||||
| def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): | def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): | ||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| input_shp = input_x['shape'] | input_shp = input_x['shape'] | ||||
| validator.check_subclass('input_x', input_x['dtype'], mstype.tensor) | |||||
| validator.check_typename('input_x', input_x['dtype'], valid_dtype) | |||||
| input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims) | |||||
| args = {'input_x': input_x['dtype']} | |||||
| validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) | |||||
| input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) | |||||
| return {'shape': input_shp, | return {'shape': input_shp, | ||||
| 'dtype': input_x['dtype'], | 'dtype': input_x['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| @@ -472,16 +463,17 @@ class CumProd(PrimitiveWithInfer): | |||||
| """ | """ | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, exclusive=False, reverse=False): | def __init__(self, exclusive=False, reverse=False): | ||||
| self.exclusive = validator.check_type("exclusive", exclusive, [bool]) | |||||
| self.reverse = validator.check_type("reverse", reverse, [bool]) | |||||
| cls_name = self.prim_name() | |||||
| self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) | |||||
| self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) | |||||
| def infer_shape(self, x_shape, axis_shape): | def infer_shape(self, x_shape, axis_shape): | ||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, axis_type): | def infer_dtype(self, x_type, axis_type): | ||||
| validator.check_subclass('x_type', x_type, mstype.tensor) | |||||
| validator.check_typename('x_type', x_type, mstype.number_type) | |||||
| validator.check_subclass("axis_type", axis_type, mstype.int_) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) | |||||
| validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | |||||
| return x_type | return x_type | ||||
| @@ -515,8 +507,9 @@ class MatMul(PrimitiveWithInfer): | |||||
| def __init__(self, transpose_a=False, transpose_b=False): | def __init__(self, transpose_a=False, transpose_b=False): | ||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | self.__setattr_flag__ = True | ||||
| validator.check_type("transpose_a", transpose_a, [bool]) | |||||
| validator.check_type("transpose_b", transpose_b, [bool]) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||||
| def check_shape_size(self, x, y): | def check_shape_size(self, x, y): | ||||
| if len(x) != 2 or len(y) != 2: | if len(x) != 2 or len(y) != 2: | ||||
| @@ -525,11 +518,11 @@ class MatMul(PrimitiveWithInfer): | |||||
| def infer_shape(self, x, y): | def infer_shape(self, x, y): | ||||
| self.check_shape_size(x, y) | self.check_shape_size(x, y) | ||||
| cls_name = self.__class__.__name__ | |||||
| cls_name = self.prim_name() | |||||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | ||||
| for i in range(len(x) - 2): | for i in range(len(x) - 2): | ||||
| if x[i] != y[i]: | if x[i] != y[i]: | ||||
| raise ValueError(f'{cls_name} shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') | |||||
| raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') | |||||
| # validate whether last two dims satifing matrix multiply | # validate whether last two dims satifing matrix multiply | ||||
| x_last = x[-2:] | x_last = x[-2:] | ||||
| @@ -538,8 +531,8 @@ class MatMul(PrimitiveWithInfer): | |||||
| x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] | x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] | ||||
| y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] | y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] | ||||
| if x_col != y_row: | if x_col != y_row: | ||||
| raise ValueError(f'{cls_name} evaluator shapes of inputs can not do this operator, got {x_col} and {y_row}' | |||||
| + f' for {cls_name}, with x shape {x}(transpose_a={self.transpose_a})' | |||||
| raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' | |||||
| + f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})' | |||||
| + f', y shape {y}(transpose_b={self.transpose_b}).') | + f', y shape {y}(transpose_b={self.transpose_b}).') | ||||
| # set attribute | # set attribute | ||||
| self.add_prim_attr('transpose_x1', self.transpose_a) | self.add_prim_attr('transpose_x1', self.transpose_a) | ||||
| @@ -549,10 +542,8 @@ class MatMul(PrimitiveWithInfer): | |||||
| return ret_dims | return ret_dims | ||||
| def infer_dtype(self, x, y): | def infer_dtype(self, x, y): | ||||
| validator.check_subclass("x", x, mstype.tensor) | |||||
| validator.check_subclass("y", y, mstype.tensor) | |||||
| args = {"x dtype": x, "y dtype": y} | |||||
| validator.check_type_same(args, mstype.float_type + mstype.int_type) | |||||
| args = {"x": x, "y": y} | |||||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -596,12 +587,13 @@ class BatchMatMul(MatMul): | |||||
| def __init__(self, transpose_a=False, transpose_b=False): | def __init__(self, transpose_a=False, transpose_b=False): | ||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | ||||
| self.__setattr_flag__ = True | self.__setattr_flag__ = True | ||||
| validator.check_type("transpose_a", transpose_a, [bool]) | |||||
| validator.check_type("transpose_b", transpose_b, [bool]) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||||
| def check_shape_size(self, x, y): | def check_shape_size(self, x, y): | ||||
| if len(x) != len(y) or len(x) < 3: | if len(x) != len(y) or len(x) < 3: | ||||
| raise ValueError('BatchMatMul input x, y should be the same dimension size and should be ' | |||||
| raise ValueError('For \'BatchMatMul\' input x, y should be the same dimension size and should be ' | |||||
| 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') | 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') | ||||
| @@ -633,18 +625,17 @@ class CumSum(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, exclusive=False, reverse=False): | def __init__(self, exclusive=False, reverse=False): | ||||
| """init cumsum""" | """init cumsum""" | ||||
| self.exclusive = validator.check_type('exclusive', exclusive, [bool]) | |||||
| self.add_prim_attr("exclusive", self.exclusive) | |||||
| self.reverse = validator.check_type('reverse', reverse, [bool]) | |||||
| self.add_prim_attr("reverse", self.reverse) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_value_type('exclusive', exclusive, [bool], cls_name) | |||||
| validator.check_value_type('reverse', reverse, [bool], cls_name) | |||||
| self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) | ||||
| def __infer__(self, x, axis): | def __infer__(self, x, axis): | ||||
| cls_name = self.prim_name() | |||||
| x_shp = x['shape'] | x_shp = x['shape'] | ||||
| validator.check_type('axis', axis['value'], [int]) | |||||
| validator.check_subclass('x', x['dtype'], mstype.tensor) | |||||
| validator.check_typename('x', x['dtype'], [mstype.uint8, mstype.int8, | |||||
| mstype.int32, mstype.float16, mstype.float32]) | |||||
| validator.check_value_type('axis', axis['value'], [int], cls_name) | |||||
| valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | |||||
| return {'shape': x_shp, | return {'shape': x_shp, | ||||
| 'dtype': x['dtype'], | 'dtype': x['dtype'], | ||||
| 'value': None} | 'value': None} | ||||
| @@ -685,21 +676,22 @@ class AddN(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | ||||
| def infer_shape(self, inputs): | def infer_shape(self, inputs): | ||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||||
| self.add_prim_attr('n', len(inputs)) | self.add_prim_attr('n', len(inputs)) | ||||
| shp0 = inputs[0] | shp0 = inputs[0] | ||||
| for i, shp in enumerate(inputs): | for i, shp in enumerate(inputs): | ||||
| validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0) | |||||
| validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0, Rel.EQ, cls_name) | |||||
| return shp0 | return shp0 | ||||
| def infer_dtype(self, inputs): | def infer_dtype(self, inputs): | ||||
| validator.check_type("inputs", inputs, [tuple, list]) | |||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_value_type("inputs", inputs, [tuple, list], cls_name) | |||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||||
| args = {} | args = {} | ||||
| for i, dtype in enumerate(inputs): | for i, dtype in enumerate(inputs): | ||||
| validator.check_subclass(f"inputs[{i}]", dtype, mstype.tensor) | |||||
| args[f"inputs[{i}]"] = dtype | args[f"inputs[{i}]"] = dtype | ||||
| validator.check_type_same(args, mstype.number_type + (mstype.bool_,)) | |||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||||
| return inputs[0] | return inputs[0] | ||||
| @@ -723,8 +715,7 @@ class Neg(PrimitiveWithInfer): | |||||
| return input_x | return input_x | ||||
| def infer_dtype(self, input_x): | def infer_dtype(self, input_x): | ||||
| validator.check_subclass("input_x", input_x, mstype.tensor) | |||||
| validator.check_typename("input_x", input_x, mstype.number_type) | |||||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) | |||||
| return input_x | return input_x | ||||
| @@ -807,8 +798,7 @@ class Square(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x", x_type, mstype.tensor) | |||||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| @@ -837,8 +827,7 @@ class Rsqrt(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x", x_type, mstype.tensor) | |||||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| @@ -867,8 +856,7 @@ class Sqrt(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x", x_type, mstype.tensor) | |||||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| @@ -898,7 +886,7 @@ class Reciprocal(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x", x, mstype.tensor) | |||||
| validator.check_subclass("x", x, mstype.tensor, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -936,8 +924,7 @@ class Pow(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x, power): | def infer_dtype(self, x, power): | ||||
| validator.check_subclass("x", x, mstype.tensor) | |||||
| validator.check_typename("power", power, mstype.number_type) | |||||
| validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -967,7 +954,7 @@ class Exp(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x", x_type, mstype.tensor) | |||||
| validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| @@ -996,7 +983,7 @@ class Log(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x", x, mstype.tensor) | |||||
| validator.check_subclass("x", x, mstype.tensor, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -1178,8 +1165,7 @@ class Floor(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_subclass("x", x_dtype, mstype.tensor) | |||||
| validator.check_typename("x_dtype", x_dtype, mstype.float_type) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1234,8 +1220,7 @@ class Acosh(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -1245,15 +1230,13 @@ class _LogicBinaryOp(_BinaryOp): | |||||
| """ | """ | ||||
| @staticmethod | @staticmethod | ||||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type): | |||||
| args_type = {"x": x_dtype, "y": y_dtype} | |||||
| validator.check_args_tensor(args_type) | |||||
| args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype} | |||||
| validator.check_type_same(args_dtype, valid_type) | |||||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): | |||||
| args_dtype = {"x": x_dtype, "y": y_dtype} | |||||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name) | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype) | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) | |||||
| class Equal(_LogicBinaryOp): | class Equal(_LogicBinaryOp): | ||||
| @@ -1289,7 +1272,7 @@ class Equal(_LogicBinaryOp): | |||||
| """ | """ | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||||
| class EqualCount(PrimitiveWithInfer): | class EqualCount(PrimitiveWithInfer): | ||||
| @@ -1318,11 +1301,13 @@ class EqualCount(PrimitiveWithInfer): | |||||
| """init EqualCount""" | """init EqualCount""" | ||||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | ||||
| def infer_shape(self, x_shape, w_shape): | |||||
| def infer_shape(self, x_shape, y_shape): | |||||
| output_shape = (1,) | output_shape = (1,) | ||||
| return output_shape | return output_shape | ||||
| def infer_dtype(self, x_dtype, w_dtype): | |||||
| def infer_dtype(self, x_dtype, y_dtype): | |||||
| args = {'x': x_dtype, 'y': y_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1359,7 +1344,7 @@ class NotEqual(_LogicBinaryOp): | |||||
| """ | """ | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||||
| class Greater(_LogicBinaryOp): | class Greater(_LogicBinaryOp): | ||||
| @@ -1495,8 +1480,7 @@ class LogicalNot(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_subclass("x", x_dtype, mstype.tensor) | |||||
| validator.check_typename("x_dtype", x_dtype, [mstype.bool_]) | |||||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) | |||||
| return mstype.tensor_type(mstype.bool_) | return mstype.tensor_type(mstype.bool_) | ||||
| @@ -1526,7 +1510,7 @@ class LogicalAnd(_LogicBinaryOp): | |||||
| """ | """ | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) | |||||
| class LogicalOr(_LogicBinaryOp): | class LogicalOr(_LogicBinaryOp): | ||||
| @@ -1555,7 +1539,7 @@ class LogicalOr(_LogicBinaryOp): | |||||
| """ | """ | ||||
| def infer_dtype(self, x_dtype, y_dtype): | def infer_dtype(self, x_dtype, y_dtype): | ||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) | |||||
| class NPUAllocFloatStatus(PrimitiveWithInfer): | class NPUAllocFloatStatus(PrimitiveWithInfer): | ||||
| @@ -1616,13 +1600,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer): | |||||
| self.add_prim_attr("_side_effect_flag", True) | self.add_prim_attr("_side_effect_flag", True) | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) | |||||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||||
| return [8] | return [8] | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| args = {"x_dtype": x_dtype} | |||||
| validator.check_type_same(args, [mstype.float32]) | |||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) | |||||
| return mstype.float32 | return mstype.float32 | ||||
| @@ -1658,13 +1642,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer): | |||||
| self.add_prim_attr("_side_effect_flag", True) | self.add_prim_attr("_side_effect_flag", True) | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) | |||||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||||
| return [8] | return [8] | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| args = {"x_dtype": x_dtype} | |||||
| validator.check_type_same(args, [mstype.float32]) | |||||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) | |||||
| return mstype.float32 | return mstype.float32 | ||||
| @@ -1692,8 +1676,7 @@ class Cos(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -1721,8 +1704,7 @@ class ACos(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -1750,8 +1732,7 @@ class Sin(PrimitiveWithInfer): | |||||
| return x | return x | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||||
| return x | return x | ||||
| @@ -1796,19 +1777,19 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, iou_threshold=0.5): | def __init__(self, iou_threshold=0.5): | ||||
| """Init NMSWithMask""" | """Init NMSWithMask""" | ||||
| validator.check_type("iou_threshold", iou_threshold, [float]) | |||||
| validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) | |||||
| self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | ||||
| def infer_shape(self, bboxes_shape): | def infer_shape(self, bboxes_shape): | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) | |||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) | |||||
| cls_name = self.prim_name() | |||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | |||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| def infer_dtype(self, bboxes_dtype): | def infer_dtype(self, bboxes_dtype): | ||||
| validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) | |||||
| validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) | |||||
| validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) | |||||
| return (bboxes_dtype, mstype.int32, mstype.bool_) | return (bboxes_dtype, mstype.int32, mstype.bool_) | ||||
| @@ -1837,8 +1818,7 @@ class Abs(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x_dtype", x_type, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x_type, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| def infer_value(self, x): | def infer_value(self, x): | ||||
| @@ -1880,8 +1860,7 @@ class Sign(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_subclass('x', x_dtype, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x_dtype, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) | |||||
| return x_dtype | return x_dtype | ||||
| @@ -1910,8 +1889,7 @@ class Round(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type): | def infer_dtype(self, x_type): | ||||
| validator.check_subclass("x_dtype", x_type, mstype.tensor) | |||||
| validator.check_typename('x_dtype', x_type, mstype.number_type) | |||||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) | |||||
| return x_type | return x_type | ||||
| @@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive): | |||||
| Primitive.__init__(self, name) | Primitive.__init__(self, name) | ||||
| self.set_prim_type(prim_type.py_infer_shape) | self.set_prim_type(prim_type.py_infer_shape) | ||||
| def prim_name(self): | |||||
| return self.__class__.__name__ | |||||
| def _clone(self): | def _clone(self): | ||||
| """ | """ | ||||
| Deeply clones the primitive object. | Deeply clones the primitive object. | ||||
| @@ -23,20 +23,25 @@ from ...utils import keyword | |||||
| class CheckExceptionsEC(IExectorComponent): | class CheckExceptionsEC(IExectorComponent): | ||||
| """ | """ | ||||
| Check if the function raises the expected Exception. | |||||
| Check if the function raises the expected Exception and the error message contains specified keywords if not None. | |||||
| Examples: | Examples: | ||||
| { | { | ||||
| 'block': f, | 'block': f, | ||||
| 'exception': Exception | |||||
| 'exception': Exception, | |||||
| 'error_keywords': ['TensorAdd', 'shape'] | |||||
| } | } | ||||
| """ | """ | ||||
| def run_function(self, function, inputs, verification_set): | def run_function(self, function, inputs, verification_set): | ||||
| f = function[keyword.block] | f = function[keyword.block] | ||||
| args = inputs[keyword.desc_inputs] | args = inputs[keyword.desc_inputs] | ||||
| e = function.get(keyword.exception, Exception) | e = function.get(keyword.exception, Exception) | ||||
| error_kws = function.get(keyword.error_keywords, None) | |||||
| try: | try: | ||||
| with pytest.raises(e): | |||||
| with pytest.raises(e) as exec_info: | |||||
| f(*args) | f(*args) | ||||
| except: | except: | ||||
| raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") | raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") | ||||
| if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws): | |||||
| raise ValueError('Error message `{}` does not contain all keywords `{}`'.format( | |||||
| str(exec_info.value), error_kws)) | |||||
| @@ -87,8 +87,9 @@ def get_function_config(function): | |||||
| init_param_with = function.get(keyword.init_param_with, None) | init_param_with = function.get(keyword.init_param_with, None) | ||||
| split_outputs = function.get(keyword.split_outputs, True) | split_outputs = function.get(keyword.split_outputs, True) | ||||
| exception = function.get(keyword.exception, Exception) | exception = function.get(keyword.exception, Exception) | ||||
| error_keywords = function.get(keyword.error_keywords, None) | |||||
| return delta, max_error, input_selector, output_selector, sampling_times, \ | return delta, max_error, input_selector, output_selector, sampling_times, \ | ||||
| reduce_output, init_param_with, split_outputs, exception | |||||
| reduce_output, init_param_with, split_outputs, exception, error_keywords | |||||
| def get_grad_checking_options(function, inputs): | def get_grad_checking_options(function, inputs): | ||||
| """ | """ | ||||
| @@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs): | |||||
| """ | """ | ||||
| f = function[keyword.block] | f = function[keyword.block] | ||||
| args = inputs[keyword.desc_inputs] | args = inputs[keyword.desc_inputs] | ||||
| delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \ | |||||
| delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \ | |||||
| get_function_config(function) | get_function_config(function) | ||||
| return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output | return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output | ||||
| @@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex | |||||
| block = block_config | block = block_config | ||||
| delta, max_error, input_selector, output_selector, \ | delta, max_error, input_selector, output_selector, \ | ||||
| sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({}) | |||||
| sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({}) | |||||
| if isinstance(block_config, tuple) and isinstance(block_config[-1], dict): | if isinstance(block_config, tuple) and isinstance(block_config[-1], dict): | ||||
| block = block_config[0] | block = block_config[0] | ||||
| delta, max_error, input_selector, output_selector, \ | delta, max_error, input_selector, output_selector, \ | ||||
| sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1]) | |||||
| sampling_times, reduce_output, init_param_with, \ | |||||
| split_outputs, exception, error_keywords = get_function_config(block_config[-1]) | |||||
| if block: | if block: | ||||
| func_list.append({ | func_list.append({ | ||||
| @@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex | |||||
| keyword.const_first: const_first, | keyword.const_first: const_first, | ||||
| keyword.add_fake_input: add_fake_input, | keyword.add_fake_input: add_fake_input, | ||||
| keyword.split_outputs: split_outputs, | keyword.split_outputs: split_outputs, | ||||
| keyword.exception: exception | |||||
| keyword.exception: exception, | |||||
| keyword.error_keywords: error_keywords | |||||
| }) | }) | ||||
| if desc_inputs or desc_const: | if desc_inputs or desc_const: | ||||
| @@ -73,5 +73,6 @@ keyword.const_first = "const_first" | |||||
| keyword.add_fake_input = "add_fake_input" | keyword.add_fake_input = "add_fake_input" | ||||
| keyword.fake_input_type = "fake_input_type" | keyword.fake_input_type = "fake_input_type" | ||||
| keyword.exception = "exception" | keyword.exception = "exception" | ||||
| keyword.error_keywords = "error_keywords" | |||||
| sys.modules[__name__] = keyword | sys.modules[__name__] = keyword | ||||
| @@ -234,7 +234,7 @@ raise_set = [ | |||||
| 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), | 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), | ||||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | ||||
| ('ReduceSum_Error', { | ('ReduceSum_Error', { | ||||
| 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | ||||
| ] | ] | ||||
| @@ -0,0 +1,751 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test ops """ | |||||
| import functools | |||||
| import numpy as np | |||||
| from mindspore import ops | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| import mindspore.ops.composite as C | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.common.parameter import Parameter | |||||
| from ..ut_filter import non_graph_engine | |||||
| from mindspore.common.api import _executor | |||||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||||
| from ....mindspore_test_framework.pipeline.forward.compile_forward\ | |||||
| import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, | |||||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||||
| from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ | |||||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | |||||
| class AssignAddNet(nn.Cell): | |||||
| def __init__(self,): | |||||
| super(AssignAddNet, self).__init__() | |||||
| self.op = P.AssignAdd() | |||||
| self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_add1") | |||||
| def construct(self, x): | |||||
| self.op(self.inputdata, x) | |||||
| return self.inputdata | |||||
| class AssignSubNet(nn.Cell): | |||||
| def __init__(self,): | |||||
| super(AssignSubNet, self).__init__() | |||||
| self.op = P.AssignSub() | |||||
| self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_sub1") | |||||
| def construct(self, x): | |||||
| self.op(self.inputdata, x) | |||||
| return self.inputdata | |||||
| class ReduceNet(nn.Cell): | |||||
| def __init__(self, op_class, keep_dims, axis): | |||||
| super(ReduceNet, self).__init__() | |||||
| self.axis = axis | |||||
| self.op = op_class(keep_dims=keep_dims) | |||||
| def construct(self, x): | |||||
| return self.op(x, self.axis) | |||||
| class CumProdNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(CumProdNet, self).__init__() | |||||
| self.op = P.CumProd() | |||||
| def construct(self, x, axis): | |||||
| return self.op(x, axis) | |||||
| class CumSumNet(nn.Cell): | |||||
| def __init__(self, axis): | |||||
| super(CumSumNet, self).__init__() | |||||
| self.axis = axis | |||||
| self.op = P.CumSum() | |||||
| def construct(self, x): | |||||
| return self.op(x, self.axis) | |||||
| raise_set = [ | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('TensorAdd0', { | |||||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('TensorAdd1', { | |||||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('TensorAdd2', { | |||||
| 'block': (P.TensorAdd(), {'exception': ValueError, 'error_keywords': ['TensorAdd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # check input Tensor(bool_) | |||||
| ('AssignAdd', { | |||||
| 'block': (AssignAddNet(), {'exception': TypeError, 'error_keywords': ['AssignAdd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)], | |||||
| 'skip': ['backward']}), | |||||
| # check input Tensor(bool_) | |||||
| ('AssignSub', { | |||||
| 'block': (AssignSubNet(), {'exception': TypeError, 'error_keywords': ['AssignSub']}), | |||||
| 'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceMean1', { | |||||
| 'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceMean']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceMean2', { | |||||
| 'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceMean']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceSum1', { | |||||
| 'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceSum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceSum2', { | |||||
| 'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceSum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceAll1', { | |||||
| 'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceAll']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceAll2', { | |||||
| 'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceAll']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceMax1', { | |||||
| 'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceMax']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceMax2', { | |||||
| 'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceMax']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceMin1', { | |||||
| 'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceMin']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceMin2', { | |||||
| 'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceMin']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis is float, not int | |||||
| ('ReduceProd1', { | |||||
| 'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5.0), | |||||
| {'exception': TypeError, 'error_keywords': ['ReduceProd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # axis is out of range | |||||
| ('ReduceProd2', { | |||||
| 'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5), | |||||
| {'exception': ValueError, 'error_keywords': ['ReduceProd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x is Tensor(bool) | |||||
| ('CumProd1', { | |||||
| 'block': (CumProdNet(), | |||||
| {'exception': TypeError, 'error_keywords': ['CumProd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool)), 1], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis in float, not int | |||||
| ('CumProd2', { | |||||
| 'block': (CumProdNet(), | |||||
| {'exception': TypeError, 'error_keywords': ['CumProd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32)), 5.0], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y are Tensor(uint32) | |||||
| ('MatMul1', { | |||||
| 'block': (P.MatMul(), | |||||
| {'exception': TypeError, 'error_keywords': ['MatMul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.uint32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('MatMul2', { | |||||
| 'block': (P.MatMul(), | |||||
| {'exception': TypeError, 'error_keywords': ['MatMul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('MatMul3', { | |||||
| 'block': (P.MatMul(), | |||||
| {'exception': ValueError, 'error_keywords': ['MatMul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([2, 3]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # dims of x and y are less than 3 | |||||
| ('BatchMatMul1', { | |||||
| 'block': (P.BatchMatMul(), | |||||
| {'exception': ValueError, 'error_keywords': ['BatchMatMul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x is Tensor(bool) | |||||
| ('CumSum1', { | |||||
| 'block': (CumSumNet(axis=1), | |||||
| {'exception': TypeError, 'error_keywords': ['CumSum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))], | |||||
| 'skip': ['backward']}), | |||||
| # type of axis in float, not int | |||||
| ('CumSum2', { | |||||
| 'block': (CumSumNet(axis=1.0), | |||||
| {'exception': TypeError, 'error_keywords': ['CumSum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))], | |||||
| 'skip': ['backward']}), | |||||
| # intput is not tuple or list | |||||
| ('AddN1', { | |||||
| 'block': (P.AddN(), | |||||
| {'exception': TypeError, 'error_keywords': ['AddN']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32))], | |||||
| 'skip': ['backward']}), | |||||
| # type not match | |||||
| ('AddN2', { | |||||
| 'block': (P.AddN(), | |||||
| {'exception': TypeError, 'error_keywords': ['AddN']}), | |||||
| 'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.int32)))], | |||||
| 'skip': ['backward']}), | |||||
| # shape not match | |||||
| ('AddN3', { | |||||
| 'block': (P.AddN(), | |||||
| {'exception': ValueError, 'error_keywords': ['AddN']}), | |||||
| 'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32)))], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Neg1', { | |||||
| 'block': (P.Neg(), | |||||
| {'exception': TypeError, 'error_keywords': ['Neg']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Sub0', { | |||||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Sub1', { | |||||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Sub2', { | |||||
| 'block': (P.Sub(), {'exception': ValueError, 'error_keywords': ['Sub']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Mul0', { | |||||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Mul1', { | |||||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Mul2', { | |||||
| 'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Square1', { | |||||
| 'block': (P.Square(), | |||||
| {'exception': TypeError, 'error_keywords': ['Square']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Rsqrt1', { | |||||
| 'block': (P.Rsqrt(), | |||||
| {'exception': TypeError, 'error_keywords': ['Rsqrt']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Sqrt1', { | |||||
| 'block': (P.Sqrt(), | |||||
| {'exception': TypeError, 'error_keywords': ['Sqrt']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not Tensor | |||||
| ('Reciprocal1', { | |||||
| 'block': (P.Reciprocal(), | |||||
| {'exception': TypeError, 'error_keywords': ['Reciprocal']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input x is Tensor(bool) | |||||
| ('Pow1', { | |||||
| 'block': (P.Pow(), | |||||
| {'exception': TypeError, 'error_keywords': ['Pow']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is not Tensor | |||||
| ('Exp1', { | |||||
| 'block': (P.Exp(), | |||||
| {'exception': TypeError, 'error_keywords': ['Exp']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is not Tensor | |||||
| ('Log1', { | |||||
| 'block': (P.Log(), | |||||
| {'exception': TypeError, 'error_keywords': ['Log']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Minimum0', { | |||||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Minimum1', { | |||||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Minimum2', { | |||||
| 'block': (P.Minimum(), {'exception': ValueError, 'error_keywords': ['Minimum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Maximum0', { | |||||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Maximum1', { | |||||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Maximum2', { | |||||
| 'block': (P.Maximum(), {'exception': ValueError, 'error_keywords': ['Maximum']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('RealDiv0', { | |||||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('RealDiv1', { | |||||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('RealDiv2', { | |||||
| 'block': (P.RealDiv(), {'exception': ValueError, 'error_keywords': ['RealDiv']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Div0', { | |||||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Div1', { | |||||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Div2', { | |||||
| 'block': (P.Div(), {'exception': ValueError, 'error_keywords': ['Div']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('FloorDiv0', { | |||||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('FloorDiv1', { | |||||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('FloorDiv2', { | |||||
| 'block': (P.FloorDiv(), {'exception': ValueError, 'error_keywords': ['FloorDiv']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input x is Tensor(int32), not Tensor(float) | |||||
| ('Floor1', { | |||||
| 'block': (P.Floor(), | |||||
| {'exception': TypeError, 'error_keywords': ['Floor']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('FloorMod0', { | |||||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('FloorMod1', { | |||||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('FFloorMod2', { | |||||
| 'block': (P.FloorMod(), {'exception': ValueError, 'error_keywords': ['FloorMod']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input x is Tensor(int32), not Tensor(float) | |||||
| ('Acosh1', { | |||||
| 'block': (P.Acosh(), | |||||
| {'exception': TypeError, 'error_keywords': ['Acosh']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Equal0', { | |||||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('Equal1', { | |||||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('Equal2', { | |||||
| 'block': (P.Equal(), {'exception': ValueError, 'error_keywords': ['Equal']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('EqualCount0', { | |||||
| 'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('EqualCount1', { | |||||
| 'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| # input is not tensor | |||||
| ('NotEqual0', { | |||||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('NotEqual1', { | |||||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('NotEqual2', { | |||||
| 'block': (P.NotEqual(), {'exception': ValueError, 'error_keywords': ['NotEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Greater0', { | |||||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('Greater1', { | |||||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('Greater2', { | |||||
| 'block': (P.Greater(), {'exception': ValueError, 'error_keywords': ['Greater']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('GreaterEqual0', { | |||||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('GreaterEqual1', { | |||||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('GreaterEqual2', { | |||||
| 'block': (P.GreaterEqual(), {'exception': ValueError, 'error_keywords': ['GreaterEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Less0', { | |||||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('Less1', { | |||||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('Less2', { | |||||
| 'block': (P.Less(), {'exception': ValueError, 'error_keywords': ['Less']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('LessEqual0', { | |||||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('LessEqual1', { | |||||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('LessEqual2', { | |||||
| 'block': (P.LessEqual(), {'exception': ValueError, 'error_keywords': ['LessEqual']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input x is not Tensor(bool) | |||||
| ('LogicalNot1', { | |||||
| 'block': (P.LogicalNot(), | |||||
| {'exception': TypeError, 'error_keywords': ['LogicalNot']}), | |||||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('LogicalAnd1', { | |||||
| 'block': (P.LogicalAnd(), {'exception': TypeError, 'error_keywords': ['LogicalAnd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('LogicalAnd2', { | |||||
| 'block': (P.LogicalAnd(), {'exception': ValueError, 'error_keywords': ['LogicalAnd']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # type of x and y not match | |||||
| ('LogicalOr1', { | |||||
| 'block': (P.LogicalOr(), {'exception': TypeError, 'error_keywords': ['LogicalOr']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # shape of x and y not match | |||||
| ('LogicalOr2', { | |||||
| 'block': (P.LogicalOr(), {'exception': ValueError, 'error_keywords': ['LogicalOr']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('NPUGetFloatStatus0', { | |||||
| 'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(int32), not Tensor(float32) | |||||
| ('NPUGetFloatStatus1', { | |||||
| 'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # dims is not 1 | |||||
| ('NPUGetFloatStatus2', { | |||||
| 'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape[0] is not 8 | |||||
| ('NPUGetFloatStatus3', { | |||||
| 'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('NPUClearFloatStatus0', { | |||||
| 'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(int32), not Tensor(float32) | |||||
| ('NPUClearFloatStatus1', { | |||||
| 'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # dims is not 1 | |||||
| ('NPUClearFloatStatus2', { | |||||
| 'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape[0] is not 8 | |||||
| ('NPUClearFloatStatus3', { | |||||
| 'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Cos0', { | |||||
| 'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Cos1', { | |||||
| 'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('ACos0', { | |||||
| 'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('ACos1', { | |||||
| 'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Sin0', { | |||||
| 'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Sin1', { | |||||
| 'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('NMSWithMask0', { | |||||
| 'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is not Tensor(float16) or Tensor(float32) | |||||
| ('NMSWithMask1', { | |||||
| 'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||||
| 'skip': ['backward']}), | |||||
| # dims is not 2 | |||||
| ('NMSWithMask2', { | |||||
| 'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # shape[1] is not 5 | |||||
| ('NMSWithMask3', { | |||||
| 'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 2]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Abs0', { | |||||
| 'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Abs1', { | |||||
| 'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Sign0', { | |||||
| 'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Sign1', { | |||||
| 'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # input is not tensor | |||||
| ('Round0', { | |||||
| 'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}), | |||||
| 'desc_inputs': [5.0], | |||||
| 'skip': ['backward']}), | |||||
| # input is Tensor(bool) | |||||
| ('Round1', { | |||||
| 'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||||
| 'skip': ['backward']}), | |||||
| # one input is scalar, and another is Tensor(float32) | |||||
| ('Atan20', { | |||||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, but element types are not same | |||||
| ('Atan21', { | |||||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| # input two tensors, their shapes do not match | |||||
| ('Atan22', { | |||||
| 'block': (P.Atan2(), {'exception': ValueError, 'error_keywords': ['Atan2']}), | |||||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||||
| 'skip': ['backward']}), | |||||
| ] | |||||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||||
| def test_check_exception(): | |||||
| return raise_set | |||||