|
|
|
@@ -27,6 +27,69 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
constexpr auto kSingleInputIndex = 1; |
|
|
|
namespace { |
|
|
|
AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
string op_name = AnfAlgo::GetCNodeName(cnode); |
|
|
|
// Currently we only eliminate transdata or cast nodes. |
|
|
|
if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
// Check whether the node has only one output node. |
|
|
|
if (manager->node_users().find(cnode) == manager->node_users().end()) { |
|
|
|
MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; |
|
|
|
} |
|
|
|
if (manager->node_users()[cnode].size() > 1) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
CheckCNodeInputSize(cnode, kSingleInputIndex + 1); |
|
|
|
return cnode->input(kSingleInputIndex); |
|
|
|
} |
|
|
|
|
|
|
|
bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_make_tuple_inputs; |
|
|
|
bool need_update = false; |
|
|
|
for (const auto &input : cnode->inputs()) { |
|
|
|
AnfNodePtr replace_input = GetReplaceNode(func_graph, input); |
|
|
|
// If replace input is not null, it will be the input of the TransData or Cast. |
|
|
|
if (replace_input == nullptr) { |
|
|
|
new_make_tuple_inputs.push_back(input); |
|
|
|
continue; |
|
|
|
} |
|
|
|
new_make_tuple_inputs.push_back(replace_input); |
|
|
|
need_update = true; |
|
|
|
} |
|
|
|
if (need_update) { |
|
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); |
|
|
|
CNodePtr new_make_tuple = nullptr; |
|
|
|
if (kernel_graph == nullptr) { |
|
|
|
new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); |
|
|
|
} else { |
|
|
|
new_make_tuple = kernel_graph->NewCNode(cnode); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(new_make_tuple); |
|
|
|
new_make_tuple->set_inputs(new_make_tuple_inputs); |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
manager->Replace(cnode, new_make_tuple); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef OptimizeDependence::DefinePattern() const { |
|
|
|
VarPtr X = std::make_shared<Var>("X"); |
|
|
|
MS_EXCEPTION_IF_NULL(X); |
|
|
|
@@ -43,9 +106,8 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto depend_cnode = node->cast<CNodePtr>(); |
|
|
|
if (depend_cnode->inputs().size() < kDependInputNum) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(depend_cnode); |
|
|
|
CheckCNodeInputSize(depend_cnode, kDependInputNum); |
|
|
|
auto replacing_node = depend_cnode->input(kDependInputNum - 1); |
|
|
|
MS_EXCEPTION_IF_NULL(replacing_node); |
|
|
|
if (!replacing_node->isa<CNode>()) { |
|
|
|
@@ -53,36 +115,29 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con |
|
|
|
} |
|
|
|
auto replacing_cnode = replacing_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(replacing_cnode); |
|
|
|
// Currently we only optimize transdata or cast nodes. |
|
|
|
string replacing_cnode_op_name = AnfAlgo::GetCNodeName(replacing_cnode); |
|
|
|
if (replacing_cnode_op_name != kTransDataOpName && replacing_cnode_op_name != prim::kPrimCast->name()) { |
|
|
|
// Deal with the make_tuple with TransData or Cast inputs. |
|
|
|
if (ReplaceMakeTuple(func_graph, replacing_cnode)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
// Check whether the replacing node has only one input and one output. |
|
|
|
if (replacing_cnode->inputs().size() != kSingleInputIndex + 1) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (manager->node_users().find(replacing_node) == manager->node_users().end()) { |
|
|
|
MS_LOG(EXCEPTION) << "The node should be used by at least another node input"; |
|
|
|
} |
|
|
|
if (manager->node_users()[replacing_node].size() > 1) { |
|
|
|
AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); |
|
|
|
if (replace_node == nullptr) { |
|
|
|
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex), |
|
|
|
depend_cnode->input(kRealInputIndexInDepend), |
|
|
|
replacing_cnode->input(kSingleInputIndex)}; |
|
|
|
depend_cnode->input(kRealInputIndexInDepend), replace_node}; |
|
|
|
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); |
|
|
|
CNodePtr new_depend; |
|
|
|
if (kernel_graph == nullptr) { |
|
|
|
new_depend = func_graph->NewCNode(new_depend_inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(new_depend); |
|
|
|
new_depend->set_abstract(node->abstract()); |
|
|
|
new_depend->set_scope(node->scope()); |
|
|
|
} else { |
|
|
|
new_depend = kernel_graph->NewCNode(depend_cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(new_depend); |
|
|
|
new_depend->set_inputs(new_depend_inputs); |
|
|
|
} |
|
|
|
new_depend->set_abstract(node->abstract()); |
|
|
|
return new_depend; |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|