|
|
|
@@ -33,31 +33,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { |
|
|
|
for (auto &iter : graph_id_map) { |
|
|
|
auto &kg = iter.second; |
|
|
|
MS_EXCEPTION_IF_NULL(kg); |
|
|
|
auto real_inputs = kg->real_inputs(); |
|
|
|
for (auto &it : real_inputs) { |
|
|
|
auto ¶meter = it.first; |
|
|
|
auto &args = it.second; |
|
|
|
for (auto &arg : args) { |
|
|
|
MS_EXCEPTION_IF_NULL(arg); |
|
|
|
if (arg->isa<Parameter>()) { |
|
|
|
MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() |
|
|
|
<< ", arg:" << arg->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); |
|
|
|
if (target_graph_iter == graph_id_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; |
|
|
|
} |
|
|
|
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, |
|
|
|
const NotNull<std::set<KernelGraphPtr> *> memo) { |
|
|
|
if (memo->find(kg.get()) != memo->end()) { |
|
|
|
@@ -89,6 +64,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union |
|
|
|
return; |
|
|
|
} |
|
|
|
memo->insert(kg.get()); |
|
|
|
|
|
|
|
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs(); |
|
|
|
for (auto &iter : real_inputs) { |
|
|
|
auto ¶ = iter.first; |
|
|
|
@@ -150,11 +126,10 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet |
|
|
|
const auto &root_inputs_vector = root_kg->inputs(); |
|
|
|
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); |
|
|
|
for (auto &node : parameter_reuse_set) { |
|
|
|
if (root_inputs_set.find(node) == root_inputs_set.end()) { |
|
|
|
continue; |
|
|
|
if (root_inputs_set.find(node) != root_inputs_set.end()) { |
|
|
|
main_parameter = node; |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
main_parameter = node; |
|
|
|
} |
|
|
|
|
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
@@ -162,9 +137,18 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { |
|
|
|
for (size_t i = start; i < list.size() - 1; ++i) { |
|
|
|
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { |
|
|
|
return list[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); |
|
|
|
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); |
|
|
|
std::map<uint32_t, KernelGraphPtr> graph_id_map; |
|
|
|
for (auto &g : memo) { |
|
|
|
if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { |
|
|
|
@@ -181,13 +165,34 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { |
|
|
|
ChildGraphDataAssign(graph_id_map); |
|
|
|
} |
|
|
|
|
|
|
|
CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) { |
|
|
|
for (size_t i = start; i < list.size() - 1; ++i) { |
|
|
|
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { |
|
|
|
return list[i]; |
|
|
|
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
(void)RecurseGraph(root_graph, NOT_NULL(&memo)); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) { |
|
|
|
for (auto &iter : graph_id_map) { |
|
|
|
auto &kg = iter.second; |
|
|
|
MS_EXCEPTION_IF_NULL(kg); |
|
|
|
auto real_inputs = kg->real_inputs(); |
|
|
|
for (auto &it : real_inputs) { |
|
|
|
auto ¶meter = it.first; |
|
|
|
auto &args = it.second; |
|
|
|
for (auto &arg : args) { |
|
|
|
MS_EXCEPTION_IF_NULL(arg); |
|
|
|
if (arg->isa<Parameter>()) { |
|
|
|
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() |
|
|
|
<< ", arg:" << arg->DebugString(); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); |
|
|
|
if (target_graph_iter == graph_id_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; |
|
|
|
} |
|
|
|
InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, |
|
|
|
@@ -212,9 +217,16 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr |
|
|
|
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; |
|
|
|
} |
|
|
|
// 4. insert first_label |
|
|
|
auto start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); |
|
|
|
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); |
|
|
|
kg->set_start_label(start_label); |
|
|
|
CNodePtr start_label; |
|
|
|
if (last_node != nullptr && last_label != nullptr) { |
|
|
|
start_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))}); |
|
|
|
MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); |
|
|
|
kg->set_start_label(start_label); |
|
|
|
} else { |
|
|
|
// no goto node will jump to start label of root graph, so return a fake label |
|
|
|
start_label = std::make_shared<CNode>(std::vector<AnfNodePtr>(), FuncGraphPtr(nullptr)); |
|
|
|
} |
|
|
|
|
|
|
|
// 5. traverse |
|
|
|
for (size_t i = 0; i < nodes.size(); ++i) { |
|
|
|
auto &cnode = nodes[i]; |
|
|
|
@@ -249,11 +261,10 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node) { |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("depend"))}; |
|
|
|
auto return_node = kg->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(return_node); |
|
|
|
inputs.push_back(return_node->input(1)); |
|
|
|
inputs.push_back(attch_node.get()); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), |
|
|
|
return_node->input(1), attch_node.get()}; |
|
|
|
auto depend_node = kg->NewCNode(inputs); |
|
|
|
return_node->set_input(1, depend_node); |
|
|
|
} |
|
|
|
@@ -407,9 +418,9 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A |
|
|
|
if (partial_cnode->size() < kCNodePartialLength) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; |
|
|
|
} |
|
|
|
|
|
|
|
auto partial_inputs = partial_cnode->inputs(); |
|
|
|
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); |
|
|
|
|
|
|
|
return {partial_cnode, branch_kg}; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -425,7 +436,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul |
|
|
|
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " |
|
|
|
<< to->DebugString(); |
|
|
|
// config inputs of assign node |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>("Assign")), to, from}; |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimAssign->name())), to, from}; |
|
|
|
// generate a new cnode |
|
|
|
auto assign_node = kg->NewCNode(inputs); |
|
|
|
MS_EXCEPTION_IF_NULL(assign_node); |
|
|
|
@@ -434,11 +445,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul |
|
|
|
InsertDependToGraph(kg, NOT_NULL(assign_node)); |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { |
|
|
|
std::set<KernelGraphPtr> memo; |
|
|
|
(void)RecurseGraph(root_graph, NOT_NULL(&memo)); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, |
|
|
|
const NotNull<std::set<KernelGraphPtr> *> memo) { |
|
|
|
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; |
|
|
|
@@ -457,29 +463,24 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> |
|
|
|
if (node == graph->get_end_goto()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
if (!CheckLabelIndex(child_order_index, 0, node, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node); |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
std::vector<uint32_t> label_switch_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); |
|
|
|
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { |
|
|
|
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
if (child_graph == graph->parent_graph()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
} |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
uint32_t label_index = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); |
|
|
|
if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { |
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail"; |
|
|
|
} |
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++]; |
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); |
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); |
|
|
|
} |
|
|
|
} |
|
|
|
graph->set_execution_order(execution_order); |
|
|
|
@@ -487,15 +488,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> |
|
|
|
return execution_order; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) { |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; |
|
|
|
} |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, |
|
|
|
NotNull<KernelGraphPtr> graph) { |
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order(); |
|
|
|
@@ -504,33 +496,19 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i |
|
|
|
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " |
|
|
|
<< child_graph_order.size() << " goto index " << order_index; |
|
|
|
} |
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { |
|
|
|
// check label_goto and start_label in child graph |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; |
|
|
|
} |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cur_label); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
uint32_t label_goto_index = GetValue<uint32_t>(primitive->GetAttr(kAttrLabelIndex)); |
|
|
|
label_index = label_goto_index; |
|
|
|
} |
|
|
|
// get start_label_set_index of child graph |
|
|
|
auto child_graph = child_graph_order[order_index]; |
|
|
|
MS_EXCEPTION_IF_NULL(child_graph); |
|
|
|
|
|
|
|
// get start_label_set_index of child graph |
|
|
|
auto start_label_set = child_graph->get_start_label(); |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { |
|
|
|
MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; |
|
|
|
} |
|
|
|
auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set); |
|
|
|
MS_EXCEPTION_IF_NULL(start_primitive); |
|
|
|
uint32_t start_label_set_index = GetValue<uint32_t>(start_primitive->GetAttr(kAttrLabelIndex)); |
|
|
|
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex); |
|
|
|
if (label_index != start_label_set_index) { |
|
|
|
MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() |
|
|
|
<< " index " << start_label_set_index << " current child graph order : " << order_index; |
|
|
|
return false; |
|
|
|
} else { |
|
|
|
return true; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) { |
|
|
|
|