|
|
|
@@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0; |
|
|
|
// Primitive attribute for argument link assign. |
|
|
|
const char LINK[] = "link"; |
|
|
|
|
|
|
|
// Attribute to indicate that the node should not be eliminated. |
|
|
|
// Used to keep argument Assign nodes for recursive graphs. |
|
|
|
const char KEEP[] = "keep"; |
|
|
|
|
|
|
|
// Attribute to indicate that this is an assign for output. |
|
|
|
const char OUTPUT[] = "output"; |
|
|
|
|
|
|
|
bool IsSaveGraph() { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
@@ -152,6 +159,9 @@ struct CallSite { |
|
|
|
// Label param to index map. |
|
|
|
std::map<AnfNodePtr, uint32_t> label_indexes; |
|
|
|
|
|
|
|
// True if this is a recursive call. |
|
|
|
bool recursive = false; |
|
|
|
|
|
|
|
// True if this is a tail call. |
|
|
|
bool tail = false; |
|
|
|
}; |
|
|
|
@@ -161,9 +171,18 @@ struct ReturnPoint { |
|
|
|
}; |
|
|
|
|
|
|
|
struct CallInfo { |
|
|
|
// Call sites in current graph. |
|
|
|
std::vector<CallSite> call_sites; |
|
|
|
|
|
|
|
// Return points of current graph. |
|
|
|
std::vector<ReturnPoint> return_points; |
|
|
|
|
|
|
|
// Parameter to store label index, if there are |
|
|
|
// multi return points, this should be set. |
|
|
|
AnfNodePtr label_param = nullptr; |
|
|
|
|
|
|
|
// True if current graph is involved with recursive calls. |
|
|
|
bool recursive = false; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
@@ -296,6 +315,8 @@ class CallInfoFinder { |
|
|
|
|
|
|
|
void Run() { |
|
|
|
FindCallSites(); |
|
|
|
FindRecursiveCalls(); |
|
|
|
DisableTailCalls(); |
|
|
|
FindCallReturns(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -340,11 +361,30 @@ class CallInfoFinder { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Find call-return pairs. |
|
|
|
void FindCallReturns() { |
|
|
|
// Find recursive non-tail calls. |
|
|
|
void FindRecursiveCalls() { |
|
|
|
for (auto &[caller, call_info] : context_.call_info_map) { |
|
|
|
for (auto &call_site : call_info.call_sites) { |
|
|
|
if (!call_site.tail) { |
|
|
|
SearchRecursiveCall(caller, &call_site); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Disable tail call optimization for recursive call graphs. |
|
|
|
void DisableTailCalls() { |
|
|
|
for (auto &entry : context_.call_info_map) { |
|
|
|
auto &caller = entry.first; |
|
|
|
auto &call_info = entry.second; |
|
|
|
if (call_info.recursive && !call_info.call_sites.empty()) { |
|
|
|
call_info.call_sites.back().tail = false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Find call-return pairs. |
|
|
|
void FindCallReturns() { |
|
|
|
for (auto &[caller, call_info] : context_.call_info_map) { |
|
|
|
for (auto &call_site : call_info.call_sites) { |
|
|
|
for (auto &callee : call_site.callees) { |
|
|
|
MakeGraphLabel(callee.graph); |
|
|
|
@@ -396,6 +436,54 @@ class CallInfoFinder { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
struct SearchRecursiveContext { |
|
|
|
const KernelGraphPtr &start_caller; |
|
|
|
CallSite *start_site; |
|
|
|
std::set<KernelGraphPtr> visited; |
|
|
|
std::vector<KernelGraphPtr> call_path; |
|
|
|
}; |
|
|
|
|
|
|
|
// Search recursive call from a call-site. |
|
|
|
void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) { |
|
|
|
SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site}; |
|
|
|
DoSearchRecursiveCall(start_caller, start_site, &context); |
|
|
|
} |
|
|
|
|
|
|
|
void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) { |
|
|
|
// Record call path. |
|
|
|
ctx->call_path.push_back(graph); |
|
|
|
// Handle callee graphs. |
|
|
|
for (auto &callee : call_site->callees) { |
|
|
|
auto &sub_graph = callee.graph; |
|
|
|
if (sub_graph == ctx->start_caller) { |
|
|
|
// Find a recursive call path. |
|
|
|
for (auto &g : ctx->call_path) { |
|
|
|
// Mark recursive for all graphs in call path. |
|
|
|
context_.call_info_map[g].recursive = true; |
|
|
|
} |
|
|
|
// Mark recursive for the start call-site. |
|
|
|
ctx->start_site->recursive = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (ctx->visited.find(sub_graph) != ctx->visited.end()) { |
|
|
|
// Skip visited graphs. |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Mark visited. |
|
|
|
ctx->visited.emplace(sub_graph); |
|
|
|
// Check call sites in the sub-graph. |
|
|
|
auto &call_info = context_.call_info_map[sub_graph]; |
|
|
|
auto &sites = call_info.call_sites; |
|
|
|
for (auto &site : sites) { |
|
|
|
if (!site.callees.empty()) { |
|
|
|
DoSearchRecursiveCall(sub_graph, &site, ctx); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// Don't forget this. |
|
|
|
ctx->call_path.pop_back(); |
|
|
|
} |
|
|
|
|
|
|
|
// Handle a call-return relation. |
|
|
|
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) { |
|
|
|
// Create a label for the return point. |
|
|
|
@@ -590,7 +678,7 @@ class AscendAutoMonadConverter { |
|
|
|
// For multi-return call, assign result from temp parameter to |
|
|
|
// output parameter, this prevent result be overwritten by next call. |
|
|
|
auto tmp_param = context_.GetTempParameter(output->abstract()); |
|
|
|
output = AssignAll(output, tmp_param); |
|
|
|
output = AssignAll(output, tmp_param, false, false, true); |
|
|
|
monad_ = UpdateState(GetMonad(), output); |
|
|
|
} |
|
|
|
// Replace the the call/switch node with the output. |
|
|
|
@@ -611,7 +699,7 @@ class AscendAutoMonadConverter { |
|
|
|
void AssignLabelIndexes(const CallSite &call_site) { |
|
|
|
for (auto &[label_param, label_index] : call_site.label_indexes) { |
|
|
|
auto index_value = GetIndexValueNode(label_index); |
|
|
|
auto assign = Assign(label_param, index_value); |
|
|
|
auto assign = Assign(label_param, index_value, false, false, false); |
|
|
|
monad_ = UpdateState(GetMonad(), assign); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -708,7 +796,7 @@ class AscendAutoMonadConverter { |
|
|
|
AnfNodePtr out_param = |
|
|
|
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); |
|
|
|
MS_EXCEPTION_IF_NULL(out_param); |
|
|
|
auto assign_output = AssignAll(out_param, kernel_graph_->output()); |
|
|
|
auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true); |
|
|
|
monad_ = UpdateState(GetMonad(), assign_output); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -739,6 +827,8 @@ class AscendAutoMonadConverter { |
|
|
|
if (args.empty()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// We do not eliminate argument Assign for recursive graphs. |
|
|
|
const bool keep = IsRecursive(graph); |
|
|
|
// Single argument. |
|
|
|
if (args.size() == 1) { |
|
|
|
auto &value = args.front(); |
|
|
|
@@ -746,7 +836,7 @@ class AscendAutoMonadConverter { |
|
|
|
// No assign for single monad argument, return it. |
|
|
|
return value; |
|
|
|
} |
|
|
|
return AssignAll(paras.front(), value, true); |
|
|
|
return AssignAll(paras.front(), value, true, keep, false); |
|
|
|
} |
|
|
|
// Multi arguments. |
|
|
|
AnfNodePtrList tuple_inputs; |
|
|
|
@@ -764,11 +854,14 @@ class AscendAutoMonadConverter { |
|
|
|
if (target == value) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
tuple_inputs.emplace_back(AssignAll(target, value, true)); |
|
|
|
tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false)); |
|
|
|
} |
|
|
|
return kernel_graph_->NewCNode(tuple_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
// Return true if the graph is involved with recursive calls. |
|
|
|
bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; } |
|
|
|
|
|
|
|
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. |
|
|
|
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } |
|
|
|
|
|
|
|
@@ -780,13 +873,21 @@ class AscendAutoMonadConverter { |
|
|
|
} |
|
|
|
|
|
|
|
// Make a assign cnode. |
|
|
|
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { |
|
|
|
auto monad = (is_link ? GetLinkMonad() : GetMonad()); |
|
|
|
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) { |
|
|
|
auto monad = (link ? GetLinkMonad() : GetMonad()); |
|
|
|
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); |
|
|
|
if (is_link) { |
|
|
|
if (link) { |
|
|
|
// Mark this assign is to link real argument to formal argument. |
|
|
|
assign_prim->set_attr(LINK, prim::kValueOne); |
|
|
|
} |
|
|
|
if (keep) { |
|
|
|
// Mark this assign should not be eliminated. |
|
|
|
assign_prim->set_attr(KEEP, prim::kValueOne); |
|
|
|
} |
|
|
|
if (output) { |
|
|
|
// Mark this assign is used for output parameter. |
|
|
|
assign_prim->set_attr(OUTPUT, prim::kValueOne); |
|
|
|
} |
|
|
|
auto assign = NewValueNode(assign_prim); |
|
|
|
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad}); |
|
|
|
cnode->set_abstract(target->abstract()); |
|
|
|
@@ -794,10 +895,10 @@ class AscendAutoMonadConverter { |
|
|
|
} |
|
|
|
|
|
|
|
// AissgnAll support tuple to tuple assign. |
|
|
|
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { |
|
|
|
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) { |
|
|
|
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) { |
|
|
|
// Assign single value. |
|
|
|
return Assign(target, source, is_link); |
|
|
|
return Assign(target, source, link, keep, output); |
|
|
|
} |
|
|
|
// Assign tuple. |
|
|
|
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem}); |
|
|
|
@@ -809,7 +910,7 @@ class AscendAutoMonadConverter { |
|
|
|
tuple_inputs.reserve(targets.size() + 1); |
|
|
|
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); |
|
|
|
for (size_t i = 0; i < targets.size(); ++i) { |
|
|
|
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link)); |
|
|
|
tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output)); |
|
|
|
} |
|
|
|
return kernel_graph_->NewCNode(tuple_inputs); |
|
|
|
} |
|
|
|
@@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator { |
|
|
|
auto &node = *iter; |
|
|
|
// We only try to erase argument link assign nodes, |
|
|
|
// other assign nodes are skipped. |
|
|
|
if (IsLinkAssign(node)) { |
|
|
|
if (IsOptimizableAssign(node)) { |
|
|
|
auto &target = node->inputs().at(kAssignTargetIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(target); |
|
|
|
auto para = param_write_times.find(target); |
|
|
|
@@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator { |
|
|
|
return param_write_times; |
|
|
|
} |
|
|
|
|
|
|
|
// Check if a node is an assign for argument link. |
|
|
|
bool IsLinkAssign(const AnfNodePtr &node) { |
|
|
|
// Check if a node is an assign for argument link and can be optimized. |
|
|
|
bool IsOptimizableAssign(const AnfNodePtr &node) { |
|
|
|
auto cnode = dyn_cast<CNode>(node); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return false; |
|
|
|
@@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator { |
|
|
|
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return prim->GetAttr(LINK) == prim::kValueOne; |
|
|
|
return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne); |
|
|
|
} |
|
|
|
|
|
|
|
// Erase LabelGoto and LabelSet |
|
|
|
|