| @@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_ | |||||
| from ..._c_expression import signature_rw as sig_rw | from ..._c_expression import signature_rw as sig_rw | ||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..._c_expression import signature_dtype as sig_dtype | from ..._c_expression import signature_dtype as sig_dtype | ||||
| from ..._c_expression import typing | |||||
| def _check_infer_attr_reduce(axis, keep_dims, prim_name): | def _check_infer_attr_reduce(axis, keep_dims, prim_name): | ||||
| validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) | ||||
| @@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer): | |||||
| data = x.default_input | data = x.default_input | ||||
| if data.dtype == dtype: | if data.dtype == dtype: | ||||
| return (True, x) | return (True, x) | ||||
| return (False, None) | |||||
| raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") | |||||
| return (False, None) | |||||
| def __infer__(self, x, t): | def __infer__(self, x, t): | ||||
| src_type = x['dtype'] | src_type = x['dtype'] | ||||
| @@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer): | |||||
| def check_elim(self, base_tensor, multiplier): | def check_elim(self, base_tensor, multiplier): | ||||
| if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): | if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): | ||||
| raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) | |||||
| def is_all_zeros(v_tuple): | |||||
| return all(v == 1 for v in v_tuple) | |||||
| if is_all_zeros(multiplier): | |||||
| raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) | |||||
| if all(v == 1 for v in multiplier): | |||||
| return (True, base_tensor) | return (True, base_tensor) | ||||
| return (False, None) | return (False, None) | ||||
| @@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer): | |||||
| validator.check_value_type("shape", multiples_v, [tuple], self.name) | validator.check_value_type("shape", multiples_v, [tuple], self.name) | ||||
| for i, multiple in enumerate(multiples_v): | for i, multiple in enumerate(multiples_v): | ||||
| validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) | validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) | ||||
| valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32] | |||||
| validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name) | |||||
| validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name) | |||||
| len_sub = len(multiples_v) - len(x_shp) | len_sub = len(multiples_v) - len(x_shp) | ||||
| multiples_w = None | multiples_w = None | ||||
| if len_sub == 0: | if len_sub == 0: | ||||