diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc index 4323ac582c..f669792d12 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,12 +34,12 @@ namespace mindspore { namespace opt { namespace { AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { - auto kernel_graph = std::dynamic_pointer_cast(anf_node->func_graph()); - MS_EXCEPTION_IF_NULL(kernel_graph); + auto func_graph = anf_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); TraceGuard guard(std::make_shared(cnode->debug_info())); - CNodePtr node = kernel_graph->NewCNode(cnode->inputs()); + CNodePtr node = func_graph->NewCNode(cnode->inputs()); node->set_abstract(cnode->abstract()); node->set_forward(cnode->forward().first, cnode->forward().second); node->set_inputs_value(cnode->inputs_value()); @@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { changed = true; } } - - mng->RemoveRoots(); - mng->KeepRoots({func_graph}); + if (changed) { + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + } return changed; } bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + + auto todos = TopoSort(func_graph->get_return()); bool result = false; - bool changed; - do { - changed = Process(func_graph); - result |= changed; - } while (changed); + for (const auto &anf_node : todos) { + if (AnfAlgo::IsGraphKernel(anf_node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node); + bool changed = false; + do { + changed = Process(sub_graph); + result = result || changed; + } while (changed); + } + } + + if (result) { + mng->RemoveRoots(); + mng->KeepRoots({func_graph}); + } return result; } } // namespace opt diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index d843e00e96..236c70f197 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -185,14 +185,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ std::vector duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; pm->AddPass(std::make_shared()); // Make more fusion opportunity. pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(duplicated_ops)); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared(duplicated_ops)); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared(duplicated_ops)); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); // The CSE may output a graph with repeated outputs.