Browse Source

fix tuple args issue in pynative

tags/v0.5.0-beta
kingfo 5 years ago
parent
commit
557a404282
3 changed files with 26 additions and 2 deletions
  1. +21
    -0
      mindspore/ccsrc/pynative/pynative_execute.cc
  2. +3
    -0
      mindspore/common/tensor.py
  3. +2
    -2
      mindspore/ops/operations/array_ops.py

+ 21
- 0
mindspore/ccsrc/pynative/pynative_execute.cc View File

@@ -661,6 +661,20 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
// out = op(cell1(x, y))
// out = op(cell1(x, y)[0])
node = GetObjNode(obj);
} else if (py::isinstance<py::tuple>(obj)) {
// out = op((x, y))
// out = cell((x, y))
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));

auto tuple = obj.cast<py::tuple>();
auto tuple_size = static_cast<int>(tuple.size());
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], py::object()));
}
auto cnode = curr_g_->NewCNode(args);
set_obj_node_map(curr_g_, GetId(obj), cnode);
node = cnode;
} else {
// out = op(x, 1)
ValuePtr converted_ret = nullptr;
@@ -728,6 +742,13 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetId(cell));
if (py::isinstance<py::tuple>(out)) {
auto out_list = py::cast<py::tuple>(out);
auto out_size = static_cast<int>(out_list.size());
for (int i = 0; i < out_size; i++) {
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
}
}
set_obj_node_map(curr_g_, GetId(out), out_cnode);
} else {
parse::ResolveFuncGraph(newfg, resource_);


+ 3
- 0
mindspore/common/tensor.py View File

@@ -19,6 +19,7 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor
from .._checkparam import check_type, check_typename
from . import dtype as mstype
from .. import context
from ._register_for_tensor import tensor_operator_registry

__all__ = ['Tensor', 'MetaTensor']
@@ -77,6 +78,8 @@ class Tensor(Tensor_):
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_:
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)

def __ne__(self, other):


+ 2
- 2
mindspore/ops/operations/array_ops.py View File

@@ -145,8 +145,8 @@ class SameTypeShape(PrimitiveWithInfer):

def __call__(self, x, y):
"""run in PyNative mode"""
validator.check_subclass('x', x.dtype(), mstype.tensor, self.name)
validator.check_subclass('y', y.dtype(), mstype.tensor, self.name)
validator.check_value_type("x", x, Tensor, self.name)
validator.check_value_type("y", y, Tensor, self.name)
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError)
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name)
return x


Loading…
Cancel
Save