| @@ -30,6 +30,7 @@ | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "pipeline/jit/base.h" | #include "pipeline/jit/base.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "runtime/device/ascend/kernel_select_ascend.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -38,8 +39,8 @@ namespace { | |||||
| // Pair of graph and its actual arguments. | // Pair of graph and its actual arguments. | ||||
| using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>; | using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>; | ||||
| // We start label id from 1, and use 0 to indicate label not set. | |||||
| constexpr uint32_t kNoLabel = 0; | |||||
| // We start label id from 0, and use 0xFFFFFFFF to indicate label not set. | |||||
| constexpr uint32_t kNoLabel = 0xFFFFFFFF; | |||||
| // Primitive attribute for argument link assign. | // Primitive attribute for argument link assign. | ||||
| const char LINK[] = "link"; | const char LINK[] = "link"; | ||||
| @@ -277,7 +278,7 @@ class AscendAutoMonadContext : public BaseContext { | |||||
| ParameterPool param_pool_; | ParameterPool param_pool_; | ||||
| // Current label id. | // Current label id. | ||||
| uint32_t label_id_ = 1; | |||||
| uint32_t label_id_ = 0; | |||||
| }; | }; | ||||
| // | // | ||||
| @@ -951,6 +952,7 @@ class ExecuteOrderGenerator { | |||||
| GenerateExecuteOrder(); | GenerateExecuteOrder(); | ||||
| EraseParameter(); | EraseParameter(); | ||||
| EraseLabel(); | EraseLabel(); | ||||
| UnfoldRepeatedLabels(); | |||||
| } | } | ||||
| private: | private: | ||||
| @@ -959,6 +961,101 @@ class ExecuteOrderGenerator { | |||||
| generator.GenerateExecuteOrder(); | generator.GenerateExecuteOrder(); | ||||
| } | } | ||||
| uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) { | |||||
| uint32_t max_label = 0; | |||||
| for (auto &node : nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | |||||
| max_label = std::max(label_id, max_label); | |||||
| } | |||||
| } | |||||
| return max_label; | |||||
| } | |||||
| void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels, | |||||
| std::multimap<uint32_t, uint32_t> *labels_multimap) { | |||||
| bool is_new_labels = false; | |||||
| auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); | |||||
| std::vector<uint32_t> new_labels; | |||||
| new_labels.reserve(label_list.size()); | |||||
| for (auto label_id : label_list) { | |||||
| auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; }); | |||||
| // Use new label if find repeated label. | |||||
| if (iter == labels->end()) { | |||||
| new_labels.emplace_back(label_id); | |||||
| continue; | |||||
| } | |||||
| new_labels.emplace_back(++max_label_); | |||||
| labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_)); | |||||
| is_new_labels = true; | |||||
| } | |||||
| labels->insert(labels->end(), new_labels.begin(), new_labels.end()); | |||||
| switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end()); | |||||
| if (is_new_labels) { | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node); | |||||
| } | |||||
| } | |||||
| void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels, | |||||
| std::multimap<uint32_t, uint32_t> *labels_multimap) { | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | |||||
| auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id); | |||||
| if (iter == switch_labels->end()) { | |||||
| labels->emplace_back(label_id); | |||||
| return; | |||||
| } | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node); | |||||
| labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_)); | |||||
| labels->emplace_back(max_label_); | |||||
| } | |||||
| // Unfold Repeated Labels, avoid same label in labelswitches. | |||||
| void UnfoldRepeatedLabels() { | |||||
| auto nodes = graph_->execution_order(); | |||||
| std::vector<uint32_t> labels; | |||||
| std::vector<uint32_t> switch_labels; | |||||
| std::multimap<uint32_t, uint32_t> labels_multimap; | |||||
| max_label_ = FindMaxLabelId(nodes); | |||||
| for (auto &node : nodes) { | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { | |||||
| HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap); | |||||
| continue; | |||||
| } | |||||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { | |||||
| HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| InsertLabelSet(&nodes, labels_multimap); | |||||
| graph_->set_label_num(max_label_ + 1); | |||||
| graph_->set_execution_order(nodes); | |||||
| } | |||||
| void InsertLabelSet(std::vector<CNodePtr> *nodes, const std::multimap<uint32_t, uint32_t> &labels_multimap) { | |||||
| for (auto labels : labels_multimap) { | |||||
| auto old_label = labels.first; | |||||
| auto new_label = labels.second; | |||||
| auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) { | |||||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { | |||||
| return false; | |||||
| } | |||||
| auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); | |||||
| return label_id == old_label; | |||||
| }); | |||||
| if (iter == nodes->end()) { | |||||
| MS_LOG(EXCEPTION) << "Not found labelset:" << old_label; | |||||
| } | |||||
| auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name())); | |||||
| auto cnode = graph_->NewCNode({label_set}); | |||||
| AnfAlgo::CopyNodeAttrs(*iter, cnode); | |||||
| AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode); | |||||
| auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad); | |||||
| cnode->set_abstract(monad->abstract()); | |||||
| device::ascend::SelectKernelInfo(cnode); | |||||
| nodes->insert(iter, cnode); | |||||
| } | |||||
| } | |||||
| void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) { | void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) { | ||||
| auto &order = graph->execution_order(); | auto &order = graph->execution_order(); | ||||
| execution_order->insert(execution_order->end(), order.begin(), order.end()); | execution_order->insert(execution_order->end(), order.begin(), order.end()); | ||||
| @@ -1231,6 +1328,7 @@ class ExecuteOrderGenerator { | |||||
| Context &context_; | Context &context_; | ||||
| const KernelGraphPtr graph_; | const KernelGraphPtr graph_; | ||||
| uint32_t max_label_ = 0; | |||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| @@ -1241,7 +1339,7 @@ void AscendAutoMonad::Run() { | |||||
| AscendAutoMonadContext context(kg); | AscendAutoMonadContext context(kg); | ||||
| CallInfoFinder::Run(&context); | CallInfoFinder::Run(&context); | ||||
| AscendAutoMonadConverter::Run(&context); | AscendAutoMonadConverter::Run(&context); | ||||
| kernel_graph_->set_label_num(context.CurrentLabel()); | |||||
| kernel_graph_->set_label_num(context.CurrentLabel() + 1); | |||||
| MS_LOG(DEBUG) << "Ascend auto-monad finish."; | MS_LOG(DEBUG) << "Ascend auto-monad finish."; | ||||
| DumpGraphForDebug(kernel_graph_); | DumpGraphForDebug(kernel_graph_); | ||||
| } | } | ||||