| @@ -31,8 +31,9 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||
| session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); | |||
| AnfNodePtr cur_node = kernel_with_index.first; | |||
| size_t cur_out_index = kernel_with_index.second; | |||
| MS_EXCEPTION_IF_NULL(cur_node); | |||
| if (cur_node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto cnode = cur_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string op_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); | |||
| @@ -88,7 +89,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| size_t input_index, const AnfNodePtr &get_item) { | |||
| AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||
| size_t final_index = output_index; | |||
| AnfNodePtr input_node = cnode->input(input_index + 1); | |||
| AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | |||
| session::KernelWithIndex origin_pair; | |||
| origin_pair = FindRefOriginNode(input_node); | |||
| MS_EXCEPTION_IF_NULL(origin_pair.first); | |||
| @@ -133,6 +134,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| } | |||
| AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| auto ref_infos = op_info->ref_infos(); | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| AbstractBasePtrList abstract_list; | |||
| @@ -144,9 +146,11 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| auto input_index = ref_infos.at(output_index); | |||
| final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(final_node); | |||
| abstract_list.push_back(final_node->abstract()); | |||
| make_tuple_inputs.push_back(final_node); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| @@ -155,6 +159,8 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| auto ref_infos = op_info->ref_infos(); | |||
| for (const auto &ref_info : ref_infos) { | |||
| if (ref_info.second > cnode->inputs().size()) { | |||
| @@ -206,7 +212,9 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A | |||
| return nullptr; | |||
| } | |||
| if (op_info->is_ref()) { | |||
| if (!cnode->Type()->isa<Tuple>()) { | |||
| auto type = cnode->Type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (!type->isa<Tuple>()) { | |||
| return DealRefSigleOutput(graph, cnode, op_info); | |||
| } else { | |||
| return DealRefForMultipleOutput(graph, cnode, op_info); | |||