Browse Source

!13196 [auto-monad] Support multi return points

From: @hwhewei
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
c35d2d0899
1 changed files with 419 additions and 274 deletions
  1. +419
    -274
      mindspore/ccsrc/backend/session/ascend_auto_monad.cc

+ 419
- 274
mindspore/ccsrc/backend/session/ascend_auto_monad.cc View File

@@ -25,6 +25,7 @@
#include <memory>
#include <algorithm>
#include "utils/ms_context.h"
#include "utils/ordered_map.h"
#include "base/core_ops.h"
#include "debug/anf_ir_dump.h"
#include "pipeline/jit/base.h"
@@ -118,6 +119,53 @@ void DumpExecuteOrder(NotNull<KernelGraphPtr> kg) {
fout.close();
}

// Return kNoLabel when label id attribute not set for the graph.
uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
auto value = kg->get_attr(kAttrLabelIndex);
if (value == nullptr) {
return kNoLabel;
}
return GetValue<uint32_t>(value);
}

struct CallBranch {
KernelGraphPtr graph;
std::vector<AnfNodePtr> args;
};

struct CallSite {
// Call/Switch/SwitchLayer
CNodePtr cnode;

// The last monad before call.
AnfNodePtr last_monad = nullptr;

// Branch graph called.
std::vector<CallBranch> callees;

// Parameter for return value.
AnfNodePtr out_param = nullptr;

// Label id for return.
uint32_t return_label = kNoLabel;

// Label param to index map.
std::map<AnfNodePtr, uint32_t> label_indexes;

// True if this is a tail call.
bool tail = false;
};

struct ReturnPoint {
CallSite *call_site = nullptr;
};

struct CallInfo {
std::vector<CallSite> call_sites;
std::vector<ReturnPoint> return_points;
AnfNodePtr label_param = nullptr;
};

class BaseContext {
public:
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
@@ -126,12 +174,14 @@ class BaseContext {

const std::set<KernelGraphPtr> &visited_graphs() const { return visited_graphs_; }

void ClearVisited() { visited_graphs_.clear(); }

private:
std::set<KernelGraphPtr> visited_graphs_;
};

//
// AscendAutoMonadContext holds some shared states during auto-moand.
// AscendAutoMonadContext holds some shared states during auto-monad.
//
class AscendAutoMonadContext : public BaseContext {
public:
@@ -144,30 +194,20 @@ class AscendAutoMonadContext : public BaseContext {
// Current label id, also the number of label ids we currently used.
uint32_t CurrentLabel() const { return label_id_; }

// Create or get a parameter for output of the kernel graph.
AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) {
// Find output parameter by kernel graph.
auto iter = kg_out_param_.find(kg);
if (iter != kg_out_param_.end()) {
// Return output parameter if found.
return iter->second;
}
// Create a new one if not found.
// Output parameters are all created on top graph.
auto para = top_graph_->NewParameter(kg->output()->abstract());
// Create a new parameter.
// Output parameters are all created on top graph.
AnfNodePtr CreateParameter(const AbstractBasePtr &abs) {
auto para = top_graph_->NewParameter(abs);
auto out_para = top_graph_->TransTupleToMakeTuple(para);
// This is required, so that device memory can be allocated for it.
top_graph_->AddChildGraphResult(out_para);
// Save new para as the output parameter of the kg.
kg_out_param_.emplace(kg, out_para);
return out_para;
}

// Set output parameter for a kernel graph.
void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) {
// Save new para as the output parameter of the kg.
kg_out_param_.emplace(kg, out_para);
}
const KernelGraphPtr &TopGraph() const { return top_graph_; }

// Map kernel_graph to its call info.
OrderedMap<KernelGraphPtr, CallInfo> call_info_map;

private:
// The top graph.
@@ -181,49 +221,34 @@ class AscendAutoMonadContext : public BaseContext {
};

//
// AscendAutoMonadConverter convert control flow to monad form
// for a kernel graph and its children graphs recursively.
// Call info finder finds graph call information.
//
class AscendAutoMonadConverter {
class CallInfoFinder {
public:
AscendAutoMonadConverter(AscendAutoMonadContext *context, const KernelGraphPtr &kg)
: context_(*context), kernel_graph_(kg) {}
static void Run(AscendAutoMonadContext *context) {
CallInfoFinder finder(context->TopGraph(), context);
finder.Run();
}

~AscendAutoMonadConverter() = default;
private:
CallInfoFinder(const KernelGraphPtr &kg, AscendAutoMonadContext *context) : kernel_graph_(kg), context_(*context) {}
~CallInfoFinder() = default;

void Run() {
// Skip if graph already visited.
if (context_.IsVisited(kernel_graph_)) {
FindCallSites();
FindCallReturns();
}

// Find all call sites.
void FindCallSites() {
auto call_info = CreateCallInfo();
if (call_info == nullptr) {
// Skip if call_info for this graph already existed.
return;
}
context_.MarkVisited(kernel_graph_);

// Update directly called sub-graphs.
kernel_graph_->UpdateChildGraphOrder();

Prepare();

// Setup entry label if needed.
auto entry_label = GetGraphLabel(kernel_graph_);
if (entry_label != kNoLabel) {
SetupEntryLabel(entry_label);
}

// Handle call and switch nodes.
HandleCallSwitch();

// Let output depend on monad.
if (monad_) {
MakeMonadDepend();
}
}

private:
//
// Prepare information for control flow processing.
//
void Prepare() {
recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive);
// Find Call/Switch/SwitchLayer nodes, and make CallSites for them.
AnfNodePtr last_monad = nullptr;
auto nodes = TopoSort(kernel_graph_->output());
for (auto &node : nodes) {
@@ -231,243 +256,387 @@ class AscendAutoMonadConverter {
if (HasAbstractUMonad(node)) {
// Found a node with UMonad abstract, set it as the last monad.
last_monad = node;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
MakeCallSite(node->cast<CNodePtr>(), last_monad, call_info);
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
MakeSwitchCallSite(node->cast<CNodePtr>(), last_monad, call_info);
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (cnode->size() < 1) {
MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl;
}
// Set the last call as tail call if it is the output node.
// We don't set tail call for top graph because return is always required.
if (kernel_graph_ != context_.TopGraph() && !call_info->call_sites.empty()) {
auto real_output = GetRealNode(kernel_graph_->output());
if (real_output == call_info->call_sites.back().cnode) {
call_info->call_sites.back().tail = true;
}
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
// Found call/switch/switchlayer node, set it as the tail call node.
tail_call_node_ = cnode;
call_switch_nodes_.emplace_back(cnode);
monad_map_.emplace(cnode, last_monad);
} else if (tail_call_node_ != nullptr && AnfAlgo::IsRealKernel(cnode)) {
// Set no tail call if we found real kernel cnode after call/switch.
tail_call_node_ = nullptr;
}
// Recursively find CallSites from sub-graphs.
for (auto &call_site : call_info->call_sites) {
for (auto &callee : call_site.callees) {
CallInfoFinder finder(callee.graph, &context_);
finder.FindCallSites();
}
}
}

//
// Handle call and switch node, return true if tail call found.
//
void HandleCallSwitch() {
// Handle call switch nodes.
for (auto &cnode : call_switch_nodes_) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
HandleCall(cnode);
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
HandleSwitch(cnode);
} else {
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString();
// Find call-return pairs.
void FindCallReturns() {
for (auto &entry : context_.call_info_map) {
auto &caller = entry.first;
auto &call_info = entry.second;
for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) {
MakeGraphLabel(callee.graph);
}
if (!call_site.tail) {
SearchCallReturns(caller, &call_site);
}
}
}
// If no tail call, assign output value to output parameter,
// and then goto the return label if set.
if (tail_call_node_ == nullptr || recursive_) {
if (output_parameter_) {
auto assign_output = AssignAll(output_parameter_, kernel_graph_->output());
monad_ = UpdateState(GetMonad(), assign_output);
}

// Create entry label for the given graph if not set.
void MakeGraphLabel(const KernelGraphPtr &kg) {
auto label = GetGraphLabel(kg);
if (label == kNoLabel) {
// Allocate a new label id and save it to the graph.
label = context_.NewLabel();
kg->set_attr(kAttrLabelIndex, MakeValue(label));
}
}

// Search return points for all non-tail calls.
void SearchCallReturns(const KernelGraphPtr &caller, CallSite *call_site) {
std::set<KernelGraphPtr> visited = {caller};
std::queue<CallSite *> call_sites;
call_sites.push(call_site);
while (!call_sites.empty()) {
auto site = call_sites.front();
call_sites.pop();
for (auto &callee : site->callees) {
auto &kg = callee.graph;
if (visited.find(kg) != visited.end()) {
// Skip visited graphs.
continue;
}
// Mark visited.
visited.emplace(kg);
// Check callee.
auto &call_info = context_.call_info_map[kg];
auto &sites = call_info.call_sites;
if (!sites.empty() && sites.back().tail) {
// Follow tail call.
call_sites.push(&sites.back());
} else {
// Find a call-return relation.
HandleCallReturn(caller, call_site, kg);
}
}
if (return_label_ != kNoLabel) {
// Insert label_goto for return.
auto return_goto = LabelGoto(return_label_);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
kernel_graph_->set_end_goto(return_goto);
}
}

// Handle a call-return relation.
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
// Create a label for the return point.
if (call_site->return_label == kNoLabel) {
call_site->return_label = context_.NewLabel();
}
// Create a parameter for the return value.
if (call_site->out_param == nullptr) {
call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
}
// Add a return point for the callee graph.
auto &call_info = context_.call_info_map[callee];
auto &return_point = call_info.return_points.emplace_back();
return_point.call_site = call_site;

// Setup label index if there are multi return points.
const auto n_return_points = call_info.return_points.size();
if (n_return_points > 1) {
if (n_return_points == 2) {
// Create a parameter to store label index.
const ShapeVector shape = {1};
auto abs = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
call_info.label_param = context_.CreateParameter(abs);
// Add label index for the first call site.
call_info.return_points.front().call_site->label_indexes.emplace(call_info.label_param, 0);
}
// Add label index for the current call site.
auto label_index = static_cast<uint32_t>(call_info.return_points.size() - 1);
call_site->label_indexes.emplace(call_info.label_param, label_index);
}
}

//
// Convert call node:
// out = Call(graph, arg)
// to:
// r = link_args(graph.para, arg, c)
// c = UpdateState(c, r)
// c = LabelGoto(c) : L1
//
void HandleCall(const CNodePtr &cnode) {
// Update last_monad_.
last_monad_ = monad_map_[cnode];
// Create a CallInfo for current kernel graph, return null if it is already existed.
CallInfo *CreateCallInfo() {
auto [iter, ok] = context_.call_info_map.add(kernel_graph_);
if (!ok) {
// CallInfo already existed.
return nullptr;
}
return &(iter->second);
}

// The callee graph.
auto graph = GetCallGraph(cnode);
MS_EXCEPTION_IF_NULL(graph);
// Create CallSite for Call node.
void MakeCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
auto &call_site = call_info->call_sites.emplace_back();
call_site.cnode = cnode;
call_site.last_monad = last_monad;
call_site.callees.emplace_back(GetCallBranch(cnode));
}

// Link arguments for the sub-graph.
// Create CallSite for Switch/SwitchLayer node.
void MakeSwitchCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
auto &call_site = call_info->call_sites.emplace_back();
call_site.cnode = cnode;
call_site.last_monad = last_monad;
call_site.callees = GetSwitchBranches(cnode);
}

CallBranch GetCallBranch(const CNodePtr &cnode) {
auto input_graph = cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(input_graph);
auto kg = GetValueNode<KernelGraphPtr>(input_graph);
MS_EXCEPTION_IF_NULL(kg);
constexpr size_t call_arg_index = 2;
auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> args(inputs.begin() + call_arg_index, inputs.end());
auto linked_args = LinkArguments(args, graph);
if (linked_args != nullptr) {
monad_ = UpdateState(GetMonad(), linked_args);
std::vector<AnfNodePtr> args{inputs.begin() + call_arg_index, inputs.end()};
return {.graph = kg, .args = std::move(args)};
}

std::vector<CallBranch> GetSwitchBranches(const CNodePtr &cnode) {
constexpr size_t cond_start_index = 2;
std::vector<CallBranch> branches;
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
branches.emplace_back(GetSwitchBranch(cnode, index));
}
return branches;
}

// Goto sub-graph label.
uint32_t graph_label = GetOrCreateGraphLabel(graph);
auto goto_node = LabelGoto(graph_label);
CallBranch GetSwitchBranch(const CNodePtr &cnode, size_t index) {
auto partial_cnode = dyn_cast<CNode>(cnode->input(index));
if (partial_cnode == nullptr) {
return {nullptr, {}};
}
auto &inputs = partial_cnode->inputs();
if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) {
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
}
auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1));
constexpr size_t arg_index = 2;
std::vector<AnfNodePtr> args{inputs.begin() + arg_index, inputs.end()};
return {.graph = graph, .args = std::move(args)};
}

// Set child graph attribute, so that subsequence steps such
// as 'select kernel' can handle sub graphs.
SetChildGrapAttr(goto_node, {graph});
static AnfNodePtr GetRealNode(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
return node;
}
return GetRealNode(node->cast<CNodePtr>()->input(1));
}

// Setup return label if this is not a tail call or it is a recursive call.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = (!is_tail_call || recursive_);
if (!need_return) {
// Set as end_goto if no return required.
kernel_graph_->set_end_goto(goto_node);
private:
const KernelGraphPtr &kernel_graph_;
AscendAutoMonadContext &context_;
};

//
// AscendAutoMonadConverter convert control flow to monad form
// for a kernel graph and its children graphs recursively.
//
class AscendAutoMonadConverter {
public:
static void Run(AscendAutoMonadContext *context) {
for (auto &entry : context->call_info_map) {
AscendAutoMonadConverter converter(entry.first, context, &entry.second);
converter.Run();
}
auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return);
}

private:
AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info)
: kernel_graph_(kg), context_(*context), call_info_(*call_info) {}
~AscendAutoMonadConverter() = default;

void Run() {
// Setup entry label if found.
SetupEntryLabel();

// Handle sub-graph recursively.
HandleSubGraph(graph, output_para, return_label);
// Handle call sites.
for (auto &call_site : call_info_.call_sites) {
HandleCallSite(call_site);
}
// Handle return points.
HandleReturnPoints();
// Let output depend on monad.
if (monad_) {
MakeMonadDepend();
}
}

//
// Convert switch/switchlayer node:
// branch1 = Partial(graph1, arg)
// branch2 = Partial(graph2, arg)
// out = Switch/SwitchLayer(cond/index, branch1, branch2)
// to:
// r = link_args(graph1, arg)
// c = UpdateState(c, r)
// r = link_args(graph2, arg)
// c = UpdateState(c, r)
// c = LabelSwitch(cond/index, c) : L1, L2
// c = LabelSet(c) : <return label>
//
void HandleSwitch(const CNodePtr &cnode) {
void HandleCallSite(const CallSite &call_site) {
// Update last_monad_.
last_monad_ = monad_map_[cnode];
last_monad_ = call_site.last_monad;

// Get branches of the switch or switchlayer, true or 0 branch first.
auto branches = GetSwitchBranches(cnode);
// The call/switch/switch_layer cnode.
auto &cnode = call_site.cnode;

// Link arguments and generate labels for branches.
// Get branches of the call_site.
// for call, there is one branch;
// for switch, the first one is true branch;
// for switch_layer, the first one is 0 branch.
auto &branches = call_site.callees;

// Link arguments and find labels for branches.
std::vector<KernelGraphPtr> graphes;
std::vector<uint32_t> labels;
graphes.reserve(branches.size());
labels.reserve(graphes.size());
labels.reserve(branches.size());
for (auto &[graph, args] : branches) {
if (graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
}
MS_EXCEPTION_IF_NULL(graph);
auto linked_args = LinkArguments(args, graph);
if (linked_args != nullptr) {
monad_ = UpdateState(GetMonad(), linked_args);
}
graphes.push_back(graph);
labels.push_back(GetOrCreateGraphLabel(graph));
labels.push_back(GetGraphLabel(graph));
}

const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch);
if (is_switch) {
// For Switch, we reverse the graphes and labels, so that the false branch
// is the first one, since for kernel LabelSwitch, false is the first branch.
// Assign label indexes if required.
AssignLabelIndexes(call_site);

// For Switch, we reverse the graphes and labels, so that the false branch
// is the first one, since for kernel LabelSwitch, false is the first branch.
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
std::reverse(graphes.begin(), graphes.end());
std::reverse(labels.begin(), labels.end());
}

// Add LabelSwith node.
auto switch_node = LabelSwitch(cnode->input(1), labels);
// Create LabelGoto or LabelSwitch node.
auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);

// Set child graph attribute for switch node.
SetChildGrapAttr(switch_node, graphes);

if (!is_switch) {
// Mark the switch node is for 'switch_layer'.
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node);
// Setup return label and output if required.
if (call_site.return_label != kNoLabel) {
auto label_node = LabelSet(call_site.return_label);
AnfNodePtr output = call_site.out_param;
MS_EXCEPTION_IF_NULL(output);
// Let output depend on the label node, this ensures the
// return label is set before output is used.
output = MakeDepend(output, label_node);
// Replace the the call/switch node with the output.
ReplaceNode(cnode, output);
return;
}

// Setup return label if required.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_);
auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return);

// Handle sub-graphs recursively.
for (auto &graph : graphes) {
HandleSubGraph(graph, output_para, return_label);
// If no return label required, it should be a tail call.
if (!call_site.tail) {
MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString();
}
// For tail calls, replace origin call node with label_goto/label_switch.
ReplaceNode(cnode, label_goto_switch);
kernel_graph_->set_end_goto(label_goto_switch);
}

AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches) {
const bool is_tail_call = (cnode == tail_call_node_);
if (is_tail_call && output_parameter_ != nullptr) {
return output_parameter_;
// Assign label indexes to label parameters for a call site.
void AssignLabelIndexes(const CallSite &call_site) {
for (auto &[label_param, label_index] : call_site.label_indexes) {
auto index_value = GetIndexValueNode(label_index);
auto assign = Assign(label_param, index_value);
monad_ = UpdateState(GetMonad(), assign);
}
return context_.GetOutputParameter(branches.front());
}

// Make return part of a call for the LabelGoto/LabelSwitch node.
std::tuple<AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches,
bool need_return) {
// Prepare return label.
uint32_t return_label = return_label_;
// Prepare output parameter.
auto output_para = GetOutputParameter(cnode, branches);
// Use same output parameter for all branches.
for (auto &branch : branches) {
context_.SetOutputParameter(branch, output_para);
}
auto output = output_para;
// Setup return label if return is required.
if (need_return) {
// Set a new label at return point.
return_label = context_.NewLabel();
auto label_node = LabelSet(return_label);
// Let output depend on the label node, this ensures the
// return label is set before output is used.
output = MakeDepend(output, label_node);
// Create or reuse ValueNode for the index.
ValueNodePtr GetIndexValueNode(uint32_t index) {
auto iter = index_nodes_.find(index);
if (iter != index_nodes_.end()) {
// Reuse ValueNode for same index.
return iter->second;
}

// Replace the the call/switch node with the output.
kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output));
return {output_para, return_label};
// Create a new ValueNode on top graph for the index.
auto &top_graph = context_.TopGraph();
std::vector<int64_t> data = {static_cast<int64_t>(index)};
auto tensor = std::make_shared<tensor::Tensor>(data, kInt32);
auto value_node = top_graph->NewValueNode(tensor->ToAbstract(), tensor);
top_graph->AddValueNodeToGraph(value_node);
index_nodes_.emplace(index, value_node);
return value_node;
}

// Handle sub-graphs recursively.
void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) {
AscendAutoMonadConverter converter(&context_, graph);
converter.output_parameter_ = out_para;
converter.return_label_ = return_label;
converter.Run();
// Replace a node with new node in current kernel graph.
// We also replace the arguments used for sub-graph calls.
void ReplaceNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
kernel_graph_->ReplaceNode(NOT_NULL(old_node), NOT_NULL(new_node));
for (auto &call_site : call_info_.call_sites) {
for (auto &callee : call_site.callees) {
std::replace(callee.args.begin(), callee.args.end(), old_node, new_node);
}
}
}

KernelGraphPtr GetCallGraph(const CNodePtr &cnode) {
auto input_graph = cnode->input(kCallKernelGraphIndex);
MS_EXCEPTION_IF_NULL(input_graph);
return GetValueNode<KernelGraphPtr>(input_graph);
// Make a label_goto or label_switch for a Call/Switch/SwitchLayer node.
CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &graphes,
const std::vector<uint32_t> &labels) {
// Create LabelGoto or LabelSwitch according the cnode type.
const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall);
auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels));

// Set child graph attribute for the LabelGoto or LabelSwitch node.
SetChildGrapAttr(label_goto_switch, graphes);

// Mark the label_switch node is for 'switch_layer' if it is.
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch);
}
return label_goto_switch;
}

GraphArgPair GetSwitchBranch(const CNodePtr &cnode, size_t index) {
auto partial_cnode = dyn_cast<CNode>(cnode->input(index));
if (partial_cnode == nullptr) {
return {nullptr, {}};
//
// Handle return points.
// use label_goto for single return point;
// use label_switch for multi return points.
//
void HandleReturnPoints() {
auto &return_points = call_info_.return_points;
// No return points.
if (return_points.empty()) {
return;
}
auto &inputs = partial_cnode->inputs();
if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) {
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
// Single return point.
if (return_points.size() == 1) {
// Insert Assign for output parameter.
auto &return_point = return_points.front();
AssignOutput(return_point);
// Insert label_goto for return.
auto return_goto = LabelGoto(return_point.call_site->return_label);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
kernel_graph_->set_end_goto(return_goto);
return;
}
auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1));
constexpr size_t arg_index = 2;
return {graph, {inputs.begin() + arg_index, inputs.end()}};
// Multi return points.
std::vector<uint32_t> return_labels;
return_labels.reserve(return_points.size());
for (auto &return_point : return_points) {
// Assign output to out_params of each return point.
AssignOutput(return_point);
// Get return labels.
return_labels.emplace_back(return_point.call_site->return_label);
}
// Insert label_switch for multi return points.
auto &label_param = call_info_.label_param;
MS_EXCEPTION_IF_NULL(label_param);
auto return_switch = LabelSwitch(label_param, return_labels);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
kernel_graph_->set_end_goto(return_switch);
}

std::vector<GraphArgPair> GetSwitchBranches(const CNodePtr &cnode) {
constexpr size_t cond_start_index = 2;
// switch branches
std::vector<GraphArgPair> switch_branches;
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
switch_branches.emplace_back(GetSwitchBranch(cnode, index));
}
return switch_branches;
// Assign graph output to the output parameter for a return point.
void AssignOutput(const ReturnPoint &return_point) {
auto call_site = return_point.call_site;
MS_EXCEPTION_IF_NULL(call_site);
auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output());
monad_ = UpdateState(GetMonad(), assign_output);
}

//
@@ -572,6 +741,7 @@ class AscendAutoMonadConverter {
return kernel_graph_->NewCNode(tuple_inputs);
}

// Insert UpdateState after input node.
AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) {
auto update_state = NewValueNode(prim::kPrimUpdateState);
auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input});
@@ -589,11 +759,14 @@ class AscendAutoMonadConverter {
// c = LabelSet(c) : entry_label
// return add(x, y)
//
void SetupEntryLabel(uint32_t entry_label) {
// Set entry label.
auto label_node = LabelSet(entry_label);
// Make start label the first one in execution order.
kernel_graph_->set_start_label(label_node);
void SetupEntryLabel() {
auto entry_label = GetGraphLabel(kernel_graph_);
if (entry_label != kNoLabel) {
// Set entry label.
auto label_node = LabelSet(entry_label);
// Make start label the first one in execution order.
kernel_graph_->set_start_label(label_node);
}
}

// Make a Depend cnode.
@@ -609,8 +782,10 @@ class AscendAutoMonadConverter {
auto monad = GetMonad();
auto origin_output = kernel_graph_->output();
MS_EXCEPTION_IF_NULL(origin_output);
auto depend_cnode = MakeDepend(origin_output, monad);
kernel_graph_->set_output(depend_cnode);
if (origin_output != monad) {
auto depend_cnode = MakeDepend(origin_output, monad);
kernel_graph_->set_output(depend_cnode);
}
}

// Gets the last monad node, we use a separated UMonad for control flow.
@@ -665,42 +840,17 @@ class AscendAutoMonadConverter {
return cnode;
}

// Return kNoLabel when label id attribute not set for the graph.
uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
auto value = kg->get_attr(kAttrLabelIndex);
if (value == nullptr) {
return kNoLabel;
}
return GetValue<uint32_t>(value);
}

// Get or create entry label for the given graph.
uint32_t GetOrCreateGraphLabel(const KernelGraphPtr &kg) {
auto label = GetGraphLabel(kg);
if (label == kNoLabel) {
// Allocate a new label id and save it to the graph.
label = context_.NewLabel();
kg->set_attr(kAttrLabelIndex, MakeValue(label));
}
return label;
}

// Set child graph attribute for label_goto/label_switch node.
void SetChildGrapAttr(const AnfNodePtr &node, const std::vector<KernelGraphPtr> &graphs) {
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
}

private:
const KernelGraphPtr &kernel_graph_;
AscendAutoMonadContext &context_;
const KernelGraphPtr kernel_graph_;

// Tail call node, null if not found.
CNodePtr tail_call_node_;

// Call/Switch nodes.
std::vector<CNodePtr> call_switch_nodes_;

// Call/Switch node to monad map.
std::map<CNodePtr, AnfNodePtr> monad_map_;
// Call info for current kernel graph.
CallInfo &call_info_;

// The last monad for Call/Switch node.
AnfNodePtr last_monad_;
@@ -711,14 +861,8 @@ class AscendAutoMonadConverter {
// The control flow monad const value node.
AnfNodePtr monad_value_;

// Parameter to store the return value.
AnfNodePtr output_parameter_;

// The return label id.
uint32_t return_label_ = kNoLabel;

// Is this graph include recursive calls.
bool recursive_ = false;
// Index value node cache for reuse.
std::map<uint32_t, ValueNodePtr> index_nodes_;
};

constexpr size_t kAssignTargetIndex = 1;
@@ -985,9 +1129,10 @@ class ExecuteOrderGenerator {

void AscendAutoMonad::Run() {
MS_LOG(DEBUG) << "Ascend auto-monad start.";
AscendAutoMonadContext context(kernel_graph_.get());
AscendAutoMonadConverter converter(&context, kernel_graph_.get());
converter.Run();
auto kg = kernel_graph_.get();
AscendAutoMonadContext context(kg);
CallInfoFinder::Run(&context);
AscendAutoMonadConverter::Run(&context);
kernel_graph_->set_label_num(context.CurrentLabel());
MS_LOG(DEBUG) << "Ascend auto-monad finish.";
DumpGraphForDebug(kernel_graph_);


Loading…
Cancel
Save