|
|
|
@@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_ |
|
|
|
from ..._c_expression import signature_rw as sig_rw |
|
|
|
from ..._c_expression import signature_kind as sig_kind |
|
|
|
from ..._c_expression import signature_dtype as sig_dtype |
|
|
|
from ..._c_expression import typing |
|
|
|
|
|
|
|
def _check_infer_attr_reduce(axis, keep_dims, prim_name): |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) |
|
|
|
@@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer): |
|
|
|
data = x.default_input |
|
|
|
if data.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
return (False, None) |
|
|
|
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") |
|
|
|
return (False, None) |
|
|
|
|
|
|
|
def __infer__(self, x, t): |
|
|
|
src_type = x['dtype'] |
|
|
|
@@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def check_elim(self, base_tensor, multiplier): |
|
|
|
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): |
|
|
|
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) |
|
|
|
def is_all_zeros(v_tuple): |
|
|
|
return all(v == 1 for v in v_tuple) |
|
|
|
if is_all_zeros(multiplier): |
|
|
|
raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) |
|
|
|
if all(v == 1 for v in multiplier): |
|
|
|
return (True, base_tensor) |
|
|
|
return (False, None) |
|
|
|
|
|
|
|
@@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer): |
|
|
|
validator.check_value_type("shape", multiples_v, [tuple], self.name) |
|
|
|
for i, multiple in enumerate(multiples_v): |
|
|
|
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) |
|
|
|
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32] |
|
|
|
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name) |
|
|
|
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name) |
|
|
|
len_sub = len(multiples_v) - len(x_shp) |
|
|
|
multiples_w = None |
|
|
|
if len_sub == 0: |
|
|
|
|