|
|
|
@@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar |
|
|
|
// args: tuple of items, index |
|
|
|
const std::string op_name = std::string("TupleGetItemTensor"); |
|
|
|
abstract::CheckArgsSize(op_name, args_spec_list, 2); |
|
|
|
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0); |
|
|
|
AbstractBasePtrList branches = branches_abs->elements(); |
|
|
|
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { |
|
|
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); |
|
|
|
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
AnfNodePtr functions = ret_graph->add_parameter(); |
|
|
|
auto index = ret_graph->add_parameter(); |
|
|
|
auto ret_graph = std::make_shared<FuncGraph>(); |
|
|
|
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); |
|
|
|
auto functions = ret_graph->add_parameter(); |
|
|
|
auto index = ret_graph->add_parameter(); |
|
|
|
|
|
|
|
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); |
|
|
|
return ret_graph; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; |
|
|
|
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); |
|
|
|
return ret_graph; |
|
|
|
} |
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { |
|
|
|
|