Browse Source

[auto-monad] Prepare to support recursive calls

1. Find and mark recursive calls and graphs;
2. Do not eliminate argument Assign for recursive graphs;
3. Disable tail call optimization for recursive graphs;
4. Add attribute to output parameter Assign to disable stack push.
pull/13600/head
He Wei 4 years ago
parent
commit
dc8ec9d87f
1 changed files with 119 additions and 18 deletions
  1. +119
    -18
      mindspore/ccsrc/backend/session/ascend_auto_monad.cc

+ 119
- 18
mindspore/ccsrc/backend/session/ascend_auto_monad.cc View File

@@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0;
// Primitive attribute for argument link assign. // Primitive attribute for argument link assign.
const char LINK[] = "link"; const char LINK[] = "link";


// Attribute to indicate that the node should not be eliminated.
// Used to keep argument Assign nodes for recursive graphs.
const char KEEP[] = "keep";

// Attribute to indicate that this is an assign for output.
const char OUTPUT[] = "output";

bool IsSaveGraph() { bool IsSaveGraph() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
@@ -152,6 +159,9 @@ struct CallSite {
// Label param to index map. // Label param to index map.
std::map<AnfNodePtr, uint32_t> label_indexes; std::map<AnfNodePtr, uint32_t> label_indexes;


// True if this is a recursive call.
bool recursive = false;

// True if this is a tail call. // True if this is a tail call.
bool tail = false; bool tail = false;
}; };
@@ -161,9 +171,18 @@ struct ReturnPoint {
}; };


struct CallInfo { struct CallInfo {
// Call sites in current graph.
std::vector<CallSite> call_sites; std::vector<CallSite> call_sites;

// Return points of current graph.
std::vector<ReturnPoint> return_points; std::vector<ReturnPoint> return_points;

// Parameter to store label index, if there are
// multi return points, this should be set.
AnfNodePtr label_param = nullptr; AnfNodePtr label_param = nullptr;

// True if current graph is involved with recursive calls.
bool recursive = false;
}; };


// //
@@ -296,6 +315,8 @@ class CallInfoFinder {


void Run() { void Run() {
FindCallSites(); FindCallSites();
FindRecursiveCalls();
DisableTailCalls();
FindCallReturns(); FindCallReturns();
} }


@@ -340,11 +361,30 @@ class CallInfoFinder {
} }
} }


// Find call-return pairs.
void FindCallReturns() {
// Find recursive non-tail calls.
void FindRecursiveCalls() {
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
if (!call_site.tail) {
SearchRecursiveCall(caller, &call_site);
}
}
}
}

// Disable tail call optimization for recursive call graphs.
void DisableTailCalls() {
for (auto &entry : context_.call_info_map) { for (auto &entry : context_.call_info_map) {
auto &caller = entry.first;
auto &call_info = entry.second; auto &call_info = entry.second;
if (call_info.recursive && !call_info.call_sites.empty()) {
call_info.call_sites.back().tail = false;
}
}
}

// Find call-return pairs.
void FindCallReturns() {
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) { for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) { for (auto &callee : call_site.callees) {
MakeGraphLabel(callee.graph); MakeGraphLabel(callee.graph);
@@ -396,6 +436,54 @@ class CallInfoFinder {
} }
} }


struct SearchRecursiveContext {
const KernelGraphPtr &start_caller;
CallSite *start_site;
std::set<KernelGraphPtr> visited;
std::vector<KernelGraphPtr> call_path;
};

// Search recursive call from a call-site.
void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
DoSearchRecursiveCall(start_caller, start_site, &context);
}

void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) {
// Record call path.
ctx->call_path.push_back(graph);
// Handle callee graphs.
for (auto &callee : call_site->callees) {
auto &sub_graph = callee.graph;
if (sub_graph == ctx->start_caller) {
// Find a recursive call path.
for (auto &g : ctx->call_path) {
// Mark recursive for all graphs in call path.
context_.call_info_map[g].recursive = true;
}
// Mark recursive for the start call-site.
ctx->start_site->recursive = true;
continue;
}
if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
// Skip visited graphs.
continue;
}
// Mark visited.
ctx->visited.emplace(sub_graph);
// Check call sites in the sub-graph.
auto &call_info = context_.call_info_map[sub_graph];
auto &sites = call_info.call_sites;
for (auto &site : sites) {
if (!site.callees.empty()) {
DoSearchRecursiveCall(sub_graph, &site, ctx);
}
}
}
// Don't forget this.
ctx->call_path.pop_back();
}

// Handle a call-return relation. // Handle a call-return relation.
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) { void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
// Create a label for the return point. // Create a label for the return point.
@@ -590,7 +678,7 @@ class AscendAutoMonadConverter {
// For multi-return call, assign result from temp parameter to // For multi-return call, assign result from temp parameter to
// output parameter, this prevent result be overwritten by next call. // output parameter, this prevent result be overwritten by next call.
auto tmp_param = context_.GetTempParameter(output->abstract()); auto tmp_param = context_.GetTempParameter(output->abstract());
output = AssignAll(output, tmp_param);
output = AssignAll(output, tmp_param, false, false, true);
monad_ = UpdateState(GetMonad(), output); monad_ = UpdateState(GetMonad(), output);
} }
// Replace the the call/switch node with the output. // Replace the the call/switch node with the output.
@@ -611,7 +699,7 @@ class AscendAutoMonadConverter {
void AssignLabelIndexes(const CallSite &call_site) { void AssignLabelIndexes(const CallSite &call_site) {
for (auto &[label_param, label_index] : call_site.label_indexes) { for (auto &[label_param, label_index] : call_site.label_indexes) {
auto index_value = GetIndexValueNode(label_index); auto index_value = GetIndexValueNode(label_index);
auto assign = Assign(label_param, index_value);
auto assign = Assign(label_param, index_value, false, false, false);
monad_ = UpdateState(GetMonad(), assign); monad_ = UpdateState(GetMonad(), assign);
} }
} }
@@ -708,7 +796,7 @@ class AscendAutoMonadConverter {
AnfNodePtr out_param = AnfNodePtr out_param =
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
MS_EXCEPTION_IF_NULL(out_param); MS_EXCEPTION_IF_NULL(out_param);
auto assign_output = AssignAll(out_param, kernel_graph_->output());
auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true);
monad_ = UpdateState(GetMonad(), assign_output); monad_ = UpdateState(GetMonad(), assign_output);
} }


@@ -739,6 +827,8 @@ class AscendAutoMonadConverter {
if (args.empty()) { if (args.empty()) {
return nullptr; return nullptr;
} }
// We do not eliminate argument Assign for recursive graphs.
const bool keep = IsRecursive(graph);
// Single argument. // Single argument.
if (args.size() == 1) { if (args.size() == 1) {
auto &value = args.front(); auto &value = args.front();
@@ -746,7 +836,7 @@ class AscendAutoMonadConverter {
// No assign for single monad argument, return it. // No assign for single monad argument, return it.
return value; return value;
} }
return AssignAll(paras.front(), value, true);
return AssignAll(paras.front(), value, true, keep, false);
} }
// Multi arguments. // Multi arguments.
AnfNodePtrList tuple_inputs; AnfNodePtrList tuple_inputs;
@@ -764,11 +854,14 @@ class AscendAutoMonadConverter {
if (target == value) { if (target == value) {
continue; continue;
} }
tuple_inputs.emplace_back(AssignAll(target, value, true));
tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false));
} }
return kernel_graph_->NewCNode(tuple_inputs); return kernel_graph_->NewCNode(tuple_inputs);
} }


// Return true if the graph is involved with recursive calls.
bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }

// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }


@@ -780,13 +873,21 @@ class AscendAutoMonadConverter {
} }


// Make a assign cnode. // Make a assign cnode.
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
auto monad = (is_link ? GetLinkMonad() : GetMonad());
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
auto monad = (link ? GetLinkMonad() : GetMonad());
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
if (is_link) {
if (link) {
// Mark this assign is to link real argument to formal argument. // Mark this assign is to link real argument to formal argument.
assign_prim->set_attr(LINK, prim::kValueOne); assign_prim->set_attr(LINK, prim::kValueOne);
} }
if (keep) {
// Mark this assign should not be eliminated.
assign_prim->set_attr(KEEP, prim::kValueOne);
}
if (output) {
// Mark this assign is used for output parameter.
assign_prim->set_attr(OUTPUT, prim::kValueOne);
}
auto assign = NewValueNode(assign_prim); auto assign = NewValueNode(assign_prim);
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad}); auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
cnode->set_abstract(target->abstract()); cnode->set_abstract(target->abstract());
@@ -794,10 +895,10 @@ class AscendAutoMonadConverter {
} }


// AissgnAll support tuple to tuple assign. // AissgnAll support tuple to tuple assign.
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) { if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
// Assign single value. // Assign single value.
return Assign(target, source, is_link);
return Assign(target, source, link, keep, output);
} }
// Assign tuple. // Assign tuple.
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
@@ -809,7 +910,7 @@ class AscendAutoMonadConverter {
tuple_inputs.reserve(targets.size() + 1); tuple_inputs.reserve(targets.size() + 1);
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < targets.size(); ++i) { for (size_t i = 0; i < targets.size(); ++i) {
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link));
tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
} }
return kernel_graph_->NewCNode(tuple_inputs); return kernel_graph_->NewCNode(tuple_inputs);
} }
@@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator {
auto &node = *iter; auto &node = *iter;
// We only try to erase argument link assign nodes, // We only try to erase argument link assign nodes,
// other assign nodes are skipped. // other assign nodes are skipped.
if (IsLinkAssign(node)) {
if (IsOptimizableAssign(node)) {
auto &target = node->inputs().at(kAssignTargetIndex); auto &target = node->inputs().at(kAssignTargetIndex);
MS_EXCEPTION_IF_NULL(target); MS_EXCEPTION_IF_NULL(target);
auto para = param_write_times.find(target); auto para = param_write_times.find(target);
@@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator {
return param_write_times; return param_write_times;
} }


// Check if a node is an assign for argument link.
bool IsLinkAssign(const AnfNodePtr &node) {
// Check if a node is an assign for argument link and can be optimized.
bool IsOptimizableAssign(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node); auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) { if (cnode == nullptr) {
return false; return false;
@@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator {
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) { if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
return false; return false;
} }
return prim->GetAttr(LINK) == prim::kValueOne;
return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne);
} }


// Erase LabelGoto and LabelSet // Erase LabelGoto and LabelSet


Loading…
Cancel
Save