Browse Source

!11622 【GraphKernel】Moved ShapeOpsSplitter before GraphKernelSplitter

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @gaoxiong1
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
9efbef72fc
2 changed files with 35 additions and 16 deletions
  1. +31
    -12
      mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc
  2. +4
    -4
      mindspore/ccsrc/backend/session/gpu_session.cc

+ 31
- 12
mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -34,12 +34,12 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
auto func_graph = anf_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info())); TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
CNodePtr node = func_graph->NewCNode(cnode->inputs());
node->set_abstract(cnode->abstract()); node->set_abstract(cnode->abstract());
node->set_forward(cnode->forward().first, cnode->forward().second); node->set_forward(cnode->forward().first, cnode->forward().second);
node->set_inputs_value(cnode->inputs_value()); node->set_inputs_value(cnode->inputs_value());
@@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
changed = true; changed = true;
} }
} }

mng->RemoveRoots();
mng->KeepRoots({func_graph});
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed; return changed;
} }


bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}

auto todos = TopoSort(func_graph->get_return());
bool result = false; bool result = false;
bool changed;
do {
changed = Process(func_graph);
result |= changed;
} while (changed);
for (const auto &anf_node : todos) {
if (AnfAlgo::IsGraphKernel(anf_node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
bool changed = false;
do {
changed = Process(sub_graph);
result = result || changed;
} while (changed);
}
}

if (result) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return result; return result;
} }
} // namespace opt } // namespace opt


+ 4
- 4
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -185,14 +185,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity. pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity.
pm->AddPass(std::make_shared<opt::GraphKernelExpander>()); pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::BasicOpsFusion>()); pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>()); pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>()); pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>()); pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::TensorPromotion>()); pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>()); pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>()); pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
// The CSE may output a graph with repeated outputs. // The CSE may output a graph with repeated outputs.


Loading…
Cancel
Save