|
|
|
@@ -349,11 +349,10 @@ void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotN |
|
|
|
|
|
|
|
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
(void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo)); |
|
|
|
(void)RecurseGraph(root_graph, NOT_NULL(&memo)); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, |
|
|
|
NotNull<KernelGraphPtr> graph, |
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, |
|
|
|
NotNull<std::set<KernelGraphPtr> *> memo) { |
|
|
|
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; |
|
|
|
auto print_vector = [&](std::vector<CNodePtr> vec) -> void { |
|
|
|
@@ -366,52 +365,38 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
return {}; |
|
|
|
} |
|
|
|
memo->insert(graph.get()); |
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); |
|
|
|
graph->SetExecOrderByDefault(); |
|
|
|
|
|
|
|
const std::vector<CNodePtr> &cnodes = graph->execution_order(); |
|
|
|
std::map<uint32_t, CNodePtr> label_map; |
|
|
|
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; |
|
|
|
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); |
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order; |
|
|
|
uint32_t child_order_index = 0; |
|
|
|
|
|
|
|
for (auto &node : cnodes) { |
|
|
|
execution_order.push_back(node); |
|
|
|
if (node == graph->get_end_goto()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto label_iter = |
|
|
|
std::find_if(label_map.begin(), label_map.end(), |
|
|
|
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; }); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { |
|
|
|
if (!CheckLabelIndex(child_order_index, 0, node, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = child_graph_order[label_iter->first]; |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::map<uint32_t, CNodePtr> child_label_map; |
|
|
|
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); |
|
|
|
auto child_execution_order = |
|
|
|
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
std::vector<uint32_t> label_list = label_switch_map.find(node)->second; |
|
|
|
std::reverse(label_list.begin(), label_list.end()); |
|
|
|
for (size_t i = 0; i < label_list.size(); ++i) { |
|
|
|
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { |
|
|
|
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node); |
|
|
|
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { |
|
|
|
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = child_graph_order[label_iter->first + i]; |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::map<uint32_t, CNodePtr> child_label_map; |
|
|
|
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); |
|
|
|
auto child_execution_order = |
|
|
|
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -421,6 +406,15 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe |
|
|
|
return execution_order; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; |
|
|
|
} |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, |
|
|
|
NotNull<KernelGraphPtr> graph) { |
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); |
|
|
|
@@ -458,31 +452,6 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> AscendControlParser::GetLabelNode( |
|
|
|
const std::vector<CNodePtr> &nodes) { |
|
|
|
std::map<uint32_t, CNodePtr> label_map; |
|
|
|
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map; |
|
|
|
// record child graph |
|
|
|
uint32_t index = 0; |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
label_map[index++] = node; |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; |
|
|
|
} |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); |
|
|
|
label_switch_map.insert({node, label_list}); |
|
|
|
for (size_t i = 0; i < label_list.size(); ++i) { |
|
|
|
label_map[index++] = node; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return {label_map, label_switch_map}; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) { |
|
|
|
MS_LOG(INFO) << "graph id:" << kg->graph_id(); |
|
|
|
kg->SetExecOrderByDefault(); |
|
|
|
|