|
|
@@ -1,5 +1,5 @@ |
|
|
/** |
|
|
/** |
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
|
|
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
* |
|
|
* |
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
* you may not use this file except in compliance with the License. |
|
|
* you may not use this file except in compliance with the License. |
|
|
@@ -34,12 +34,12 @@ namespace mindspore { |
|
|
namespace opt { |
|
|
namespace opt { |
|
|
namespace { |
|
|
namespace { |
|
|
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { |
|
|
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { |
|
|
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph()); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
|
|
|
|
|
auto func_graph = anf_node->func_graph(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info())); |
|
|
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info())); |
|
|
CNodePtr node = kernel_graph->NewCNode(cnode->inputs()); |
|
|
|
|
|
|
|
|
CNodePtr node = func_graph->NewCNode(cnode->inputs()); |
|
|
node->set_abstract(cnode->abstract()); |
|
|
node->set_abstract(cnode->abstract()); |
|
|
node->set_forward(cnode->forward().first, cnode->forward().second); |
|
|
node->set_forward(cnode->forward().first, cnode->forward().second); |
|
|
node->set_inputs_value(cnode->inputs_value()); |
|
|
node->set_inputs_value(cnode->inputs_value()); |
|
|
@@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) { |
|
|
changed = true; |
|
|
changed = true; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
mng->RemoveRoots(); |
|
|
|
|
|
mng->KeepRoots({func_graph}); |
|
|
|
|
|
|
|
|
if (changed) { |
|
|
|
|
|
mng->RemoveRoots(); |
|
|
|
|
|
mng->KeepRoots({func_graph}); |
|
|
|
|
|
} |
|
|
return changed; |
|
|
return changed; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { |
|
|
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
|
|
auto mng = func_graph->manager(); |
|
|
|
|
|
if (mng == nullptr) { |
|
|
|
|
|
mng = Manage(func_graph, true); |
|
|
|
|
|
func_graph->set_manager(mng); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto todos = TopoSort(func_graph->get_return()); |
|
|
bool result = false; |
|
|
bool result = false; |
|
|
bool changed; |
|
|
|
|
|
do { |
|
|
|
|
|
changed = Process(func_graph); |
|
|
|
|
|
result |= changed; |
|
|
|
|
|
} while (changed); |
|
|
|
|
|
|
|
|
for (const auto &anf_node : todos) { |
|
|
|
|
|
if (AnfAlgo::IsGraphKernel(anf_node)) { |
|
|
|
|
|
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node); |
|
|
|
|
|
bool changed = false; |
|
|
|
|
|
do { |
|
|
|
|
|
changed = Process(sub_graph); |
|
|
|
|
|
result = result || changed; |
|
|
|
|
|
} while (changed); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (result) { |
|
|
|
|
|
mng->RemoveRoots(); |
|
|
|
|
|
mng->KeepRoots({func_graph}); |
|
|
|
|
|
} |
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
} // namespace opt |
|
|
} // namespace opt |
|
|
|