diff --git a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc index 9d7eec3233..21157c9224 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc @@ -75,20 +75,23 @@ void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector< group++; } -void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector &inplace_nodes) { +void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector &inplace_nodes, + const AnfNodePtr aggregate_node) { std::vector inputs1 = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), inplace_nodes[0].node, inplace_nodes[1].node}; auto control_depend_node = graph->NewCNode(inputs1); - auto return_node = graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - // mount the `depend` before make_tuple, otherwise the output of graph will be `(tensor, )` rather than `tensor` - auto return_input = return_node->input(kFirstDataInputIndex)->cast(); - MS_EXCEPTION_IF_NULL(return_input); std::vector inputs2 = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_input->input(kFirstDataInputIndex), control_depend_node}; + aggregate_node, control_depend_node}; auto depend_node = graph->NewCNode(inputs2); - return_node->set_input(kFirstDataInputIndex, depend_node); + + auto users = GetRealNodeUsedList(graph, aggregate_node); + if (users->size() == 0) { + MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString(); + } + auto mount_node = users->at(0).first->cast(); + MS_EXCEPTION_IF_NULL(mount_node); + mount_node->set_input(kFirstDataInputIndex, depend_node); } bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, @@ -186,7 +189,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { // 2. Set Node attr SetNodeAttr(aggregate_node, skip_node, &inplace_node); // 3. Set dependence for inplace nodes - InsertControlDependToGraph(graph, inplace_node); + InsertControlDependToGraph(graph, inplace_node, aggregate_node.node); } return true; diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 38091cf110..de3f788a96 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -108,6 +108,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_gra pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared("reduce_precision"));