|
|
|
@@ -118,62 +118,6 @@ void DumpExecuteOrder(NotNull<KernelGraphPtr> kg) { |
|
|
|
fout.close(); |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// ParameterPool cache parameters by its abstract, so that we can reuse |
|
|
|
// parameter with same abstract to store return values. |
|
|
|
// |
|
|
|
class ParameterPool { |
|
|
|
public: |
|
|
|
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} |
|
|
|
~ParameterPool() = default; |
|
|
|
|
|
|
|
// Create or get a parameter from pool with the given abstract. |
|
|
|
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { |
|
|
|
// Find parameter in pool by the given abstract. |
|
|
|
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { |
|
|
|
auto para_abs = para->abstract(); |
|
|
|
// Reuse output parameter with compatible abstract. |
|
|
|
return IsCompatible(abs, para_abs); |
|
|
|
}); |
|
|
|
// Return the parameter if found. |
|
|
|
if (iter != paras_.end()) { |
|
|
|
return *iter; |
|
|
|
} |
|
|
|
// If parameter not found with the given abstract, create a new one. |
|
|
|
auto para = top_graph_->NewParameter(abs); |
|
|
|
auto out_para = top_graph_->TransTupleToMakeTuple(para); |
|
|
|
// This is required, so that device memory can be allocated for it. |
|
|
|
top_graph_->AddChildGraphResult(out_para); |
|
|
|
// Save new para to pool. |
|
|
|
paras_.push_back(out_para); |
|
|
|
return out_para; |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
// Check if one abstract is compatible with another abstract. |
|
|
|
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { |
|
|
|
if (a1 == nullptr || a2 == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) { |
|
|
|
// This make AbstractRef compatible with AbstractTensor. |
|
|
|
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1); |
|
|
|
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2); |
|
|
|
return t1 == t2; |
|
|
|
} |
|
|
|
return *a1 == *a2; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
// The top graph. |
|
|
|
const KernelGraphPtr &top_graph_; |
|
|
|
|
|
|
|
// Cached parameters. |
|
|
|
std::vector<AnfNodePtr> paras_; |
|
|
|
}; |
|
|
|
|
|
|
|
using ParameterPoolPtr = std::shared_ptr<ParameterPool>; |
|
|
|
|
|
|
|
class BaseContext { |
|
|
|
public: |
|
|
|
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } |
|
|
|
@@ -200,13 +144,38 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
// Current label id, also the number of label ids we currently used. |
|
|
|
uint32_t CurrentLabel() const { return label_id_; } |
|
|
|
|
|
|
|
// Create a new parameter pool. |
|
|
|
ParameterPoolPtr NewParameterPool() { return std::make_shared<ParameterPool>(top_graph_); } |
|
|
|
// Create or get a parameter for output of the kernel graph. |
|
|
|
AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) { |
|
|
|
// Find output parameter by kernel graph. |
|
|
|
auto iter = kg_out_param_.find(kg); |
|
|
|
if (iter != kg_out_param_.end()) { |
|
|
|
// Return output parameter if found. |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
// Create a new one if not found. |
|
|
|
// Output parameters are all created on top graph. |
|
|
|
auto para = top_graph_->NewParameter(kg->output()->abstract()); |
|
|
|
auto out_para = top_graph_->TransTupleToMakeTuple(para); |
|
|
|
// This is required, so that device memory can be allocated for it. |
|
|
|
top_graph_->AddChildGraphResult(out_para); |
|
|
|
// Save new para as the output parameter of the kg. |
|
|
|
kg_out_param_.emplace(kg, out_para); |
|
|
|
return out_para; |
|
|
|
} |
|
|
|
|
|
|
|
// Set output parameter for a kernel graph. |
|
|
|
void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) { |
|
|
|
// Save new para as the output parameter of the kg. |
|
|
|
kg_out_param_.emplace(kg, out_para); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
// The top graph. |
|
|
|
const KernelGraphPtr &top_graph_; |
|
|
|
|
|
|
|
// Map kernel_graph to its output parameter. |
|
|
|
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_; |
|
|
|
|
|
|
|
// Current label id. |
|
|
|
uint32_t label_id_ = 1; |
|
|
|
}; |
|
|
|
@@ -254,6 +223,7 @@ class AscendAutoMonadConverter { |
|
|
|
// Prepare information for control flow processing. |
|
|
|
// |
|
|
|
void Prepare() { |
|
|
|
recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive); |
|
|
|
AnfNodePtr last_monad = nullptr; |
|
|
|
auto nodes = TopoSort(kernel_graph_->output()); |
|
|
|
for (auto &node : nodes) { |
|
|
|
@@ -291,26 +261,25 @@ class AscendAutoMonadConverter { |
|
|
|
for (auto &cnode : call_switch_nodes_) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { |
|
|
|
HandleCall(cnode); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { |
|
|
|
HandleSwitch(cnode); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { |
|
|
|
HandleSwitchLayer(cnode); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
} |
|
|
|
// If no tail call, assign output value to output parameter, |
|
|
|
// and then goto the return label if set. |
|
|
|
if (tail_call_node_ == nullptr) { |
|
|
|
if (tail_call_node_ == nullptr || recursive_) { |
|
|
|
if (output_parameter_) { |
|
|
|
auto assign_output = AssignAll(output_parameter_, kernel_graph_->output()); |
|
|
|
monad_ = UpdateState(GetMonad(), assign_output); |
|
|
|
} |
|
|
|
if (return_label_ != kNoLabel) { |
|
|
|
(void)LabelGoto(return_label_); |
|
|
|
} else { |
|
|
|
// Clear end goto if return label not set. |
|
|
|
kernel_graph_->set_end_goto(nullptr); |
|
|
|
// Insert label_goto for return. |
|
|
|
auto return_goto = LabelGoto(return_label_); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); |
|
|
|
kernel_graph_->set_end_goto(return_goto); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -348,33 +317,37 @@ class AscendAutoMonadConverter { |
|
|
|
// as 'select kernel' can handle sub graphs. |
|
|
|
SetChildGrapAttr(goto_node, {graph}); |
|
|
|
|
|
|
|
// Setup return label if this is not a tail call. |
|
|
|
// Setup return label if this is not a tail call or it is a recursive call. |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
const bool need_return = !is_tail_call; |
|
|
|
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); |
|
|
|
const bool need_return = (!is_tail_call || recursive_); |
|
|
|
if (!need_return) { |
|
|
|
// Set as end_goto if no return required. |
|
|
|
kernel_graph_->set_end_goto(goto_node); |
|
|
|
} |
|
|
|
auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return); |
|
|
|
|
|
|
|
// Handle sub-graph recursively. |
|
|
|
HandleSubGraph(graph, para_pool, output_para, return_label); |
|
|
|
HandleSubGraph(graph, output_para, return_label); |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// Convert switch node: |
|
|
|
// Convert switch/switchlayer node: |
|
|
|
// branch1 = Partial(graph1, arg) |
|
|
|
// branch2 = Partial(graph2, arg) |
|
|
|
// out = Switch(cond, branch1, branch2) |
|
|
|
// out = Switch/SwitchLayer(cond/index, branch1, branch2) |
|
|
|
// to: |
|
|
|
// r = link_args(graph1, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// r = link_args(graph2, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// c = LabelSwitch(cond, c) : L1, L2 |
|
|
|
// c = LabelSwitch(cond/index, c) : L1, L2 |
|
|
|
// c = LabelSet(c) : <return label> |
|
|
|
// |
|
|
|
void HandleSwitch(const CNodePtr &cnode) { |
|
|
|
// Update last_monad_. |
|
|
|
last_monad_ = monad_map_[cnode]; |
|
|
|
|
|
|
|
// Get both branches of the switch, true branch first. |
|
|
|
// Get branches of the switch or switchlayer, true or 0 branch first. |
|
|
|
auto branches = GetSwitchBranches(cnode); |
|
|
|
|
|
|
|
// Link arguments and generate labels for branches. |
|
|
|
@@ -394,10 +367,13 @@ class AscendAutoMonadConverter { |
|
|
|
labels.push_back(GetOrCreateGraphLabel(graph)); |
|
|
|
} |
|
|
|
|
|
|
|
// Since true/false branches is reversed in kernel LabelSwitch, |
|
|
|
// We reverse graphes and labels to make false branch first. |
|
|
|
std::reverse(graphes.begin(), graphes.end()); |
|
|
|
std::reverse(labels.begin(), labels.end()); |
|
|
|
const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch); |
|
|
|
if (is_switch) { |
|
|
|
// For Switch, we reverse the graphes and labels, so that the false branch |
|
|
|
// is the first one, since for kernel LabelSwitch, false is the first branch. |
|
|
|
std::reverse(graphes.begin(), graphes.end()); |
|
|
|
std::reverse(labels.begin(), labels.end()); |
|
|
|
} |
|
|
|
|
|
|
|
// Add LabelSwith node. |
|
|
|
auto switch_node = LabelSwitch(cnode->input(1), labels); |
|
|
|
@@ -405,95 +381,42 @@ class AscendAutoMonadConverter { |
|
|
|
// Set child graph attribute for switch node. |
|
|
|
SetChildGrapAttr(switch_node, graphes); |
|
|
|
|
|
|
|
// Setup return label if required. |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
const bool need_return = (return_label_ == kNoLabel || !is_tail_call); |
|
|
|
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); |
|
|
|
|
|
|
|
// Handle sub-graphs recursively. |
|
|
|
for (auto &graph : graphes) { |
|
|
|
HandleSubGraph(graph, para_pool, output_para, return_label); |
|
|
|
if (!is_switch) { |
|
|
|
// Mark the switch node is for 'switch_layer'. |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// Convert switch node: |
|
|
|
// branch1 = Partial(graph1, arg) |
|
|
|
// branch2 = Partial(graph2, arg) |
|
|
|
// out = SwitchLayer(index, branch1, branch2) |
|
|
|
// to: |
|
|
|
// r = link_args(graph1, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// r = link_args(graph2, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// c = LabelSwitch(index, c) : L1, L2 |
|
|
|
// c = LabelSet(c) : <return label> |
|
|
|
// |
|
|
|
void HandleSwitchLayer(const CNodePtr &cnode) { |
|
|
|
// Update last_monad_. |
|
|
|
last_monad_ = monad_map_[cnode]; |
|
|
|
|
|
|
|
// Get both branches of the switch, true branch first. |
|
|
|
auto branches = GetSwitchBranches(cnode); |
|
|
|
|
|
|
|
// Link arguments and generate labels for branches. |
|
|
|
std::vector<KernelGraphPtr> graphes; |
|
|
|
std::vector<uint32_t> labels; |
|
|
|
graphes.reserve(branches.size()); |
|
|
|
labels.reserve(graphes.size()); |
|
|
|
for (auto &[graph, args] : branches) { |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
auto linked_args = LinkArguments(args, graph); |
|
|
|
if (linked_args != nullptr) { |
|
|
|
monad_ = UpdateState(GetMonad(), linked_args); |
|
|
|
} |
|
|
|
graphes.push_back(graph); |
|
|
|
labels.push_back(GetOrCreateGraphLabel(graph)); |
|
|
|
} |
|
|
|
|
|
|
|
// Add LabelSwith node. |
|
|
|
auto switch_node = LabelSwitch(cnode->input(1), labels); |
|
|
|
|
|
|
|
// Set child graph attribute for switch node. |
|
|
|
SetChildGrapAttr(switch_node, graphes); |
|
|
|
|
|
|
|
// Setup return label if required. |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
const bool need_return = (return_label_ == kNoLabel || !is_tail_call); |
|
|
|
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); |
|
|
|
const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_); |
|
|
|
auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return); |
|
|
|
|
|
|
|
// Handle sub-graphs recursively. |
|
|
|
for (auto &graph : graphes) { |
|
|
|
HandleSubGraph(graph, para_pool, output_para, return_label); |
|
|
|
HandleSubGraph(graph, output_para, return_label); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ParameterPoolPtr GetParameterPool(bool is_last_call) { |
|
|
|
if (!is_last_call) { |
|
|
|
// There are multiple calls in this graph, use a new parameter pool |
|
|
|
// for each of them except the last one. |
|
|
|
return context_.NewParameterPool(); |
|
|
|
} |
|
|
|
// For last call, try reuse parameter pool from the caller. |
|
|
|
if (para_pool_ == nullptr) { |
|
|
|
para_pool_ = context_.NewParameterPool(); |
|
|
|
AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches) { |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
if (is_tail_call && output_parameter_ != nullptr) { |
|
|
|
return output_parameter_; |
|
|
|
} |
|
|
|
return para_pool_; |
|
|
|
return context_.GetOutputParameter(branches.front()); |
|
|
|
} |
|
|
|
|
|
|
|
// Make return part of a call for the LabelGoto/LabelSwitch node. |
|
|
|
std::tuple<ParameterPoolPtr, AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, bool need_return) { |
|
|
|
// Find a parameter pool for output parameter. |
|
|
|
const bool is_last_call = (cnode == call_switch_nodes_.back()); |
|
|
|
auto para_pool = GetParameterPool(is_last_call); |
|
|
|
|
|
|
|
// Prepare return label and output parameter. |
|
|
|
std::tuple<AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches, |
|
|
|
bool need_return) { |
|
|
|
// Prepare return label. |
|
|
|
uint32_t return_label = return_label_; |
|
|
|
auto output_para = para_pool->GetParameter(cnode->abstract()); |
|
|
|
// Prepare output parameter. |
|
|
|
auto output_para = GetOutputParameter(cnode, branches); |
|
|
|
// Use same output parameter for all branches. |
|
|
|
for (auto &branch : branches) { |
|
|
|
context_.SetOutputParameter(branch, output_para); |
|
|
|
} |
|
|
|
auto output = output_para; |
|
|
|
|
|
|
|
// Setup return label if return is required. |
|
|
|
if (need_return) { |
|
|
|
// Set a new label at return point. |
|
|
|
@@ -504,16 +427,14 @@ class AscendAutoMonadConverter { |
|
|
|
output = MakeDepend(output, label_node); |
|
|
|
} |
|
|
|
|
|
|
|
// Replace the the switch node with the output. |
|
|
|
// Replace the the call/switch node with the output. |
|
|
|
kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output)); |
|
|
|
return {para_pool, output_para, return_label}; |
|
|
|
return {output_para, return_label}; |
|
|
|
} |
|
|
|
|
|
|
|
// Handle sub-graphs recursively. |
|
|
|
void HandleSubGraph(const KernelGraphPtr &graph, const ParameterPoolPtr ¶_pool, const AnfNodePtr &out_para, |
|
|
|
uint32_t return_label) { |
|
|
|
void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) { |
|
|
|
AscendAutoMonadConverter converter(&context_, graph); |
|
|
|
converter.para_pool_ = para_pool; |
|
|
|
converter.output_parameter_ = out_para; |
|
|
|
converter.return_label_ = return_label; |
|
|
|
converter.Run(); |
|
|
|
@@ -717,7 +638,6 @@ class AscendAutoMonadConverter { |
|
|
|
auto cnode = kernel_graph_->NewCNode({label_goto, monad}); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode); |
|
|
|
cnode->set_abstract(monad->abstract()); |
|
|
|
kernel_graph_->set_end_goto(cnode); // make 'goto' the last one in execute order. |
|
|
|
monad_ = cnode; |
|
|
|
return cnode; |
|
|
|
} |
|
|
|
@@ -794,11 +714,11 @@ class AscendAutoMonadConverter { |
|
|
|
// Parameter to store the return value. |
|
|
|
AnfNodePtr output_parameter_; |
|
|
|
|
|
|
|
// Parameter pool for output parameter allocation. |
|
|
|
ParameterPoolPtr para_pool_; |
|
|
|
|
|
|
|
// The return label id. |
|
|
|
uint32_t return_label_ = kNoLabel; |
|
|
|
|
|
|
|
// Is this graph include recursive calls. |
|
|
|
bool recursive_ = false; |
|
|
|
}; |
|
|
|
|
|
|
|
constexpr size_t kAssignTargetIndex = 1; |
|
|
|
@@ -851,20 +771,22 @@ class ExecuteOrderGenerator { |
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order; |
|
|
|
const auto &cnodes = graph_->execution_order(); |
|
|
|
for (auto cnode : cnodes) { |
|
|
|
for (auto &cnode : cnodes) { |
|
|
|
// Push current node to execution order list. |
|
|
|
execution_order.push_back(cnode); |
|
|
|
// For cnode with sub-graphs, such as LabelSwitch, LabelGoto, |
|
|
|
// Generate execute order for these sub-graphs, |
|
|
|
// and then append them to current execution order list. |
|
|
|
if (HasSubGraphs(cnode)) { |
|
|
|
// We use reversed order to generate sub-graph's execution order, |
|
|
|
// because the true branch of LabelSwitch is the second one, but |
|
|
|
// we want to make true branch ahead of false branch in the generated |
|
|
|
// execution order. |
|
|
|
auto sub_graphs = GetSubGraphs(cnode); |
|
|
|
for (auto iter = sub_graphs.rbegin(); iter != sub_graphs.rend(); iter++) { |
|
|
|
auto &sub_graph = *iter; |
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) { |
|
|
|
// For Switch, we use reversed order to generate sub-graph's execution order, |
|
|
|
// because the true branch of LabelSwitch is the second one, but |
|
|
|
// we want to make true branch ahead of false branch in the generated |
|
|
|
// execution order. |
|
|
|
std::reverse(sub_graphs.begin(), sub_graphs.end()); |
|
|
|
} |
|
|
|
for (auto &sub_graph : sub_graphs) { |
|
|
|
if (context_.IsVisited(sub_graph)) { |
|
|
|
// Skip visited sub-graphs. |
|
|
|
continue; |
|
|
|
|