|
|
|
@@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) { |
|
|
|
auto &nodes = parent_graph->execution_order(); |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) { |
|
|
|
return node; |
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) && |
|
|
|
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || |
|
|
|
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) { |
|
|
|
return node; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, |
|
|
|
const NotNull<std::set<KernelGraphPtr> *> memo) { |
|
|
|
if (memo->find(kg.get()) != memo->end()) { |
|
|
|
@@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr |
|
|
|
if (target_graph_iter == graph_id_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; |
|
|
|
} |
|
|
|
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); |
|
|
|
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), |
|
|
|
NOT_NULL(parameter)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A |
|
|
|
return {partial_cnode, branch_kg}; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, |
|
|
|
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, |
|
|
|
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from, |
|
|
|
NotNull<AnfNodePtr> to) { |
|
|
|
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); |
|
|
|
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); |
|
|
|
@@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg |
|
|
|
<< to_outputs.size() << "]"; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < from_outputs.size(); i++) { |
|
|
|
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); |
|
|
|
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); |
|
|
|
if (assign_node != nullptr) { |
|
|
|
auto jump_node = GetJumpNode(from_graph, to_graph); |
|
|
|
if (jump_node != nullptr) { |
|
|
|
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, |
|
|
|
NotNull<AnfNodePtr> to) { |
|
|
|
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, |
|
|
|
NotNull<AnfNodePtr> to) { |
|
|
|
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && |
|
|
|
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { |
|
|
|
return; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (from.get() == to.get()) { |
|
|
|
return; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " |
|
|
|
<< to->DebugString(); |
|
|
|
@@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul |
|
|
|
assign_node->set_abstract(to->abstract()); |
|
|
|
// append the assign at the end of from graph |
|
|
|
InsertDependToGraph(kg, NOT_NULL(assign_node)); |
|
|
|
return assign_node; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, |
|
|
|
|