|
|
|
@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, y_shape): |
|
|
|
return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) |
|
|
|
return _get_broadcast_shape(x_shape, y_shape, self.name) |
|
|
|
|
|
|
|
|
|
|
|
class _MathBinaryOp(_BinaryOp): |
|
|
|
@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp): |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) |
|
|
|
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name) |
|
|
|
|
|
|
|
|
|
|
|
class TensorAdd(_MathBinaryOp): |
|
|
|
@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, variable, value): |
|
|
|
args = {"value": value} |
|
|
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, variable, value): |
|
|
|
args = {"value": value} |
|
|
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name) |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, keep_dims=False): |
|
|
|
"""init Reduce""" |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) |
|
|
|
|
|
|
|
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): |
|
|
|
axis_v = axis['value'] |
|
|
|
input_shp = input_x['shape'] |
|
|
|
args = {'input_x': input_x['dtype']} |
|
|
|
validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) |
|
|
|
validator.check_tensor_type_same(args, valid_dtype, self.name) |
|
|
|
|
|
|
|
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) |
|
|
|
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) |
|
|
|
return {'shape': input_shp, |
|
|
|
'dtype': input_x['dtype'], |
|
|
|
'value': None} |
|
|
|
@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, exclusive=False, reverse=False): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) |
|
|
|
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) |
|
|
|
|
|
|
|
@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type, axis_type): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) |
|
|
|
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) |
|
|
|
return x_type |
|
|
|
@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer): |
|
|
|
def __init__(self, transpose_a=False, transpose_b=False): |
|
|
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) |
|
|
|
self.__setattr_flag__ = True |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) |
|
|
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) |
|
|
|
|
|
|
|
@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_shape(self, x, y): |
|
|
|
self.check_shape_size(x, y) |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two |
|
|
|
for i in range(len(x) - 2): |
|
|
|
if x[i] != y[i]: |
|
|
|
@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, x, y): |
|
|
|
args = {"x": x, "y": y} |
|
|
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -590,7 +590,7 @@ class BatchMatMul(MatMul): |
|
|
|
def __init__(self, transpose_a=False, transpose_b=False): |
|
|
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) |
|
|
|
self.__setattr_flag__ = True |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) |
|
|
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) |
|
|
|
|
|
|
|
@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, exclusive=False, reverse=False): |
|
|
|
"""init cumsum""" |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_value_type('exclusive', exclusive, [bool], cls_name) |
|
|
|
validator.check_value_type('reverse', reverse, [bool], cls_name) |
|
|
|
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) |
|
|
|
|
|
|
|
def __infer__(self, x, axis): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
x_shp = x['shape'] |
|
|
|
validator.check_value_type('axis', axis['value'], [int], cls_name) |
|
|
|
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] |
|
|
|
@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) |
|
|
|
|
|
|
|
def infer_shape(self, inputs): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) |
|
|
|
self.add_prim_attr('n', len(inputs)) |
|
|
|
shp0 = inputs[0] |
|
|
|
@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer): |
|
|
|
return shp0 |
|
|
|
|
|
|
|
def infer_dtype(self, inputs): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_value_type("inputs", inputs, [tuple, list], cls_name) |
|
|
|
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) |
|
|
|
args = {} |
|
|
|
@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer): |
|
|
|
return input_x |
|
|
|
|
|
|
|
def infer_dtype(self, input_x): |
|
|
|
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) |
|
|
|
return input_x |
|
|
|
|
|
|
|
|
|
|
|
@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) |
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Pow(PrimitiveWithInfer): |
|
|
|
class Pow(_MathBinaryOp): |
|
|
|
""" |
|
|
|
Computes a tensor to the power of the second input. |
|
|
|
|
|
|
|
The first input must be a tensor, and the second input should be a tensor or a number. |
|
|
|
When the inputs are two tensors, the shapes of them could be broadcast, |
|
|
|
and the data types of them should be the same. |
|
|
|
When the inputs are one tensor and one scalar, the scalar could not be a parameter, |
|
|
|
only could be a constant, and the type of the scalar is the same as the data type of the tensor. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number. |
|
|
|
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or |
|
|
|
a number. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - The input tensor. |
|
|
|
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to |
|
|
|
@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer): |
|
|
|
[1.0, 16.0, 64.0] |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self): |
|
|
|
"""init Multiply""" |
|
|
|
|
|
|
|
def infer_shape(self, x, power): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x, power): |
|
|
|
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name()) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Exp(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) |
|
|
|
validator.check_subclass("x", x_type, mstype.tensor, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) |
|
|
|
validator.check_subclass("x", x, mstype.tensor, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp): |
|
|
|
return mstype.tensor_type(mstype.bool_) |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name) |
|
|
|
|
|
|
|
|
|
|
|
class Equal(_LogicBinaryOp): |
|
|
|
@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp): |
|
|
|
""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
|
|
|
|
|
|
|
|
class EqualCount(PrimitiveWithInfer): |
|
|
|
@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
args = {'x': x_dtype, 'y': y_dtype} |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) |
|
|
|
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp): |
|
|
|
""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) |
|
|
|
|
|
|
|
|
|
|
|
class Greater(_LogicBinaryOp): |
|
|
|
@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name) |
|
|
|
return mstype.tensor_type(mstype.bool_) |
|
|
|
|
|
|
|
|
|
|
|
@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp): |
|
|
|
""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) |
|
|
|
|
|
|
|
|
|
|
|
class LogicalOr(_LogicBinaryOp): |
|
|
|
@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp): |
|
|
|
""" |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype, y_dtype): |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) |
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name) |
|
|
|
|
|
|
|
class IsNan(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr("_side_effect_flag", True) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) |
|
|
|
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) |
|
|
|
return [8] |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name) |
|
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr("_side_effect_flag", True) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) |
|
|
|
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) |
|
|
|
return [8] |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name) |
|
|
|
return mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer): |
|
|
|
return x |
|
|
|
|
|
|
|
def infer_dtype(self, x): |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, iou_threshold=0.5): |
|
|
|
"""Init NMSWithMask""" |
|
|
|
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) |
|
|
|
validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) |
|
|
|
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) |
|
|
|
|
|
|
|
def infer_shape(self, bboxes_shape): |
|
|
|
cls_name = self.prim_name() |
|
|
|
cls_name = self.name |
|
|
|
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) |
|
|
|
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) |
|
|
|
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) |
|
|
|
@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer): |
|
|
|
return (bboxes_shape, (num,), (num,)) |
|
|
|
|
|
|
|
def infer_dtype(self, bboxes_dtype): |
|
|
|
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) |
|
|
|
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name) |
|
|
|
return (bboxes_dtype, mstype.int32, mstype.bool_) |
|
|
|
|
|
|
|
|
|
|
|
@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
def infer_value(self, x): |
|
|
|
@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name) |
|
|
|
return x_dtype |
|
|
|
|
|
|
|
|
|
|
|
@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) |
|
|
|
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) |
|
|
|
return x_type |
|
|
|
|
|
|
|
|
|
|
|
|