| @@ -15,6 +15,7 @@ | |||
| """Check parameters.""" | |||
| import re | |||
| from enum import Enum | |||
| from functools import reduce | |||
| from itertools import repeat | |||
| from collections import Iterable | |||
| @@ -93,8 +94,131 @@ rel_strs = { | |||
| } | |||
| class Validator: | |||
| """validator for checking input parameters""" | |||
| @staticmethod | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): | |||
| """ | |||
| Method for judging relation between two int values or list/tuple made up of ints. | |||
| This method is not suitable for judging relation between floats, since it does not consider float error. | |||
| """ | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||
| msg_prefix = f'For {prim_name} the' if prim_name else "The" | |||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' | |||
| f' but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether an int value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) | |||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||
| f' but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||
| """Check whether some type is sublcass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the element types of input tensors are the same.""" | |||
| def _check_tensor_type(arg): | |||
| arg_key, arg_val = arg | |||
| Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | |||
| elem_type = arg_val.element_type() | |||
| if not elem_type in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {elem_type}.') | |||
| return (arg_key, elem_type) | |||
| def _check_types_same(arg1, arg2): | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| if arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| return arg1 | |||
| elem_types = map(_check_tensor_type, args.items()) | |||
| reduce(_check_types_same, elem_types) | |||
| @staticmethod | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the types of inputs are the same. if the input args are tensors, check their element types""" | |||
| def _check_argument_type(arg): | |||
| arg_key, arg_val = arg | |||
| if isinstance(arg_val, type(mstype.tensor)): | |||
| arg_val = arg_val.element_type() | |||
| if not arg_val in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {arg_val}.') | |||
| return arg | |||
| def _check_types_same(arg1, arg2): | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| excp_flag = False | |||
| if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): | |||
| arg1_type = arg1_type.element_type() | |||
| arg2_type = arg2_type.element_type() | |||
| elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | |||
| pass | |||
| else: | |||
| excp_flag = True | |||
| if excp_flag or arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| return arg1 | |||
| reduce(_check_types_same, map(_check_argument_type, args.items())) | |||
| @staticmethod | |||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | |||
| """Check whether a values is instance of some types.""" | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| num_types = len(valid_types) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' | |||
| f'{"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | |||
| # `check_value_type('x', True, [bool, int])` will check pass | |||
| if isinstance(arg_value, bool) and bool not in tuple(valid_types): | |||
| raise_error_msg() | |||
| if isinstance(arg_value, tuple(valid_types)): | |||
| return arg_value | |||
| raise_error_msg() | |||
| class ParamValidator: | |||
| """Parameter validator.""" | |||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||
| @staticmethod | |||
| def equal(arg_name, arg_value, cond_str, cond): | |||
| @@ -16,13 +16,14 @@ | |||
| """broadcast""" | |||
| def _get_broadcast_shape(x_shape, y_shape): | |||
| def _get_broadcast_shape(x_shape, y_shape, prim_name): | |||
| """ | |||
| Doing broadcast between tensor x and tensor y. | |||
| Args: | |||
| x_shape (list): The shape of tensor x. | |||
| y_shape (list): The shape of tensor y. | |||
| prim_name (str): Primitive name. | |||
| Returns: | |||
| List, the shape that broadcast between tensor x and tensor y. | |||
| @@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape): | |||
| elif x_shape[i] == y_shape[i]: | |||
| broadcast_shape_back.append(x_shape[i]) | |||
| else: | |||
| raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape)) | |||
| raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format( | |||
| prim_name, x_shape, y_shape)) | |||
| broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] | |||
| broadcast_shape = broadcast_shape_front + broadcast_shape_back | |||
| @@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| def _check_infer_attr_reduce(axis, keep_dims): | |||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||
| validator.check_type('axis', axis, [int, tuple]) | |||
| if isinstance(axis, tuple): | |||
| for index, value in enumerate(axis): | |||
| validator.check_type('axis[%d]' % index, value, [int]) | |||
| class ExpandDims(PrimitiveWithInfer): | |||
| """ | |||
| @@ -1090,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): | |||
| axis = self.axis | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) | |||
| return ouput_shape, ouput_shape | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -1136,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||
| axis = self.axis | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) | |||
| return ouput_shape, ouput_shape | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -19,7 +19,7 @@ import numpy as np | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.tensor import Tensor | |||
| @@ -27,16 +27,16 @@ from .._utils import _get_broadcast_shape | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| def _infer_shape_reduce(x, axis, keep_dims): | |||
| def _infer_shape_reduce(x, axis, keep_dims, prim_name): | |||
| """Common infer for reduce operator""" | |||
| def reduce_one_axis(one_axis): | |||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT) | |||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name) | |||
| if one_axis < 0: | |||
| one_axis += dim | |||
| axis_reduce.add(one_axis) | |||
| validator.check_type('axis', axis, [int, tuple, list]) | |||
| validator.check_value_type('axis', axis, [int, tuple, list], prim_name) | |||
| dim = len(x) | |||
| axis_reduce = set() | |||
| @@ -48,7 +48,7 @@ def _infer_shape_reduce(x, axis, keep_dims): | |||
| return [1] * dim | |||
| return [] | |||
| for index, one_axis in enumerate(axis): | |||
| validator.check_type('axis[%d]' % index, one_axis, [int]) | |||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], prim_name) | |||
| reduce_one_axis(one_axis) | |||
| out_shape = [] | |||
| @@ -61,14 +61,6 @@ def _infer_shape_reduce(x, axis, keep_dims): | |||
| return out_shape | |||
| def _check_infer_attr_reduce(axis, keep_dims): | |||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||
| validator.check_type('axis', axis, [int, tuple]) | |||
| if isinstance(axis, tuple): | |||
| for index, value in enumerate(axis): | |||
| validator.check_type('axis[%d]' % index, value, [int]) | |||
| class _BinaryOp(PrimitiveWithInfer): | |||
| """ | |||
| Define binary operators. | |||
| @@ -82,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| def infer_shape(self, x_shape, y_shape): | |||
| return _get_broadcast_shape(x_shape, y_shape) | |||
| return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) | |||
| class _MathBinaryOp(_BinaryOp): | |||
| @@ -91,15 +83,13 @@ class _MathBinaryOp(_BinaryOp): | |||
| """ | |||
| @staticmethod | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type): | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_dtype=mstype.number_type, prim_name=None): | |||
| args_type = {"x": x_dtype, "y": y_dtype} | |||
| validator.check_args_tensor(args_type) | |||
| args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype} | |||
| validator.check_type_same(args_dtype, valid_dtype) | |||
| validator.check_tensor_type_same(args_type, valid_dtype, prim_name) | |||
| return x_dtype | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype) | |||
| return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) | |||
| class TensorAdd(_MathBinaryOp): | |||
| @@ -167,7 +157,7 @@ class AssignAdd(PrimitiveWithInfer): | |||
| def infer_dtype(self, variable, value): | |||
| args = {"value": value} | |||
| validator.check_type_same(args, mstype.number_type) | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) | |||
| return value | |||
| @@ -208,7 +198,7 @@ class AssignSub(PrimitiveWithInfer): | |||
| def infer_dtype(self, variable, value): | |||
| args = {"value": value} | |||
| validator.check_type_same(args, mstype.number_type) | |||
| validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) | |||
| return value | |||
| @@ -229,15 +219,16 @@ class _Reduce(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, keep_dims=False): | |||
| """init Reduce""" | |||
| validator.check_type('keep_dims', keep_dims, [bool]) | |||
| validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) | |||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | |||
| def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): | |||
| axis_v = axis['value'] | |||
| input_shp = input_x['shape'] | |||
| validator.check_subclass('input_x', input_x['dtype'], mstype.tensor) | |||
| validator.check_typename('input_x', input_x['dtype'], valid_dtype) | |||
| input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims) | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) | |||
| input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) | |||
| return {'shape': input_shp, | |||
| 'dtype': input_x['dtype'], | |||
| 'value': None} | |||
| @@ -472,16 +463,17 @@ class CumProd(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, exclusive=False, reverse=False): | |||
| self.exclusive = validator.check_type("exclusive", exclusive, [bool]) | |||
| self.reverse = validator.check_type("reverse", reverse, [bool]) | |||
| cls_name = self.prim_name() | |||
| self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) | |||
| self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) | |||
| def infer_shape(self, x_shape, axis_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, axis_type): | |||
| validator.check_subclass('x_type', x_type, mstype.tensor) | |||
| validator.check_typename('x_type', x_type, mstype.number_type) | |||
| validator.check_subclass("axis_type", axis_type, mstype.int_) | |||
| cls_name = self.prim_name() | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) | |||
| validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | |||
| return x_type | |||
| @@ -515,8 +507,9 @@ class MatMul(PrimitiveWithInfer): | |||
| def __init__(self, transpose_a=False, transpose_b=False): | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| validator.check_type("transpose_a", transpose_a, [bool]) | |||
| validator.check_type("transpose_b", transpose_b, [bool]) | |||
| cls_name = self.prim_name() | |||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| def check_shape_size(self, x, y): | |||
| if len(x) != 2 or len(y) != 2: | |||
| @@ -525,11 +518,11 @@ class MatMul(PrimitiveWithInfer): | |||
| def infer_shape(self, x, y): | |||
| self.check_shape_size(x, y) | |||
| cls_name = self.__class__.__name__ | |||
| cls_name = self.prim_name() | |||
| # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two | |||
| for i in range(len(x) - 2): | |||
| if x[i] != y[i]: | |||
| raise ValueError(f'{cls_name} shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') | |||
| raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}') | |||
| # validate whether last two dims satifing matrix multiply | |||
| x_last = x[-2:] | |||
| @@ -538,8 +531,8 @@ class MatMul(PrimitiveWithInfer): | |||
| x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0] | |||
| y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1] | |||
| if x_col != y_row: | |||
| raise ValueError(f'{cls_name} evaluator shapes of inputs can not do this operator, got {x_col} and {y_row}' | |||
| + f' for {cls_name}, with x shape {x}(transpose_a={self.transpose_a})' | |||
| raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,' | |||
| + f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})' | |||
| + f', y shape {y}(transpose_b={self.transpose_b}).') | |||
| # set attribute | |||
| self.add_prim_attr('transpose_x1', self.transpose_a) | |||
| @@ -549,10 +542,8 @@ class MatMul(PrimitiveWithInfer): | |||
| return ret_dims | |||
| def infer_dtype(self, x, y): | |||
| validator.check_subclass("x", x, mstype.tensor) | |||
| validator.check_subclass("y", y, mstype.tensor) | |||
| args = {"x dtype": x, "y dtype": y} | |||
| validator.check_type_same(args, mstype.float_type + mstype.int_type) | |||
| args = {"x": x, "y": y} | |||
| validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) | |||
| return x | |||
| @@ -596,12 +587,13 @@ class BatchMatMul(MatMul): | |||
| def __init__(self, transpose_a=False, transpose_b=False): | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) | |||
| self.__setattr_flag__ = True | |||
| validator.check_type("transpose_a", transpose_a, [bool]) | |||
| validator.check_type("transpose_b", transpose_b, [bool]) | |||
| cls_name = self.prim_name() | |||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | |||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | |||
| def check_shape_size(self, x, y): | |||
| if len(x) != len(y) or len(x) < 3: | |||
| raise ValueError('BatchMatMul input x, y should be the same dimension size and should be ' | |||
| raise ValueError('For \'BatchMatMul\' input x, y should be the same dimension size and should be ' | |||
| 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') | |||
| @@ -633,18 +625,17 @@ class CumSum(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, exclusive=False, reverse=False): | |||
| """init cumsum""" | |||
| self.exclusive = validator.check_type('exclusive', exclusive, [bool]) | |||
| self.add_prim_attr("exclusive", self.exclusive) | |||
| self.reverse = validator.check_type('reverse', reverse, [bool]) | |||
| self.add_prim_attr("reverse", self.reverse) | |||
| cls_name = self.prim_name() | |||
| validator.check_value_type('exclusive', exclusive, [bool], cls_name) | |||
| validator.check_value_type('reverse', reverse, [bool], cls_name) | |||
| self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) | |||
| def __infer__(self, x, axis): | |||
| cls_name = self.prim_name() | |||
| x_shp = x['shape'] | |||
| validator.check_type('axis', axis['value'], [int]) | |||
| validator.check_subclass('x', x['dtype'], mstype.tensor) | |||
| validator.check_typename('x', x['dtype'], [mstype.uint8, mstype.int8, | |||
| mstype.int32, mstype.float16, mstype.float32]) | |||
| validator.check_value_type('axis', axis['value'], [int], cls_name) | |||
| valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] | |||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) | |||
| return {'shape': x_shp, | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| @@ -685,21 +676,22 @@ class AddN(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | |||
| def infer_shape(self, inputs): | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE) | |||
| cls_name = self.prim_name() | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| self.add_prim_attr('n', len(inputs)) | |||
| shp0 = inputs[0] | |||
| for i, shp in enumerate(inputs): | |||
| validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0) | |||
| validator.check(f"shape of inputs[{i}]", shp, 'shape of inputs[0]', shp0, Rel.EQ, cls_name) | |||
| return shp0 | |||
| def infer_dtype(self, inputs): | |||
| validator.check_type("inputs", inputs, [tuple, list]) | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE) | |||
| cls_name = self.prim_name() | |||
| validator.check_value_type("inputs", inputs, [tuple, list], cls_name) | |||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | |||
| args = {} | |||
| for i, dtype in enumerate(inputs): | |||
| validator.check_subclass(f"inputs[{i}]", dtype, mstype.tensor) | |||
| args[f"inputs[{i}]"] = dtype | |||
| validator.check_type_same(args, mstype.number_type + (mstype.bool_,)) | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), cls_name) | |||
| return inputs[0] | |||
| @@ -723,8 +715,7 @@ class Neg(PrimitiveWithInfer): | |||
| return input_x | |||
| def infer_dtype(self, input_x): | |||
| validator.check_subclass("input_x", input_x, mstype.tensor) | |||
| validator.check_typename("input_x", input_x, mstype.number_type) | |||
| validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) | |||
| return input_x | |||
| @@ -807,8 +798,7 @@ class Square(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||
| return x_type | |||
| @@ -837,8 +827,7 @@ class Rsqrt(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||
| return x_type | |||
| @@ -867,8 +856,7 @@ class Sqrt(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_type, mstype.number_type) | |||
| validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) | |||
| return x_type | |||
| @@ -898,7 +886,7 @@ class Reciprocal(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x", x, mstype.tensor) | |||
| validator.check_subclass("x", x, mstype.tensor, self.prim_name()) | |||
| return x | |||
| @@ -936,8 +924,7 @@ class Pow(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x, power): | |||
| validator.check_subclass("x", x, mstype.tensor) | |||
| validator.check_typename("power", power, mstype.number_type) | |||
| validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name()) | |||
| return x | |||
| @@ -967,7 +954,7 @@ class Exp(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x", x_type, mstype.tensor) | |||
| validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) | |||
| return x_type | |||
| @@ -996,7 +983,7 @@ class Log(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x", x, mstype.tensor) | |||
| validator.check_subclass("x", x, mstype.tensor, self.prim_name()) | |||
| return x | |||
| @@ -1178,8 +1165,7 @@ class Floor(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_dtype, mstype.float_type) | |||
| validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) | |||
| return x_dtype | |||
| @@ -1234,8 +1220,7 @@ class Acosh(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||
| return x | |||
| @@ -1245,15 +1230,13 @@ class _LogicBinaryOp(_BinaryOp): | |||
| """ | |||
| @staticmethod | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type): | |||
| args_type = {"x": x_dtype, "y": y_dtype} | |||
| validator.check_args_tensor(args_type) | |||
| args_dtype = {"x_dtype": x_dtype, "y_dtype": y_dtype} | |||
| validator.check_type_same(args_dtype, valid_type) | |||
| def do_infer_dtype(x_dtype, y_dtype, valid_type=mstype.number_type, prim_name=None): | |||
| args_dtype = {"x": x_dtype, "y": y_dtype} | |||
| validator.check_tensor_type_same(args_dtype, valid_type, prim_name) | |||
| return mstype.tensor_type(mstype.bool_) | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype) | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) | |||
| class Equal(_LogicBinaryOp): | |||
| @@ -1289,7 +1272,7 @@ class Equal(_LogicBinaryOp): | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||
| class EqualCount(PrimitiveWithInfer): | |||
| @@ -1318,11 +1301,13 @@ class EqualCount(PrimitiveWithInfer): | |||
| """init EqualCount""" | |||
| self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) | |||
| def infer_shape(self, x_shape, w_shape): | |||
| def infer_shape(self, x_shape, y_shape): | |||
| output_shape = (1,) | |||
| return output_shape | |||
| def infer_dtype(self, x_dtype, w_dtype): | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| args = {'x': x_dtype, 'y': y_dtype} | |||
| validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||
| return x_dtype | |||
| @@ -1359,7 +1344,7 @@ class NotEqual(_LogicBinaryOp): | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,)) | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) | |||
| class Greater(_LogicBinaryOp): | |||
| @@ -1495,8 +1480,7 @@ class LogicalNot(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass("x", x_dtype, mstype.tensor) | |||
| validator.check_typename("x_dtype", x_dtype, [mstype.bool_]) | |||
| validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) | |||
| return mstype.tensor_type(mstype.bool_) | |||
| @@ -1526,7 +1510,7 @@ class LogicalAnd(_LogicBinaryOp): | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) | |||
| class LogicalOr(_LogicBinaryOp): | |||
| @@ -1555,7 +1539,7 @@ class LogicalOr(_LogicBinaryOp): | |||
| """ | |||
| def infer_dtype(self, x_dtype, y_dtype): | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,)) | |||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) | |||
| class NPUAllocFloatStatus(PrimitiveWithInfer): | |||
| @@ -1616,13 +1600,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer): | |||
| self.add_prim_attr("_side_effect_flag", True) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) | |||
| cls_name = self.prim_name() | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| args = {"x_dtype": x_dtype} | |||
| validator.check_type_same(args, [mstype.float32]) | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) | |||
| return mstype.float32 | |||
| @@ -1658,13 +1642,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer): | |||
| self.add_prim_attr("_side_effect_flag", True) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ) | |||
| cls_name = self.prim_name() | |||
| validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) | |||
| validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) | |||
| return [8] | |||
| def infer_dtype(self, x_dtype): | |||
| args = {"x_dtype": x_dtype} | |||
| validator.check_type_same(args, [mstype.float32]) | |||
| validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) | |||
| return mstype.float32 | |||
| @@ -1692,8 +1676,7 @@ class Cos(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||
| return x | |||
| @@ -1721,8 +1704,7 @@ class ACos(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||
| return x | |||
| @@ -1750,8 +1732,7 @@ class Sin(PrimitiveWithInfer): | |||
| return x | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("x_dtype", x, mstype.tensor) | |||
| validator.check_typename('x_dtype', x, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) | |||
| return x | |||
| @@ -1796,19 +1777,19 @@ class NMSWithMask(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, iou_threshold=0.5): | |||
| """Init NMSWithMask""" | |||
| validator.check_type("iou_threshold", iou_threshold, [float]) | |||
| validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) | |||
| self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | |||
| def infer_shape(self, bboxes_shape): | |||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ) | |||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT) | |||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ) | |||
| cls_name = self.prim_name() | |||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | |||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||
| num = bboxes_shape[0] | |||
| return (bboxes_shape, (num,), (num,)) | |||
| def infer_dtype(self, bboxes_dtype): | |||
| validator.check_subclass("bboxes_dtype", bboxes_dtype, mstype.tensor) | |||
| validator.check_typename("bboxes_dtype", bboxes_dtype, [mstype.float16, mstype.float32]) | |||
| validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) | |||
| return (bboxes_dtype, mstype.int32, mstype.bool_) | |||
| @@ -1837,8 +1818,7 @@ class Abs(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x_dtype", x_type, mstype.tensor) | |||
| validator.check_typename('x_dtype', x_type, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) | |||
| return x_type | |||
| def infer_value(self, x): | |||
| @@ -1880,8 +1860,7 @@ class Sign(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| validator.check_subclass('x', x_dtype, mstype.tensor) | |||
| validator.check_typename('x_dtype', x_dtype, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) | |||
| return x_dtype | |||
| @@ -1910,8 +1889,7 @@ class Round(PrimitiveWithInfer): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| validator.check_subclass("x_dtype", x_type, mstype.tensor) | |||
| validator.check_typename('x_dtype', x_type, mstype.number_type) | |||
| validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) | |||
| return x_type | |||
| @@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive): | |||
| Primitive.__init__(self, name) | |||
| self.set_prim_type(prim_type.py_infer_shape) | |||
| def prim_name(self): | |||
| return self.__class__.__name__ | |||
| def _clone(self): | |||
| """ | |||
| Deeply clones the primitive object. | |||
| @@ -23,20 +23,25 @@ from ...utils import keyword | |||
| class CheckExceptionsEC(IExectorComponent): | |||
| """ | |||
| Check if the function raises the expected Exception. | |||
| Check if the function raises the expected Exception and the error message contains specified keywords if not None. | |||
| Examples: | |||
| { | |||
| 'block': f, | |||
| 'exception': Exception | |||
| 'exception': Exception, | |||
| 'error_keywords': ['TensorAdd', 'shape'] | |||
| } | |||
| """ | |||
| def run_function(self, function, inputs, verification_set): | |||
| f = function[keyword.block] | |||
| args = inputs[keyword.desc_inputs] | |||
| e = function.get(keyword.exception, Exception) | |||
| error_kws = function.get(keyword.error_keywords, None) | |||
| try: | |||
| with pytest.raises(e): | |||
| with pytest.raises(e) as exec_info: | |||
| f(*args) | |||
| except: | |||
| raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") | |||
| if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws): | |||
| raise ValueError('Error message `{}` does not contain all keywords `{}`'.format( | |||
| str(exec_info.value), error_kws)) | |||
| @@ -87,8 +87,9 @@ def get_function_config(function): | |||
| init_param_with = function.get(keyword.init_param_with, None) | |||
| split_outputs = function.get(keyword.split_outputs, True) | |||
| exception = function.get(keyword.exception, Exception) | |||
| error_keywords = function.get(keyword.error_keywords, None) | |||
| return delta, max_error, input_selector, output_selector, sampling_times, \ | |||
| reduce_output, init_param_with, split_outputs, exception | |||
| reduce_output, init_param_with, split_outputs, exception, error_keywords | |||
| def get_grad_checking_options(function, inputs): | |||
| """ | |||
| @@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs): | |||
| """ | |||
| f = function[keyword.block] | |||
| args = inputs[keyword.desc_inputs] | |||
| delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \ | |||
| delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \ | |||
| get_function_config(function) | |||
| return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output | |||
| @@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex | |||
| block = block_config | |||
| delta, max_error, input_selector, output_selector, \ | |||
| sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({}) | |||
| sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({}) | |||
| if isinstance(block_config, tuple) and isinstance(block_config[-1], dict): | |||
| block = block_config[0] | |||
| delta, max_error, input_selector, output_selector, \ | |||
| sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1]) | |||
| sampling_times, reduce_output, init_param_with, \ | |||
| split_outputs, exception, error_keywords = get_function_config(block_config[-1]) | |||
| if block: | |||
| func_list.append({ | |||
| @@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex | |||
| keyword.const_first: const_first, | |||
| keyword.add_fake_input: add_fake_input, | |||
| keyword.split_outputs: split_outputs, | |||
| keyword.exception: exception | |||
| keyword.exception: exception, | |||
| keyword.error_keywords: error_keywords | |||
| }) | |||
| if desc_inputs or desc_const: | |||
| @@ -73,5 +73,6 @@ keyword.const_first = "const_first" | |||
| keyword.add_fake_input = "add_fake_input" | |||
| keyword.fake_input_type = "fake_input_type" | |||
| keyword.exception = "exception" | |||
| keyword.error_keywords = "error_keywords" | |||
| sys.modules[__name__] = keyword | |||
| @@ -234,7 +234,7 @@ raise_set = [ | |||
| 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), | |||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | |||
| ('ReduceSum_Error', { | |||
| 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}), | |||
| 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}), | |||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | |||
| ] | |||
| @@ -0,0 +1,751 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test ops """ | |||
| import functools | |||
| import numpy as np | |||
| from mindspore import ops | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| import mindspore.ops.composite as C | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.parameter import Parameter | |||
| from ..ut_filter import non_graph_engine | |||
| from mindspore.common.api import _executor | |||
| from ....mindspore_test_framework.mindspore_test import mindspore_test | |||
| from ....mindspore_test_framework.pipeline.forward.compile_forward\ | |||
| import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, | |||
| pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||
| from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ | |||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | |||
| class AssignAddNet(nn.Cell): | |||
| def __init__(self,): | |||
| super(AssignAddNet, self).__init__() | |||
| self.op = P.AssignAdd() | |||
| self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_add1") | |||
| def construct(self, x): | |||
| self.op(self.inputdata, x) | |||
| return self.inputdata | |||
| class AssignSubNet(nn.Cell): | |||
| def __init__(self,): | |||
| super(AssignSubNet, self).__init__() | |||
| self.op = P.AssignSub() | |||
| self.inputdata = Parameter(Tensor(np.zeros([1]).astype(np.bool_), mstype.bool_), name="assign_sub1") | |||
| def construct(self, x): | |||
| self.op(self.inputdata, x) | |||
| return self.inputdata | |||
| class ReduceNet(nn.Cell): | |||
| def __init__(self, op_class, keep_dims, axis): | |||
| super(ReduceNet, self).__init__() | |||
| self.axis = axis | |||
| self.op = op_class(keep_dims=keep_dims) | |||
| def construct(self, x): | |||
| return self.op(x, self.axis) | |||
| class CumProdNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CumProdNet, self).__init__() | |||
| self.op = P.CumProd() | |||
| def construct(self, x, axis): | |||
| return self.op(x, axis) | |||
| class CumSumNet(nn.Cell): | |||
| def __init__(self, axis): | |||
| super(CumSumNet, self).__init__() | |||
| self.axis = axis | |||
| self.op = P.CumSum() | |||
| def construct(self, x): | |||
| return self.op(x, self.axis) | |||
| raise_set = [ | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('TensorAdd0', { | |||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('TensorAdd1', { | |||
| 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('TensorAdd2', { | |||
| 'block': (P.TensorAdd(), {'exception': ValueError, 'error_keywords': ['TensorAdd']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # check input Tensor(bool_) | |||
| ('AssignAdd', { | |||
| 'block': (AssignAddNet(), {'exception': TypeError, 'error_keywords': ['AssignAdd']}), | |||
| 'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)], | |||
| 'skip': ['backward']}), | |||
| # check input Tensor(bool_) | |||
| ('AssignSub', { | |||
| 'block': (AssignSubNet(), {'exception': TypeError, 'error_keywords': ['AssignSub']}), | |||
| 'desc_inputs': [Tensor(np.ones([1]).astype(np.bool_), mstype.bool_)], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceMean1', { | |||
| 'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceMean']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceMean2', { | |||
| 'block': (ReduceNet(P.ReduceMean, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceMean']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceSum1', { | |||
| 'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceSum']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceSum2', { | |||
| 'block': (ReduceNet(P.ReduceSum, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceSum']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceAll1', { | |||
| 'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceAll']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceAll2', { | |||
| 'block': (ReduceNet(P.ReduceAll, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceAll']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceMax1', { | |||
| 'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceMax']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceMax2', { | |||
| 'block': (ReduceNet(P.ReduceMax, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceMax']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceMin1', { | |||
| 'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceMin']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceMin2', { | |||
| 'block': (ReduceNet(P.ReduceMin, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceMin']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of axis is float, not int | |||
| ('ReduceProd1', { | |||
| 'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5.0), | |||
| {'exception': TypeError, 'error_keywords': ['ReduceProd']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # axis is out of range | |||
| ('ReduceProd2', { | |||
| 'block': (ReduceNet(P.ReduceProd, keep_dims=True, axis=5), | |||
| {'exception': ValueError, 'error_keywords': ['ReduceProd']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x is Tensor(bool) | |||
| ('CumProd1', { | |||
| 'block': (CumProdNet(), | |||
| {'exception': TypeError, 'error_keywords': ['CumProd']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool)), 1], | |||
| 'skip': ['backward']}), | |||
| # type of axis in float, not int | |||
| ('CumProd2', { | |||
| 'block': (CumProdNet(), | |||
| {'exception': TypeError, 'error_keywords': ['CumProd']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.float32)), 5.0], | |||
| 'skip': ['backward']}), | |||
| # type of x and y are Tensor(uint32) | |||
| ('MatMul1', { | |||
| 'block': (P.MatMul(), | |||
| {'exception': TypeError, 'error_keywords': ['MatMul']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.uint32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('MatMul2', { | |||
| 'block': (P.MatMul(), | |||
| {'exception': TypeError, 'error_keywords': ['MatMul']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('MatMul3', { | |||
| 'block': (P.MatMul(), | |||
| {'exception': ValueError, 'error_keywords': ['MatMul']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.float32)), Tensor(np.ones([2, 3]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # dims of x and y are less than 3 | |||
| ('BatchMatMul1', { | |||
| 'block': (P.BatchMatMul(), | |||
| {'exception': ValueError, 'error_keywords': ['BatchMatMul']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # type of x is Tensor(bool) | |||
| ('CumSum1', { | |||
| 'block': (CumSumNet(axis=1), | |||
| {'exception': TypeError, 'error_keywords': ['CumSum']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))], | |||
| 'skip': ['backward']}), | |||
| # type of axis in float, not int | |||
| ('CumSum2', { | |||
| 'block': (CumSumNet(axis=1.0), | |||
| {'exception': TypeError, 'error_keywords': ['CumSum']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3, 5]).astype(np.bool))], | |||
| 'skip': ['backward']}), | |||
| # intput is not tuple or list | |||
| ('AddN1', { | |||
| 'block': (P.AddN(), | |||
| {'exception': TypeError, 'error_keywords': ['AddN']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.uint32))], | |||
| 'skip': ['backward']}), | |||
| # type not match | |||
| ('AddN2', { | |||
| 'block': (P.AddN(), | |||
| {'exception': TypeError, 'error_keywords': ['AddN']}), | |||
| 'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.uint32)), Tensor(np.ones([3, 2]).astype(np.int32)))], | |||
| 'skip': ['backward']}), | |||
| # shape not match | |||
| ('AddN3', { | |||
| 'block': (P.AddN(), | |||
| {'exception': ValueError, 'error_keywords': ['AddN']}), | |||
| 'desc_inputs': [(Tensor(np.ones([2, 3]).astype(np.int32)), Tensor(np.ones([3, 2]).astype(np.int32)))], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Neg1', { | |||
| 'block': (P.Neg(), | |||
| {'exception': TypeError, 'error_keywords': ['Neg']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Sub0', { | |||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Sub1', { | |||
| 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Sub2', { | |||
| 'block': (P.Sub(), {'exception': ValueError, 'error_keywords': ['Sub']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Mul0', { | |||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Mul1', { | |||
| 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Mul2', { | |||
| 'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Square1', { | |||
| 'block': (P.Square(), | |||
| {'exception': TypeError, 'error_keywords': ['Square']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Rsqrt1', { | |||
| 'block': (P.Rsqrt(), | |||
| {'exception': TypeError, 'error_keywords': ['Rsqrt']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Sqrt1', { | |||
| 'block': (P.Sqrt(), | |||
| {'exception': TypeError, 'error_keywords': ['Sqrt']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not Tensor | |||
| ('Reciprocal1', { | |||
| 'block': (P.Reciprocal(), | |||
| {'exception': TypeError, 'error_keywords': ['Reciprocal']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input x is Tensor(bool) | |||
| ('Pow1', { | |||
| 'block': (P.Pow(), | |||
| {'exception': TypeError, 'error_keywords': ['Pow']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_)), 2.0], | |||
| 'skip': ['backward']}), | |||
| # input is not Tensor | |||
| ('Exp1', { | |||
| 'block': (P.Exp(), | |||
| {'exception': TypeError, 'error_keywords': ['Exp']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is not Tensor | |||
| ('Log1', { | |||
| 'block': (P.Log(), | |||
| {'exception': TypeError, 'error_keywords': ['Log']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Minimum0', { | |||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Minimum1', { | |||
| 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Minimum2', { | |||
| 'block': (P.Minimum(), {'exception': ValueError, 'error_keywords': ['Minimum']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Maximum0', { | |||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Maximum1', { | |||
| 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Maximum2', { | |||
| 'block': (P.Maximum(), {'exception': ValueError, 'error_keywords': ['Maximum']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('RealDiv0', { | |||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('RealDiv1', { | |||
| 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('RealDiv2', { | |||
| 'block': (P.RealDiv(), {'exception': ValueError, 'error_keywords': ['RealDiv']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Div0', { | |||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Div1', { | |||
| 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Div2', { | |||
| 'block': (P.Div(), {'exception': ValueError, 'error_keywords': ['Div']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('FloorDiv0', { | |||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('FloorDiv1', { | |||
| 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('FloorDiv2', { | |||
| 'block': (P.FloorDiv(), {'exception': ValueError, 'error_keywords': ['FloorDiv']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input x is Tensor(int32), not Tensor(float) | |||
| ('Floor1', { | |||
| 'block': (P.Floor(), | |||
| {'exception': TypeError, 'error_keywords': ['Floor']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('FloorMod0', { | |||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('FloorMod1', { | |||
| 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('FFloorMod2', { | |||
| 'block': (P.FloorMod(), {'exception': ValueError, 'error_keywords': ['FloorMod']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input x is Tensor(int32), not Tensor(float) | |||
| ('Acosh1', { | |||
| 'block': (P.Acosh(), | |||
| {'exception': TypeError, 'error_keywords': ['Acosh']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Equal0', { | |||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Equal1', { | |||
| 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('Equal2', { | |||
| 'block': (P.Equal(), {'exception': ValueError, 'error_keywords': ['Equal']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('EqualCount0', { | |||
| 'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('EqualCount1', { | |||
| 'block': (P.EqualCount(), {'exception': TypeError, 'error_keywords': ['EqualCount']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| # input is not tensor | |||
| ('NotEqual0', { | |||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('NotEqual1', { | |||
| 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('NotEqual2', { | |||
| 'block': (P.NotEqual(), {'exception': ValueError, 'error_keywords': ['NotEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Greater0', { | |||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Greater1', { | |||
| 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('Greater2', { | |||
| 'block': (P.Greater(), {'exception': ValueError, 'error_keywords': ['Greater']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('GreaterEqual0', { | |||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('GreaterEqual1', { | |||
| 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('GreaterEqual2', { | |||
| 'block': (P.GreaterEqual(), {'exception': ValueError, 'error_keywords': ['GreaterEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Less0', { | |||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('Less1', { | |||
| 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('Less2', { | |||
| 'block': (P.Less(), {'exception': ValueError, 'error_keywords': ['Less']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('LessEqual0', { | |||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('LessEqual1', { | |||
| 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('LessEqual2', { | |||
| 'block': (P.LessEqual(), {'exception': ValueError, 'error_keywords': ['LessEqual']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input x is not Tensor(bool) | |||
| ('LogicalNot1', { | |||
| 'block': (P.LogicalNot(), | |||
| {'exception': TypeError, 'error_keywords': ['LogicalNot']}), | |||
| 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('LogicalAnd1', { | |||
| 'block': (P.LogicalAnd(), {'exception': TypeError, 'error_keywords': ['LogicalAnd']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('LogicalAnd2', { | |||
| 'block': (P.LogicalAnd(), {'exception': ValueError, 'error_keywords': ['LogicalAnd']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # type of x and y not match | |||
| ('LogicalOr1', { | |||
| 'block': (P.LogicalOr(), {'exception': TypeError, 'error_keywords': ['LogicalOr']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # shape of x and y not match | |||
| ('LogicalOr2', { | |||
| 'block': (P.LogicalOr(), {'exception': ValueError, 'error_keywords': ['LogicalOr']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_)), Tensor(np.ones([3, 2]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('NPUGetFloatStatus0', { | |||
| 'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(int32), not Tensor(float32) | |||
| ('NPUGetFloatStatus1', { | |||
| 'block': (P.NPUGetFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUGetFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # dims is not 1 | |||
| ('NPUGetFloatStatus2', { | |||
| 'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape[0] is not 8 | |||
| ('NPUGetFloatStatus3', { | |||
| 'block': (P.NPUGetFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUGetFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('NPUClearFloatStatus0', { | |||
| 'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(int32), not Tensor(float32) | |||
| ('NPUClearFloatStatus1', { | |||
| 'block': (P.NPUClearFloatStatus(), {'exception': TypeError, 'error_keywords': ['NPUClearFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # dims is not 1 | |||
| ('NPUClearFloatStatus2', { | |||
| 'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape[0] is not 8 | |||
| ('NPUClearFloatStatus3', { | |||
| 'block': (P.NPUClearFloatStatus(), {'exception': ValueError, 'error_keywords': ['NPUClearFloatStatus']}), | |||
| 'desc_inputs': [Tensor(np.ones([3]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Cos0', { | |||
| 'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Cos1', { | |||
| 'block': (P.Cos(), {'exception': TypeError, 'error_keywords': ['Cos']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('ACos0', { | |||
| 'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('ACos1', { | |||
| 'block': (P.ACos(), {'exception': TypeError, 'error_keywords': ['ACos']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Sin0', { | |||
| 'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Sin1', { | |||
| 'block': (P.Sin(), {'exception': TypeError, 'error_keywords': ['Sin']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('NMSWithMask0', { | |||
| 'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is not Tensor(float16) or Tensor(float32) | |||
| ('NMSWithMask1', { | |||
| 'block': (P.NMSWithMask(), {'exception': TypeError, 'error_keywords': ['NMSWithMask']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], | |||
| 'skip': ['backward']}), | |||
| # dims is not 2 | |||
| ('NMSWithMask2', { | |||
| 'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # shape[1] is not 5 | |||
| ('NMSWithMask3', { | |||
| 'block': (P.NMSWithMask(), {'exception': ValueError, 'error_keywords': ['NMSWithMask']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 2]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Abs0', { | |||
| 'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Abs1', { | |||
| 'block': (P.Abs(), {'exception': TypeError, 'error_keywords': ['Abs']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Sign0', { | |||
| 'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Sign1', { | |||
| 'block': (P.Sign(), {'exception': TypeError, 'error_keywords': ['Sign']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # input is not tensor | |||
| ('Round0', { | |||
| 'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}), | |||
| 'desc_inputs': [5.0], | |||
| 'skip': ['backward']}), | |||
| # input is Tensor(bool) | |||
| ('Round1', { | |||
| 'block': (P.Round(), {'exception': TypeError, 'error_keywords': ['Round']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], | |||
| 'skip': ['backward']}), | |||
| # one input is scalar, and another is Tensor(float32) | |||
| ('Atan20', { | |||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||
| 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, but element types are not same | |||
| ('Atan21', { | |||
| 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| # input two tensors, their shapes do not match | |||
| ('Atan22', { | |||
| 'block': (P.Atan2(), {'exception': ValueError, 'error_keywords': ['Atan2']}), | |||
| 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) | |||
| def test_check_exception(): | |||
| return raise_set | |||