| @@ -23,7 +23,7 @@ from .. import signature as sig | |||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | |||||
| from ...common.tensor import Tensor, MetaTensor | |||||
| from .._utils import get_broadcast_shape | from .._utils import get_broadcast_shape | ||||
| from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op | ||||
| @@ -2324,9 +2324,13 @@ class Equal(_LogicBinaryOp): | |||||
| return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) | return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) | ||||
| def infer_value(self, x, y): | def infer_value(self, x, y): | ||||
| if x is not None and y is not None: | |||||
| return Tensor(x.asnumpy() == y.asnumpy()) | |||||
| return None | |||||
| if x is None or y is None: | |||||
| return None | |||||
| if isinstance(x, MetaTensor): | |||||
| x = x.to_tensor() | |||||
| if isinstance(y, MetaTensor): | |||||
| y = y.to_tensor() | |||||
| return Tensor(x.asnumpy() == y.asnumpy()) | |||||
| class ApproximateEqual(_LogicBinaryOp): | class ApproximateEqual(_LogicBinaryOp): | ||||