|
|
@@ -225,7 +225,7 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, |
|
|
|
|
|
|
|
|
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of |
|
|
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of |
|
|
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] |
|
|
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] |
|
|
static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { |
|
|
|
|
|
|
|
|
static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) { |
|
|
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); |
|
|
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); |
|
|
for (auto &call_node : call_nodes) { |
|
|
for (auto &call_node : call_nodes) { |
|
|
MS_EXCEPTION_IF_NULL(call_node); |
|
|
MS_EXCEPTION_IF_NULL(call_node); |
|
|
@@ -236,7 +236,9 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { |
|
|
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); |
|
|
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()); |
|
|
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); |
|
|
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs(); |
|
|
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); |
|
|
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get()); |
|
|
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); |
|
|
|
|
|
|
|
|
if (split_flag) { |
|
|
|
|
|
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); |
|
|
|
|
|
} |
|
|
} else if (child_graphs.size() == 2) { |
|
|
} else if (child_graphs.size() == 2) { |
|
|
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { |
|
|
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { |
|
|
auto switch_node = call_node->input(1); |
|
|
auto switch_node = call_node->input(1); |
|
|
@@ -248,8 +250,10 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) { |
|
|
auto partial_cnode = partial->cast<CNodePtr>(); |
|
|
auto partial_cnode = partial->cast<CNodePtr>(); |
|
|
MS_EXCEPTION_IF_NULL(partial_cnode); |
|
|
MS_EXCEPTION_IF_NULL(partial_cnode); |
|
|
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); |
|
|
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); |
|
|
partial_cnode->set_inputs( |
|
|
|
|
|
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); |
|
|
|
|
|
|
|
|
if (split_flag) { |
|
|
|
|
|
partial_cnode->set_inputs( |
|
|
|
|
|
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); |
|
|
|
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
}; |
|
|
}; |
|
|
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); |
|
|
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); |
|
|
@@ -1678,6 +1682,7 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, |
|
|
|
|
|
|
|
|
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { |
|
|
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { |
|
|
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); |
|
|
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); |
|
|
|
|
|
bool split_flag = false; |
|
|
auto apply_list = GetCNodes(TopoSort(graph->get_return())); |
|
|
auto apply_list = GetCNodes(TopoSort(graph->get_return())); |
|
|
// update the root graph child graph order |
|
|
// update the root graph child graph order |
|
|
AscendControlParser::UpdateChildGraphOrder(graph); |
|
|
AscendControlParser::UpdateChildGraphOrder(graph); |
|
|
@@ -1710,9 +1715,10 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri |
|
|
AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); |
|
|
AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
split_flag = true; |
|
|
} |
|
|
} |
|
|
AscendControlParser::UpdateChildGraphOrder(graph); |
|
|
AscendControlParser::UpdateChildGraphOrder(graph); |
|
|
UpdateRealInput(graph); |
|
|
|
|
|
|
|
|
UpdateRealInput(graph, split_flag); |
|
|
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; |
|
|
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; |
|
|
// recurse to split child graph |
|
|
// recurse to split child graph |
|
|
} |
|
|
} |
|
|
|