From a423320d58d6b33beb3ab891a5079da142dba9de Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 17 Nov 2020 15:44:06 +0800 Subject: [PATCH] Support MetaTensor in Equal's infer_value --- mindspore/ops/operations/math_ops.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 27fd2f4a45..9eb09c2ef3 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -23,7 +23,7 @@ from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype -from ...common.tensor import Tensor +from ...common.tensor import Tensor, MetaTensor from .._utils import get_broadcast_shape from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op @@ -2324,9 +2324,13 @@ class Equal(_LogicBinaryOp): return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name) def infer_value(self, x, y): - if x is not None and y is not None: - return Tensor(x.asnumpy() == y.asnumpy()) - return None + if x is None or y is None: + return None + if isinstance(x, MetaTensor): + x = x.to_tensor() + if isinstance(y, MetaTensor): + y = y.to_tensor() + return Tensor(x.asnumpy() == y.asnumpy()) class ApproximateEqual(_LogicBinaryOp):