| @@ -62,6 +62,16 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||||
| return kernel_with_index; | return kernel_with_index; | ||||
| } | } | ||||
| void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, | |||||
| const size_t input_index) { | |||||
| // record the ref_pair | |||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| session::AnfWithOutIndex final_pair = std::make_pair(cnode, output_index); | |||||
| session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0); | |||||
| kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); | |||||
| } | |||||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | ||||
| const AnfNodePtr &final_node, size_t final_index, | const AnfNodePtr &final_node, size_t final_index, | ||||
| const session::KernelWithIndex &origin_pair) { | const session::KernelWithIndex &origin_pair) { | ||||
| @@ -88,6 +98,7 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||||
| AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | ||||
| size_t input_index, const AnfNodePtr &get_item) { | size_t input_index, const AnfNodePtr &get_item) { | ||||
| AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | ||||
| bool need_refresh_ref_addr = false; | |||||
| size_t final_index = output_index; | size_t final_index = output_index; | ||||
| AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | ||||
| session::KernelWithIndex origin_pair; | session::KernelWithIndex origin_pair; | ||||
| @@ -109,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | ||||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); | RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); | ||||
| final_index = 0; | final_index = 0; | ||||
| need_refresh_ref_addr = true; | |||||
| MS_EXCEPTION_IF_NULL(final_node); | MS_EXCEPTION_IF_NULL(final_node); | ||||
| MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); | ||||
| } | } | ||||
| @@ -119,15 +131,19 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| MS_EXCEPTION_IF_NULL(final_node); | MS_EXCEPTION_IF_NULL(final_node); | ||||
| final_node->set_scope(cnode->scope()); | final_node->set_scope(cnode->scope()); | ||||
| final_index = 0; | final_index = 0; | ||||
| need_refresh_ref_addr = true; | |||||
| MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); | MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); | ||||
| } | } | ||||
| // add ref pair | // add ref pair | ||||
| AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); | AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); | ||||
| if (need_refresh_ref_addr) { | |||||
| AddRefNodePairToKernelGraph(func_graph, cnode, output_index, input_index); | |||||
| } | |||||
| // insert depend | // insert depend | ||||
| if (origin_format != cur_format || origin_type != cur_type) { | if (origin_format != cur_format || origin_type != cur_type) { | ||||
| std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; | std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; | ||||
| final_node = func_graph->NewCNode(depend_nodes); | final_node = func_graph->NewCNode(depend_nodes); | ||||
| MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); | |||||
| MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); | |||||
| } | } | ||||
| return final_node; | return final_node; | ||||