|
|
|
@@ -661,6 +661,20 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o |
|
|
|
// out = op(cell1(x, y)) |
|
|
|
// out = op(cell1(x, y)[0]) |
|
|
|
node = GetObjNode(obj); |
|
|
|
} else if (py::isinstance<py::tuple>(obj)) { |
|
|
|
// out = op((x, y)) |
|
|
|
// out = cell((x, y)) |
|
|
|
std::vector<AnfNodePtr> args; |
|
|
|
args.push_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
|
|
|
|
auto tuple = obj.cast<py::tuple>(); |
|
|
|
auto tuple_size = static_cast<int>(tuple.size()); |
|
|
|
for (int i = 0; i < tuple_size; i++) { |
|
|
|
args.push_back(GetInput(tuple[i], py::object())); |
|
|
|
} |
|
|
|
auto cnode = curr_g_->NewCNode(args); |
|
|
|
set_obj_node_map(curr_g_, GetId(obj), cnode); |
|
|
|
node = cnode; |
|
|
|
} else { |
|
|
|
// out = op(x, 1) |
|
|
|
ValuePtr converted_ret = nullptr; |
|
|
|
@@ -728,6 +742,13 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c |
|
|
|
} |
|
|
|
auto out_cnode = curr_g_->NewCNode(inputs); |
|
|
|
set_pyobj(curr_g_, GetId(cell)); |
|
|
|
if (py::isinstance<py::tuple>(out)) { |
|
|
|
auto out_list = py::cast<py::tuple>(out); |
|
|
|
auto out_size = static_cast<int>(out_list.size()); |
|
|
|
for (int i = 0; i < out_size; i++) { |
|
|
|
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); |
|
|
|
} |
|
|
|
} |
|
|
|
set_obj_node_map(curr_g_, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
parse::ResolveFuncGraph(newfg, resource_); |
|
|
|
|