Browse Source

Support MetaTensor in Equal's infer_value

tags/v1.1.0
huanghui 5 years ago
parent
commit
a423320d58
1 changed files with 8 additions and 4 deletions
  1. +8
    -4
      mindspore/ops/operations/math_ops.py

+ 8
- 4
mindspore/ops/operations/math_ops.py View File

@@ -23,7 +23,7 @@ from .. import signature as sig
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common.tensor import Tensor, MetaTensor
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op

@@ -2324,9 +2324,13 @@ class Equal(_LogicBinaryOp):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)

def infer_value(self, x, y):
if x is not None and y is not None:
return Tensor(x.asnumpy() == y.asnumpy())
return None
if x is None or y is None:
return None
if isinstance(x, MetaTensor):
x = x.to_tensor()
if isinstance(y, MetaTensor):
y = y.to_tensor()
return Tensor(x.asnumpy() == y.asnumpy())


class ApproximateEqual(_LogicBinaryOp):


Loading…
Cancel
Save