|
|
|
@@ -30,6 +30,7 @@ |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
#include "pipeline/jit/base.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "runtime/device/ascend/kernel_select_ascend.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
@@ -38,8 +39,8 @@ namespace { |
|
|
|
// Pair of graph and its actual arguments. |
|
|
|
using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>; |
|
|
|
|
|
|
|
// We start label id from 1, and use 0 to indicate label not set. |
|
|
|
constexpr uint32_t kNoLabel = 0; |
|
|
|
// We start label id from 0, and use 0xFFFFFFFF to indicate label not set. |
|
|
|
constexpr uint32_t kNoLabel = 0xFFFFFFFF; |
|
|
|
|
|
|
|
// Primitive attribute for argument link assign. |
|
|
|
const char LINK[] = "link"; |
|
|
|
@@ -296,7 +297,7 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
ParameterPool param_pool_; |
|
|
|
|
|
|
|
// Current label id. |
|
|
|
uint32_t label_id_ = 1; |
|
|
|
uint32_t label_id_ = 0; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
@@ -1052,6 +1053,7 @@ class ExecuteOrderGenerator { |
|
|
|
GenerateExecuteOrder(); |
|
|
|
EraseParameter(); |
|
|
|
EraseLabel(); |
|
|
|
UnfoldRepeatedLabels(); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
@@ -1060,6 +1062,101 @@ class ExecuteOrderGenerator { |
|
|
|
generator.GenerateExecuteOrder(); |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) { |
|
|
|
uint32_t max_label = 0; |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { |
|
|
|
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); |
|
|
|
max_label = std::max(label_id, max_label); |
|
|
|
} |
|
|
|
} |
|
|
|
return max_label; |
|
|
|
} |
|
|
|
|
|
|
|
void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels, |
|
|
|
std::multimap<uint32_t, uint32_t> *labels_multimap) { |
|
|
|
bool is_new_labels = false; |
|
|
|
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList); |
|
|
|
std::vector<uint32_t> new_labels; |
|
|
|
new_labels.reserve(label_list.size()); |
|
|
|
for (auto label_id : label_list) { |
|
|
|
auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; }); |
|
|
|
// Use new label if find repeated label. |
|
|
|
if (iter == labels->end()) { |
|
|
|
new_labels.emplace_back(label_id); |
|
|
|
continue; |
|
|
|
} |
|
|
|
new_labels.emplace_back(++max_label_); |
|
|
|
labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_)); |
|
|
|
is_new_labels = true; |
|
|
|
} |
|
|
|
labels->insert(labels->end(), new_labels.begin(), new_labels.end()); |
|
|
|
switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end()); |
|
|
|
if (is_new_labels) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels, |
|
|
|
std::multimap<uint32_t, uint32_t> *labels_multimap) { |
|
|
|
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); |
|
|
|
auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id); |
|
|
|
if (iter == switch_labels->end()) { |
|
|
|
labels->emplace_back(label_id); |
|
|
|
return; |
|
|
|
} |
|
|
|
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node); |
|
|
|
labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_)); |
|
|
|
labels->emplace_back(max_label_); |
|
|
|
} |
|
|
|
|
|
|
|
// Unfold Repeated Labels, avoid same label in labelswitches. |
|
|
|
void UnfoldRepeatedLabels() { |
|
|
|
auto nodes = graph_->execution_order(); |
|
|
|
std::vector<uint32_t> labels; |
|
|
|
std::vector<uint32_t> switch_labels; |
|
|
|
std::multimap<uint32_t, uint32_t> labels_multimap; |
|
|
|
max_label_ = FindMaxLabelId(nodes); |
|
|
|
for (auto &node : nodes) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { |
|
|
|
HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { |
|
|
|
HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
InsertLabelSet(&nodes, labels_multimap); |
|
|
|
graph_->set_label_num(max_label_ + 1); |
|
|
|
graph_->set_execution_order(nodes); |
|
|
|
} |
|
|
|
|
|
|
|
void InsertLabelSet(std::vector<CNodePtr> *nodes, const std::multimap<uint32_t, uint32_t> &labels_multimap) { |
|
|
|
for (auto labels : labels_multimap) { |
|
|
|
auto old_label = labels.first; |
|
|
|
auto new_label = labels.second; |
|
|
|
auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex); |
|
|
|
return label_id == old_label; |
|
|
|
}); |
|
|
|
if (iter == nodes->end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Not found labelset:" << old_label; |
|
|
|
} |
|
|
|
auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name())); |
|
|
|
auto cnode = graph_->NewCNode({label_set}); |
|
|
|
AnfAlgo::CopyNodeAttrs(*iter, cnode); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode); |
|
|
|
auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad); |
|
|
|
cnode->set_abstract(monad->abstract()); |
|
|
|
device::ascend::SelectKernelInfo(cnode); |
|
|
|
nodes->insert(iter, cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) { |
|
|
|
auto &order = graph->execution_order(); |
|
|
|
execution_order->insert(execution_order->end(), order.begin(), order.end()); |
|
|
|
@@ -1343,6 +1440,7 @@ class ExecuteOrderGenerator { |
|
|
|
|
|
|
|
Context &context_; |
|
|
|
const KernelGraphPtr graph_; |
|
|
|
uint32_t max_label_ = 0; |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace |
|
|
|
@@ -1353,7 +1451,7 @@ void AscendAutoMonad::Run() { |
|
|
|
AscendAutoMonadContext context(kg); |
|
|
|
CallInfoFinder::Run(&context); |
|
|
|
AscendAutoMonadConverter::Run(&context); |
|
|
|
kernel_graph_->set_label_num(context.CurrentLabel()); |
|
|
|
kernel_graph_->set_label_num(context.CurrentLabel() + 1); |
|
|
|
MS_LOG(DEBUG) << "Ascend auto-monad finish."; |
|
|
|
DumpGraphForDebug(kernel_graph_); |
|
|
|
} |
|
|
|
|