@@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0;
// Primitive attribute for argument link assign.
// Primitive attribute for argument link assign.
const char LINK[] = "link";
const char LINK[] = "link";
// Attribute to indicate that the node should not be eliminated.
// Used to keep argument Assign nodes for recursive graphs.
const char KEEP[] = "keep";
// Attribute to indicate that this is an assign for output.
const char OUTPUT[] = "output";
bool IsSaveGraph() {
bool IsSaveGraph() {
auto context_ptr = MsContext::GetInstance();
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(context_ptr);
@@ -152,6 +159,9 @@ struct CallSite {
// Label param to index map.
// Label param to index map.
std::map<AnfNodePtr, uint32_t> label_indexes;
std::map<AnfNodePtr, uint32_t> label_indexes;
// True if this is a recursive call.
bool recursive = false;
// True if this is a tail call.
// True if this is a tail call.
bool tail = false;
bool tail = false;
};
};
@@ -161,9 +171,18 @@ struct ReturnPoint {
};
};
struct CallInfo {
struct CallInfo {
// Call sites in current graph.
std::vector<CallSite> call_sites;
std::vector<CallSite> call_sites;
// Return points of current graph.
std::vector<ReturnPoint> return_points;
std::vector<ReturnPoint> return_points;
// Parameter to store label index, if there are
// multi return points, this should be set.
AnfNodePtr label_param = nullptr;
AnfNodePtr label_param = nullptr;
// True if current graph is involved with recursive calls.
bool recursive = false;
};
};
//
//
@@ -296,6 +315,8 @@ class CallInfoFinder {
void Run() {
void Run() {
FindCallSites();
FindCallSites();
FindRecursiveCalls();
DisableTailCalls();
FindCallReturns();
FindCallReturns();
}
}
@@ -340,11 +361,30 @@ class CallInfoFinder {
}
}
}
}
// Find call-return pairs.
void FindCallReturns() {
// Find recursive non-tail calls.
void FindRecursiveCalls() {
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
if (!call_site.tail) {
SearchRecursiveCall(caller, &call_site);
}
}
}
}
// Disable tail call optimization for recursive call graphs.
void DisableTailCalls() {
for (auto &entry : context_.call_info_map) {
for (auto &entry : context_.call_info_map) {
auto &caller = entry.first;
auto &call_info = entry.second;
auto &call_info = entry.second;
if (call_info.recursive && !call_info.call_sites.empty()) {
call_info.call_sites.back().tail = false;
}
}
}
// Find call-return pairs.
void FindCallReturns() {
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) {
for (auto &callee : call_site.callees) {
MakeGraphLabel(callee.graph);
MakeGraphLabel(callee.graph);
@@ -396,6 +436,54 @@ class CallInfoFinder {
}
}
}
}
struct SearchRecursiveContext {
const KernelGraphPtr &start_caller;
CallSite *start_site;
std::set<KernelGraphPtr> visited;
std::vector<KernelGraphPtr> call_path;
};
// Search recursive call from a call-site.
void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
DoSearchRecursiveCall(start_caller, start_site, &context);
}
void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) {
// Record call path.
ctx->call_path.push_back(graph);
// Handle callee graphs.
for (auto &callee : call_site->callees) {
auto &sub_graph = callee.graph;
if (sub_graph == ctx->start_caller) {
// Find a recursive call path.
for (auto &g : ctx->call_path) {
// Mark recursive for all graphs in call path.
context_.call_info_map[g].recursive = true;
}
// Mark recursive for the start call-site.
ctx->start_site->recursive = true;
continue;
}
if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
// Skip visited graphs.
continue;
}
// Mark visited.
ctx->visited.emplace(sub_graph);
// Check call sites in the sub-graph.
auto &call_info = context_.call_info_map[sub_graph];
auto &sites = call_info.call_sites;
for (auto &site : sites) {
if (!site.callees.empty()) {
DoSearchRecursiveCall(sub_graph, &site, ctx);
}
}
}
// Don't forget this.
ctx->call_path.pop_back();
}
// Handle a call-return relation.
// Handle a call-return relation.
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
// Create a label for the return point.
// Create a label for the return point.
@@ -590,7 +678,7 @@ class AscendAutoMonadConverter {
// For multi-return call, assign result from temp parameter to
// For multi-return call, assign result from temp parameter to
// output parameter, this prevent result be overwritten by next call.
// output parameter, this prevent result be overwritten by next call.
auto tmp_param = context_.GetTempParameter(output->abstract());
auto tmp_param = context_.GetTempParameter(output->abstract());
output = AssignAll(output, tmp_param);
output = AssignAll(output, tmp_param, false, false, true );
monad_ = UpdateState(GetMonad(), output);
monad_ = UpdateState(GetMonad(), output);
}
}
// Replace the the call/switch node with the output.
// Replace the the call/switch node with the output.
@@ -611,7 +699,7 @@ class AscendAutoMonadConverter {
void AssignLabelIndexes(const CallSite &call_site) {
void AssignLabelIndexes(const CallSite &call_site) {
for (auto &[label_param, label_index] : call_site.label_indexes) {
for (auto &[label_param, label_index] : call_site.label_indexes) {
auto index_value = GetIndexValueNode(label_index);
auto index_value = GetIndexValueNode(label_index);
auto assign = Assign(label_param, index_value);
auto assign = Assign(label_param, index_value, false, false, false );
monad_ = UpdateState(GetMonad(), assign);
monad_ = UpdateState(GetMonad(), assign);
}
}
}
}
@@ -708,7 +796,7 @@ class AscendAutoMonadConverter {
AnfNodePtr out_param =
AnfNodePtr out_param =
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
MS_EXCEPTION_IF_NULL(out_param);
MS_EXCEPTION_IF_NULL(out_param);
auto assign_output = AssignAll(out_param, kernel_graph_->output());
auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true );
monad_ = UpdateState(GetMonad(), assign_output);
monad_ = UpdateState(GetMonad(), assign_output);
}
}
@@ -739,6 +827,8 @@ class AscendAutoMonadConverter {
if (args.empty()) {
if (args.empty()) {
return nullptr;
return nullptr;
}
}
// We do not eliminate argument Assign for recursive graphs.
const bool keep = IsRecursive(graph);
// Single argument.
// Single argument.
if (args.size() == 1) {
if (args.size() == 1) {
auto &value = args.front();
auto &value = args.front();
@@ -746,7 +836,7 @@ class AscendAutoMonadConverter {
// No assign for single monad argument, return it.
// No assign for single monad argument, return it.
return value;
return value;
}
}
return AssignAll(paras.front(), value, true);
return AssignAll(paras.front(), value, true, keep, false );
}
}
// Multi arguments.
// Multi arguments.
AnfNodePtrList tuple_inputs;
AnfNodePtrList tuple_inputs;
@@ -764,11 +854,14 @@ class AscendAutoMonadConverter {
if (target == value) {
if (target == value) {
continue;
continue;
}
}
tuple_inputs.emplace_back(AssignAll(target, value, true));
tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false ));
}
}
return kernel_graph_->NewCNode(tuple_inputs);
return kernel_graph_->NewCNode(tuple_inputs);
}
}
// Return true if the graph is involved with recursive calls.
bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
@@ -780,13 +873,21 @@ class AscendAutoMonadConverter {
}
}
// Make a assign cnode.
// Make a assign cnode.
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false ) {
auto monad = (is_ link ? GetLinkMonad() : GetMonad());
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output ) {
auto monad = (link ? GetLinkMonad() : GetMonad());
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
if (is_ link) {
if (link) {
// Mark this assign is to link real argument to formal argument.
// Mark this assign is to link real argument to formal argument.
assign_prim->set_attr(LINK, prim::kValueOne);
assign_prim->set_attr(LINK, prim::kValueOne);
}
}
if (keep) {
// Mark this assign should not be eliminated.
assign_prim->set_attr(KEEP, prim::kValueOne);
}
if (output) {
// Mark this assign is used for output parameter.
assign_prim->set_attr(OUTPUT, prim::kValueOne);
}
auto assign = NewValueNode(assign_prim);
auto assign = NewValueNode(assign_prim);
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
cnode->set_abstract(target->abstract());
cnode->set_abstract(target->abstract());
@@ -794,10 +895,10 @@ class AscendAutoMonadConverter {
}
}
// AissgnAll support tuple to tuple assign.
// AissgnAll support tuple to tuple assign.
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false ) {
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output ) {
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
// Assign single value.
// Assign single value.
return Assign(target, source, is_ link);
return Assign(target, source, link, keep, output );
}
}
// Assign tuple.
// Assign tuple.
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
@@ -809,7 +910,7 @@ class AscendAutoMonadConverter {
tuple_inputs.reserve(targets.size() + 1);
tuple_inputs.reserve(targets.size() + 1);
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < targets.size(); ++i) {
for (size_t i = 0; i < targets.size(); ++i) {
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_ link));
tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output ));
}
}
return kernel_graph_->NewCNode(tuple_inputs);
return kernel_graph_->NewCNode(tuple_inputs);
}
}
@@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator {
auto &node = *iter;
auto &node = *iter;
// We only try to erase argument link assign nodes,
// We only try to erase argument link assign nodes,
// other assign nodes are skipped.
// other assign nodes are skipped.
if (IsLink Assign(node)) {
if (IsOptimizable Assign(node)) {
auto &target = node->inputs().at(kAssignTargetIndex);
auto &target = node->inputs().at(kAssignTargetIndex);
MS_EXCEPTION_IF_NULL(target);
MS_EXCEPTION_IF_NULL(target);
auto para = param_write_times.find(target);
auto para = param_write_times.find(target);
@@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator {
return param_write_times;
return param_write_times;
}
}
// Check if a node is an assign for argument link.
bool IsLink Assign(const AnfNodePtr &node) {
// Check if a node is an assign for argument link and can be optimized .
bool IsOptimizable Assign(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
if (cnode == nullptr) {
return false;
return false;
@@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator {
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
return false;
return false;
}
}
return prim->GetAttr(LINK) == prim::kValueOne;
return ( prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne) ;
}
}
// Erase LabelGoto and LabelSet
// Erase LabelGoto and LabelSet