|
|
|
@@ -25,6 +25,7 @@ |
|
|
|
#include <memory> |
|
|
|
#include <algorithm> |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "utils/ordered_map.h" |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "debug/anf_ir_dump.h" |
|
|
|
#include "pipeline/jit/base.h" |
|
|
|
@@ -118,6 +119,53 @@ void DumpExecuteOrder(NotNull<KernelGraphPtr> kg) { |
|
|
|
fout.close(); |
|
|
|
} |
|
|
|
|
|
|
|
// Return kNoLabel when label id attribute not set for the graph. |
|
|
|
uint32_t GetGraphLabel(const KernelGraphPtr &kg) { |
|
|
|
auto value = kg->get_attr(kAttrLabelIndex); |
|
|
|
if (value == nullptr) { |
|
|
|
return kNoLabel; |
|
|
|
} |
|
|
|
return GetValue<uint32_t>(value); |
|
|
|
} |
|
|
|
|
|
|
|
struct CallBranch { |
|
|
|
KernelGraphPtr graph; |
|
|
|
std::vector<AnfNodePtr> args; |
|
|
|
}; |
|
|
|
|
|
|
|
struct CallSite { |
|
|
|
// Call/Switch/SwitchLayer |
|
|
|
CNodePtr cnode; |
|
|
|
|
|
|
|
// The last monad before call. |
|
|
|
AnfNodePtr last_monad = nullptr; |
|
|
|
|
|
|
|
// Branch graph called. |
|
|
|
std::vector<CallBranch> callees; |
|
|
|
|
|
|
|
// Parameter for return value. |
|
|
|
AnfNodePtr out_param = nullptr; |
|
|
|
|
|
|
|
// Label id for return. |
|
|
|
uint32_t return_label = kNoLabel; |
|
|
|
|
|
|
|
// Label param to index map. |
|
|
|
std::map<AnfNodePtr, uint32_t> label_indexes; |
|
|
|
|
|
|
|
// True if this is a tail call. |
|
|
|
bool tail = false; |
|
|
|
}; |
|
|
|
|
|
|
|
struct ReturnPoint { |
|
|
|
CallSite *call_site = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
struct CallInfo { |
|
|
|
std::vector<CallSite> call_sites; |
|
|
|
std::vector<ReturnPoint> return_points; |
|
|
|
AnfNodePtr label_param = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
class BaseContext { |
|
|
|
public: |
|
|
|
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } |
|
|
|
@@ -126,12 +174,14 @@ class BaseContext { |
|
|
|
|
|
|
|
const std::set<KernelGraphPtr> &visited_graphs() const { return visited_graphs_; } |
|
|
|
|
|
|
|
void ClearVisited() { visited_graphs_.clear(); } |
|
|
|
|
|
|
|
private: |
|
|
|
std::set<KernelGraphPtr> visited_graphs_; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
// AscendAutoMonadContext holds some shared states during auto-moand. |
|
|
|
// AscendAutoMonadContext holds some shared states during auto-monad. |
|
|
|
// |
|
|
|
class AscendAutoMonadContext : public BaseContext { |
|
|
|
public: |
|
|
|
@@ -144,30 +194,20 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
// Current label id, also the number of label ids we currently used. |
|
|
|
uint32_t CurrentLabel() const { return label_id_; } |
|
|
|
|
|
|
|
// Create or get a parameter for output of the kernel graph. |
|
|
|
AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) { |
|
|
|
// Find output parameter by kernel graph. |
|
|
|
auto iter = kg_out_param_.find(kg); |
|
|
|
if (iter != kg_out_param_.end()) { |
|
|
|
// Return output parameter if found. |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
// Create a new one if not found. |
|
|
|
// Output parameters are all created on top graph. |
|
|
|
auto para = top_graph_->NewParameter(kg->output()->abstract()); |
|
|
|
// Create a new parameter. |
|
|
|
// Output parameters are all created on top graph. |
|
|
|
AnfNodePtr CreateParameter(const AbstractBasePtr &abs) { |
|
|
|
auto para = top_graph_->NewParameter(abs); |
|
|
|
auto out_para = top_graph_->TransTupleToMakeTuple(para); |
|
|
|
// This is required, so that device memory can be allocated for it. |
|
|
|
top_graph_->AddChildGraphResult(out_para); |
|
|
|
// Save new para as the output parameter of the kg. |
|
|
|
kg_out_param_.emplace(kg, out_para); |
|
|
|
return out_para; |
|
|
|
} |
|
|
|
|
|
|
|
// Set output parameter for a kernel graph. |
|
|
|
void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) { |
|
|
|
// Save new para as the output parameter of the kg. |
|
|
|
kg_out_param_.emplace(kg, out_para); |
|
|
|
} |
|
|
|
const KernelGraphPtr &TopGraph() const { return top_graph_; } |
|
|
|
|
|
|
|
// Map kernel_graph to its call info. |
|
|
|
OrderedMap<KernelGraphPtr, CallInfo> call_info_map; |
|
|
|
|
|
|
|
private: |
|
|
|
// The top graph. |
|
|
|
@@ -181,49 +221,34 @@ class AscendAutoMonadContext : public BaseContext { |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
// AscendAutoMonadConverter convert control flow to monad form |
|
|
|
// for a kernel graph and its children graphs recursively. |
|
|
|
// Call info finder finds graph call information. |
|
|
|
// |
|
|
|
class AscendAutoMonadConverter { |
|
|
|
class CallInfoFinder { |
|
|
|
public: |
|
|
|
AscendAutoMonadConverter(AscendAutoMonadContext *context, const KernelGraphPtr &kg) |
|
|
|
: context_(*context), kernel_graph_(kg) {} |
|
|
|
static void Run(AscendAutoMonadContext *context) { |
|
|
|
CallInfoFinder finder(context->TopGraph(), context); |
|
|
|
finder.Run(); |
|
|
|
} |
|
|
|
|
|
|
|
~AscendAutoMonadConverter() = default; |
|
|
|
private: |
|
|
|
CallInfoFinder(const KernelGraphPtr &kg, AscendAutoMonadContext *context) : kernel_graph_(kg), context_(*context) {} |
|
|
|
~CallInfoFinder() = default; |
|
|
|
|
|
|
|
void Run() { |
|
|
|
// Skip if graph already visited. |
|
|
|
if (context_.IsVisited(kernel_graph_)) { |
|
|
|
FindCallSites(); |
|
|
|
FindCallReturns(); |
|
|
|
} |
|
|
|
|
|
|
|
// Find all call sites. |
|
|
|
void FindCallSites() { |
|
|
|
auto call_info = CreateCallInfo(); |
|
|
|
if (call_info == nullptr) { |
|
|
|
// Skip if call_info for this graph already existed. |
|
|
|
return; |
|
|
|
} |
|
|
|
context_.MarkVisited(kernel_graph_); |
|
|
|
|
|
|
|
// Update directly called sub-graphs. |
|
|
|
kernel_graph_->UpdateChildGraphOrder(); |
|
|
|
|
|
|
|
Prepare(); |
|
|
|
|
|
|
|
// Setup entry label if needed. |
|
|
|
auto entry_label = GetGraphLabel(kernel_graph_); |
|
|
|
if (entry_label != kNoLabel) { |
|
|
|
SetupEntryLabel(entry_label); |
|
|
|
} |
|
|
|
|
|
|
|
// Handle call and switch nodes. |
|
|
|
HandleCallSwitch(); |
|
|
|
|
|
|
|
// Let output depend on monad. |
|
|
|
if (monad_) { |
|
|
|
MakeMonadDepend(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
// |
|
|
|
// Prepare information for control flow processing. |
|
|
|
// |
|
|
|
void Prepare() { |
|
|
|
recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive); |
|
|
|
// Find Call/Switch/SwitchLayer nodes, and make CallSites for them. |
|
|
|
AnfNodePtr last_monad = nullptr; |
|
|
|
auto nodes = TopoSort(kernel_graph_->output()); |
|
|
|
for (auto &node : nodes) { |
|
|
|
@@ -231,243 +256,387 @@ class AscendAutoMonadConverter { |
|
|
|
if (HasAbstractUMonad(node)) { |
|
|
|
// Found a node with UMonad abstract, set it as the last monad. |
|
|
|
last_monad = node; |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { |
|
|
|
MakeCallSite(node->cast<CNodePtr>(), last_monad, call_info); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { |
|
|
|
MakeSwitchCallSite(node->cast<CNodePtr>(), last_monad, call_info); |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (cnode->size() < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl; |
|
|
|
} |
|
|
|
// Set the last call as tail call if it is the output node. |
|
|
|
// We don't set tail call for top graph because return is always required. |
|
|
|
if (kernel_graph_ != context_.TopGraph() && !call_info->call_sites.empty()) { |
|
|
|
auto real_output = GetRealNode(kernel_graph_->output()); |
|
|
|
if (real_output == call_info->call_sites.back().cnode) { |
|
|
|
call_info->call_sites.back().tail = true; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || |
|
|
|
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { |
|
|
|
// Found call/switch/switchlayer node, set it as the tail call node. |
|
|
|
tail_call_node_ = cnode; |
|
|
|
call_switch_nodes_.emplace_back(cnode); |
|
|
|
monad_map_.emplace(cnode, last_monad); |
|
|
|
} else if (tail_call_node_ != nullptr && AnfAlgo::IsRealKernel(cnode)) { |
|
|
|
// Set no tail call if we found real kernel cnode after call/switch. |
|
|
|
tail_call_node_ = nullptr; |
|
|
|
} |
|
|
|
// Recursively find CallSites from sub-graphs. |
|
|
|
for (auto &call_site : call_info->call_sites) { |
|
|
|
for (auto &callee : call_site.callees) { |
|
|
|
CallInfoFinder finder(callee.graph, &context_); |
|
|
|
finder.FindCallSites(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// Handle call and switch node, return true if tail call found. |
|
|
|
// |
|
|
|
void HandleCallSwitch() { |
|
|
|
// Handle call switch nodes. |
|
|
|
for (auto &cnode : call_switch_nodes_) { |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { |
|
|
|
HandleCall(cnode); |
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || |
|
|
|
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { |
|
|
|
HandleSwitch(cnode); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString(); |
|
|
|
// Find call-return pairs. |
|
|
|
void FindCallReturns() { |
|
|
|
for (auto &entry : context_.call_info_map) { |
|
|
|
auto &caller = entry.first; |
|
|
|
auto &call_info = entry.second; |
|
|
|
for (auto &call_site : call_info.call_sites) { |
|
|
|
for (auto &callee : call_site.callees) { |
|
|
|
MakeGraphLabel(callee.graph); |
|
|
|
} |
|
|
|
if (!call_site.tail) { |
|
|
|
SearchCallReturns(caller, &call_site); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// If no tail call, assign output value to output parameter, |
|
|
|
// and then goto the return label if set. |
|
|
|
if (tail_call_node_ == nullptr || recursive_) { |
|
|
|
if (output_parameter_) { |
|
|
|
auto assign_output = AssignAll(output_parameter_, kernel_graph_->output()); |
|
|
|
monad_ = UpdateState(GetMonad(), assign_output); |
|
|
|
} |
|
|
|
|
|
|
|
// Create entry label for the given graph if not set. |
|
|
|
void MakeGraphLabel(const KernelGraphPtr &kg) { |
|
|
|
auto label = GetGraphLabel(kg); |
|
|
|
if (label == kNoLabel) { |
|
|
|
// Allocate a new label id and save it to the graph. |
|
|
|
label = context_.NewLabel(); |
|
|
|
kg->set_attr(kAttrLabelIndex, MakeValue(label)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Search return points for all non-tail calls. |
|
|
|
void SearchCallReturns(const KernelGraphPtr &caller, CallSite *call_site) { |
|
|
|
std::set<KernelGraphPtr> visited = {caller}; |
|
|
|
std::queue<CallSite *> call_sites; |
|
|
|
call_sites.push(call_site); |
|
|
|
while (!call_sites.empty()) { |
|
|
|
auto site = call_sites.front(); |
|
|
|
call_sites.pop(); |
|
|
|
for (auto &callee : site->callees) { |
|
|
|
auto &kg = callee.graph; |
|
|
|
if (visited.find(kg) != visited.end()) { |
|
|
|
// Skip visited graphs. |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Mark visited. |
|
|
|
visited.emplace(kg); |
|
|
|
// Check callee. |
|
|
|
auto &call_info = context_.call_info_map[kg]; |
|
|
|
auto &sites = call_info.call_sites; |
|
|
|
if (!sites.empty() && sites.back().tail) { |
|
|
|
// Follow tail call. |
|
|
|
call_sites.push(&sites.back()); |
|
|
|
} else { |
|
|
|
// Find a call-return relation. |
|
|
|
HandleCallReturn(caller, call_site, kg); |
|
|
|
} |
|
|
|
} |
|
|
|
if (return_label_ != kNoLabel) { |
|
|
|
// Insert label_goto for return. |
|
|
|
auto return_goto = LabelGoto(return_label_); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); |
|
|
|
kernel_graph_->set_end_goto(return_goto); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Handle a call-return relation. |
|
|
|
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) { |
|
|
|
// Create a label for the return point. |
|
|
|
if (call_site->return_label == kNoLabel) { |
|
|
|
call_site->return_label = context_.NewLabel(); |
|
|
|
} |
|
|
|
// Create a parameter for the return value. |
|
|
|
if (call_site->out_param == nullptr) { |
|
|
|
call_site->out_param = context_.CreateParameter(call_site->cnode->abstract()); |
|
|
|
} |
|
|
|
// Add a return point for the callee graph. |
|
|
|
auto &call_info = context_.call_info_map[callee]; |
|
|
|
auto &return_point = call_info.return_points.emplace_back(); |
|
|
|
return_point.call_site = call_site; |
|
|
|
|
|
|
|
// Setup label index if there are multi return points. |
|
|
|
const auto n_return_points = call_info.return_points.size(); |
|
|
|
if (n_return_points > 1) { |
|
|
|
if (n_return_points == 2) { |
|
|
|
// Create a parameter to store label index. |
|
|
|
const ShapeVector shape = {1}; |
|
|
|
auto abs = std::make_shared<abstract::AbstractTensor>(kInt32, shape); |
|
|
|
call_info.label_param = context_.CreateParameter(abs); |
|
|
|
// Add label index for the first call site. |
|
|
|
call_info.return_points.front().call_site->label_indexes.emplace(call_info.label_param, 0); |
|
|
|
} |
|
|
|
// Add label index for the current call site. |
|
|
|
auto label_index = static_cast<uint32_t>(call_info.return_points.size() - 1); |
|
|
|
call_site->label_indexes.emplace(call_info.label_param, label_index); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// Convert call node: |
|
|
|
// out = Call(graph, arg) |
|
|
|
// to: |
|
|
|
// r = link_args(graph.para, arg, c) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// c = LabelGoto(c) : L1 |
|
|
|
// |
|
|
|
void HandleCall(const CNodePtr &cnode) { |
|
|
|
// Update last_monad_. |
|
|
|
last_monad_ = monad_map_[cnode]; |
|
|
|
// Create a CallInfo for current kernel graph, return null if it is already existed. |
|
|
|
CallInfo *CreateCallInfo() { |
|
|
|
auto [iter, ok] = context_.call_info_map.add(kernel_graph_); |
|
|
|
if (!ok) { |
|
|
|
// CallInfo already existed. |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return &(iter->second); |
|
|
|
} |
|
|
|
|
|
|
|
// The callee graph. |
|
|
|
auto graph = GetCallGraph(cnode); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
// Create CallSite for Call node. |
|
|
|
void MakeCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) { |
|
|
|
auto &call_site = call_info->call_sites.emplace_back(); |
|
|
|
call_site.cnode = cnode; |
|
|
|
call_site.last_monad = last_monad; |
|
|
|
call_site.callees.emplace_back(GetCallBranch(cnode)); |
|
|
|
} |
|
|
|
|
|
|
|
// Link arguments for the sub-graph. |
|
|
|
// Create CallSite for Switch/SwitchLayer node. |
|
|
|
void MakeSwitchCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) { |
|
|
|
auto &call_site = call_info->call_sites.emplace_back(); |
|
|
|
call_site.cnode = cnode; |
|
|
|
call_site.last_monad = last_monad; |
|
|
|
call_site.callees = GetSwitchBranches(cnode); |
|
|
|
} |
|
|
|
|
|
|
|
CallBranch GetCallBranch(const CNodePtr &cnode) { |
|
|
|
auto input_graph = cnode->input(kCallKernelGraphIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(input_graph); |
|
|
|
auto kg = GetValueNode<KernelGraphPtr>(input_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(kg); |
|
|
|
constexpr size_t call_arg_index = 2; |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
std::vector<AnfNodePtr> args(inputs.begin() + call_arg_index, inputs.end()); |
|
|
|
auto linked_args = LinkArguments(args, graph); |
|
|
|
if (linked_args != nullptr) { |
|
|
|
monad_ = UpdateState(GetMonad(), linked_args); |
|
|
|
std::vector<AnfNodePtr> args{inputs.begin() + call_arg_index, inputs.end()}; |
|
|
|
return {.graph = kg, .args = std::move(args)}; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<CallBranch> GetSwitchBranches(const CNodePtr &cnode) { |
|
|
|
constexpr size_t cond_start_index = 2; |
|
|
|
std::vector<CallBranch> branches; |
|
|
|
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) { |
|
|
|
branches.emplace_back(GetSwitchBranch(cnode, index)); |
|
|
|
} |
|
|
|
return branches; |
|
|
|
} |
|
|
|
|
|
|
|
// Goto sub-graph label. |
|
|
|
uint32_t graph_label = GetOrCreateGraphLabel(graph); |
|
|
|
auto goto_node = LabelGoto(graph_label); |
|
|
|
CallBranch GetSwitchBranch(const CNodePtr &cnode, size_t index) { |
|
|
|
auto partial_cnode = dyn_cast<CNode>(cnode->input(index)); |
|
|
|
if (partial_cnode == nullptr) { |
|
|
|
return {nullptr, {}}; |
|
|
|
} |
|
|
|
auto &inputs = partial_cnode->inputs(); |
|
|
|
if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1)); |
|
|
|
constexpr size_t arg_index = 2; |
|
|
|
std::vector<AnfNodePtr> args{inputs.begin() + arg_index, inputs.end()}; |
|
|
|
return {.graph = graph, .args = std::move(args)}; |
|
|
|
} |
|
|
|
|
|
|
|
// Set child graph attribute, so that subsequence steps such |
|
|
|
// as 'select kernel' can handle sub graphs. |
|
|
|
SetChildGrapAttr(goto_node, {graph}); |
|
|
|
static AnfNodePtr GetRealNode(const AnfNodePtr &node) { |
|
|
|
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { |
|
|
|
return node; |
|
|
|
} |
|
|
|
return GetRealNode(node->cast<CNodePtr>()->input(1)); |
|
|
|
} |
|
|
|
|
|
|
|
// Setup return label if this is not a tail call or it is a recursive call. |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
const bool need_return = (!is_tail_call || recursive_); |
|
|
|
if (!need_return) { |
|
|
|
// Set as end_goto if no return required. |
|
|
|
kernel_graph_->set_end_goto(goto_node); |
|
|
|
private: |
|
|
|
const KernelGraphPtr &kernel_graph_; |
|
|
|
AscendAutoMonadContext &context_; |
|
|
|
}; |
|
|
|
|
|
|
|
// |
|
|
|
// AscendAutoMonadConverter convert control flow to monad form |
|
|
|
// for a kernel graph and its children graphs recursively. |
|
|
|
// |
|
|
|
class AscendAutoMonadConverter { |
|
|
|
public: |
|
|
|
static void Run(AscendAutoMonadContext *context) { |
|
|
|
for (auto &entry : context->call_info_map) { |
|
|
|
AscendAutoMonadConverter converter(entry.first, context, &entry.second); |
|
|
|
converter.Run(); |
|
|
|
} |
|
|
|
auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info) |
|
|
|
: kernel_graph_(kg), context_(*context), call_info_(*call_info) {} |
|
|
|
~AscendAutoMonadConverter() = default; |
|
|
|
|
|
|
|
void Run() { |
|
|
|
// Setup entry label if found. |
|
|
|
SetupEntryLabel(); |
|
|
|
|
|
|
|
// Handle sub-graph recursively. |
|
|
|
HandleSubGraph(graph, output_para, return_label); |
|
|
|
// Handle call sites. |
|
|
|
for (auto &call_site : call_info_.call_sites) { |
|
|
|
HandleCallSite(call_site); |
|
|
|
} |
|
|
|
// Handle return points. |
|
|
|
HandleReturnPoints(); |
|
|
|
// Let output depend on monad. |
|
|
|
if (monad_) { |
|
|
|
MakeMonadDepend(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
// Convert switch/switchlayer node: |
|
|
|
// branch1 = Partial(graph1, arg) |
|
|
|
// branch2 = Partial(graph2, arg) |
|
|
|
// out = Switch/SwitchLayer(cond/index, branch1, branch2) |
|
|
|
// to: |
|
|
|
// r = link_args(graph1, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// r = link_args(graph2, arg) |
|
|
|
// c = UpdateState(c, r) |
|
|
|
// c = LabelSwitch(cond/index, c) : L1, L2 |
|
|
|
// c = LabelSet(c) : <return label> |
|
|
|
// |
|
|
|
void HandleSwitch(const CNodePtr &cnode) { |
|
|
|
void HandleCallSite(const CallSite &call_site) { |
|
|
|
// Update last_monad_. |
|
|
|
last_monad_ = monad_map_[cnode]; |
|
|
|
last_monad_ = call_site.last_monad; |
|
|
|
|
|
|
|
// Get branches of the switch or switchlayer, true or 0 branch first. |
|
|
|
auto branches = GetSwitchBranches(cnode); |
|
|
|
// The call/switch/switch_layer cnode. |
|
|
|
auto &cnode = call_site.cnode; |
|
|
|
|
|
|
|
// Link arguments and generate labels for branches. |
|
|
|
// Get branches of the call_site. |
|
|
|
// for call, there is one branch; |
|
|
|
// for switch, the first one is true branch; |
|
|
|
// for switch_layer, the first one is 0 branch. |
|
|
|
auto &branches = call_site.callees; |
|
|
|
|
|
|
|
// Link arguments and find labels for branches. |
|
|
|
std::vector<KernelGraphPtr> graphes; |
|
|
|
std::vector<uint32_t> labels; |
|
|
|
graphes.reserve(branches.size()); |
|
|
|
labels.reserve(graphes.size()); |
|
|
|
labels.reserve(branches.size()); |
|
|
|
for (auto &[graph, args] : branches) { |
|
|
|
if (graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto linked_args = LinkArguments(args, graph); |
|
|
|
if (linked_args != nullptr) { |
|
|
|
monad_ = UpdateState(GetMonad(), linked_args); |
|
|
|
} |
|
|
|
graphes.push_back(graph); |
|
|
|
labels.push_back(GetOrCreateGraphLabel(graph)); |
|
|
|
labels.push_back(GetGraphLabel(graph)); |
|
|
|
} |
|
|
|
|
|
|
|
const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch); |
|
|
|
if (is_switch) { |
|
|
|
// For Switch, we reverse the graphes and labels, so that the false branch |
|
|
|
// is the first one, since for kernel LabelSwitch, false is the first branch. |
|
|
|
// Assign label indexes if required. |
|
|
|
AssignLabelIndexes(call_site); |
|
|
|
|
|
|
|
// For Switch, we reverse the graphes and labels, so that the false branch |
|
|
|
// is the first one, since for kernel LabelSwitch, false is the first branch. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { |
|
|
|
std::reverse(graphes.begin(), graphes.end()); |
|
|
|
std::reverse(labels.begin(), labels.end()); |
|
|
|
} |
|
|
|
|
|
|
|
// Add LabelSwith node. |
|
|
|
auto switch_node = LabelSwitch(cnode->input(1), labels); |
|
|
|
// Create LabelGoto or LabelSwitch node. |
|
|
|
auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); |
|
|
|
|
|
|
|
// Set child graph attribute for switch node. |
|
|
|
SetChildGrapAttr(switch_node, graphes); |
|
|
|
|
|
|
|
if (!is_switch) { |
|
|
|
// Mark the switch node is for 'switch_layer'. |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node); |
|
|
|
// Setup return label and output if required. |
|
|
|
if (call_site.return_label != kNoLabel) { |
|
|
|
auto label_node = LabelSet(call_site.return_label); |
|
|
|
AnfNodePtr output = call_site.out_param; |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
// Let output depend on the label node, this ensures the |
|
|
|
// return label is set before output is used. |
|
|
|
output = MakeDepend(output, label_node); |
|
|
|
// Replace the the call/switch node with the output. |
|
|
|
ReplaceNode(cnode, output); |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// Setup return label if required. |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_); |
|
|
|
auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return); |
|
|
|
|
|
|
|
// Handle sub-graphs recursively. |
|
|
|
for (auto &graph : graphes) { |
|
|
|
HandleSubGraph(graph, output_para, return_label); |
|
|
|
// If no return label required, it should be a tail call. |
|
|
|
if (!call_site.tail) { |
|
|
|
MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString(); |
|
|
|
} |
|
|
|
// For tail calls, replace origin call node with label_goto/label_switch. |
|
|
|
ReplaceNode(cnode, label_goto_switch); |
|
|
|
kernel_graph_->set_end_goto(label_goto_switch); |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches) { |
|
|
|
const bool is_tail_call = (cnode == tail_call_node_); |
|
|
|
if (is_tail_call && output_parameter_ != nullptr) { |
|
|
|
return output_parameter_; |
|
|
|
// Assign label indexes to label parameters for a call site. |
|
|
|
void AssignLabelIndexes(const CallSite &call_site) { |
|
|
|
for (auto &[label_param, label_index] : call_site.label_indexes) { |
|
|
|
auto index_value = GetIndexValueNode(label_index); |
|
|
|
auto assign = Assign(label_param, index_value); |
|
|
|
monad_ = UpdateState(GetMonad(), assign); |
|
|
|
} |
|
|
|
return context_.GetOutputParameter(branches.front()); |
|
|
|
} |
|
|
|
|
|
|
|
// Make return part of a call for the LabelGoto/LabelSwitch node. |
|
|
|
std::tuple<AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches, |
|
|
|
bool need_return) { |
|
|
|
// Prepare return label. |
|
|
|
uint32_t return_label = return_label_; |
|
|
|
// Prepare output parameter. |
|
|
|
auto output_para = GetOutputParameter(cnode, branches); |
|
|
|
// Use same output parameter for all branches. |
|
|
|
for (auto &branch : branches) { |
|
|
|
context_.SetOutputParameter(branch, output_para); |
|
|
|
} |
|
|
|
auto output = output_para; |
|
|
|
// Setup return label if return is required. |
|
|
|
if (need_return) { |
|
|
|
// Set a new label at return point. |
|
|
|
return_label = context_.NewLabel(); |
|
|
|
auto label_node = LabelSet(return_label); |
|
|
|
// Let output depend on the label node, this ensures the |
|
|
|
// return label is set before output is used. |
|
|
|
output = MakeDepend(output, label_node); |
|
|
|
// Create or reuse ValueNode for the index. |
|
|
|
ValueNodePtr GetIndexValueNode(uint32_t index) { |
|
|
|
auto iter = index_nodes_.find(index); |
|
|
|
if (iter != index_nodes_.end()) { |
|
|
|
// Reuse ValueNode for same index. |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
// Replace the the call/switch node with the output. |
|
|
|
kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output)); |
|
|
|
return {output_para, return_label}; |
|
|
|
// Create a new ValueNode on top graph for the index. |
|
|
|
auto &top_graph = context_.TopGraph(); |
|
|
|
std::vector<int64_t> data = {static_cast<int64_t>(index)}; |
|
|
|
auto tensor = std::make_shared<tensor::Tensor>(data, kInt32); |
|
|
|
auto value_node = top_graph->NewValueNode(tensor->ToAbstract(), tensor); |
|
|
|
top_graph->AddValueNodeToGraph(value_node); |
|
|
|
index_nodes_.emplace(index, value_node); |
|
|
|
return value_node; |
|
|
|
} |
|
|
|
|
|
|
|
// Handle sub-graphs recursively. |
|
|
|
void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) { |
|
|
|
AscendAutoMonadConverter converter(&context_, graph); |
|
|
|
converter.output_parameter_ = out_para; |
|
|
|
converter.return_label_ = return_label; |
|
|
|
converter.Run(); |
|
|
|
// Replace a node with new node in current kernel graph. |
|
|
|
// We also replace the arguments used for sub-graph calls. |
|
|
|
void ReplaceNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { |
|
|
|
kernel_graph_->ReplaceNode(NOT_NULL(old_node), NOT_NULL(new_node)); |
|
|
|
for (auto &call_site : call_info_.call_sites) { |
|
|
|
for (auto &callee : call_site.callees) { |
|
|
|
std::replace(callee.args.begin(), callee.args.end(), old_node, new_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
KernelGraphPtr GetCallGraph(const CNodePtr &cnode) { |
|
|
|
auto input_graph = cnode->input(kCallKernelGraphIndex); |
|
|
|
MS_EXCEPTION_IF_NULL(input_graph); |
|
|
|
return GetValueNode<KernelGraphPtr>(input_graph); |
|
|
|
// Make a label_goto or label_switch for a Call/Switch/SwitchLayer node. |
|
|
|
CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &graphes, |
|
|
|
const std::vector<uint32_t> &labels) { |
|
|
|
// Create LabelGoto or LabelSwitch according the cnode type. |
|
|
|
const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall); |
|
|
|
auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels)); |
|
|
|
|
|
|
|
// Set child graph attribute for the LabelGoto or LabelSwitch node. |
|
|
|
SetChildGrapAttr(label_goto_switch, graphes); |
|
|
|
|
|
|
|
// Mark the label_switch node is for 'switch_layer' if it is. |
|
|
|
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch); |
|
|
|
} |
|
|
|
return label_goto_switch; |
|
|
|
} |
|
|
|
|
|
|
|
GraphArgPair GetSwitchBranch(const CNodePtr &cnode, size_t index) { |
|
|
|
auto partial_cnode = dyn_cast<CNode>(cnode->input(index)); |
|
|
|
if (partial_cnode == nullptr) { |
|
|
|
return {nullptr, {}}; |
|
|
|
// |
|
|
|
// Handle return points. |
|
|
|
// use label_goto for single return point; |
|
|
|
// use label_switch for multi return points. |
|
|
|
// |
|
|
|
void HandleReturnPoints() { |
|
|
|
auto &return_points = call_info_.return_points; |
|
|
|
// No return points. |
|
|
|
if (return_points.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto &inputs = partial_cnode->inputs(); |
|
|
|
if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString(); |
|
|
|
// Single return point. |
|
|
|
if (return_points.size() == 1) { |
|
|
|
// Insert Assign for output parameter. |
|
|
|
auto &return_point = return_points.front(); |
|
|
|
AssignOutput(return_point); |
|
|
|
// Insert label_goto for return. |
|
|
|
auto return_goto = LabelGoto(return_point.call_site->return_label); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); |
|
|
|
kernel_graph_->set_end_goto(return_goto); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1)); |
|
|
|
constexpr size_t arg_index = 2; |
|
|
|
return {graph, {inputs.begin() + arg_index, inputs.end()}}; |
|
|
|
// Multi return points. |
|
|
|
std::vector<uint32_t> return_labels; |
|
|
|
return_labels.reserve(return_points.size()); |
|
|
|
for (auto &return_point : return_points) { |
|
|
|
// Assign output to out_params of each return point. |
|
|
|
AssignOutput(return_point); |
|
|
|
// Get return labels. |
|
|
|
return_labels.emplace_back(return_point.call_site->return_label); |
|
|
|
} |
|
|
|
// Insert label_switch for multi return points. |
|
|
|
auto &label_param = call_info_.label_param; |
|
|
|
MS_EXCEPTION_IF_NULL(label_param); |
|
|
|
auto return_switch = LabelSwitch(label_param, return_labels); |
|
|
|
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); |
|
|
|
kernel_graph_->set_end_goto(return_switch); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<GraphArgPair> GetSwitchBranches(const CNodePtr &cnode) { |
|
|
|
constexpr size_t cond_start_index = 2; |
|
|
|
// switch branches |
|
|
|
std::vector<GraphArgPair> switch_branches; |
|
|
|
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) { |
|
|
|
switch_branches.emplace_back(GetSwitchBranch(cnode, index)); |
|
|
|
} |
|
|
|
return switch_branches; |
|
|
|
// Assign graph output to the output parameter for a return point. |
|
|
|
void AssignOutput(const ReturnPoint &return_point) { |
|
|
|
auto call_site = return_point.call_site; |
|
|
|
MS_EXCEPTION_IF_NULL(call_site); |
|
|
|
auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); |
|
|
|
monad_ = UpdateState(GetMonad(), assign_output); |
|
|
|
} |
|
|
|
|
|
|
|
// |
|
|
|
@@ -572,6 +741,7 @@ class AscendAutoMonadConverter { |
|
|
|
return kernel_graph_->NewCNode(tuple_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
// Insert UpdateState after input node. |
|
|
|
AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) { |
|
|
|
auto update_state = NewValueNode(prim::kPrimUpdateState); |
|
|
|
auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input}); |
|
|
|
@@ -589,11 +759,14 @@ class AscendAutoMonadConverter { |
|
|
|
// c = LabelSet(c) : entry_label |
|
|
|
// return add(x, y) |
|
|
|
// |
|
|
|
void SetupEntryLabel(uint32_t entry_label) { |
|
|
|
// Set entry label. |
|
|
|
auto label_node = LabelSet(entry_label); |
|
|
|
// Make start label the first one in execution order. |
|
|
|
kernel_graph_->set_start_label(label_node); |
|
|
|
void SetupEntryLabel() { |
|
|
|
auto entry_label = GetGraphLabel(kernel_graph_); |
|
|
|
if (entry_label != kNoLabel) { |
|
|
|
// Set entry label. |
|
|
|
auto label_node = LabelSet(entry_label); |
|
|
|
// Make start label the first one in execution order. |
|
|
|
kernel_graph_->set_start_label(label_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Make a Depend cnode. |
|
|
|
@@ -609,8 +782,10 @@ class AscendAutoMonadConverter { |
|
|
|
auto monad = GetMonad(); |
|
|
|
auto origin_output = kernel_graph_->output(); |
|
|
|
MS_EXCEPTION_IF_NULL(origin_output); |
|
|
|
auto depend_cnode = MakeDepend(origin_output, monad); |
|
|
|
kernel_graph_->set_output(depend_cnode); |
|
|
|
if (origin_output != monad) { |
|
|
|
auto depend_cnode = MakeDepend(origin_output, monad); |
|
|
|
kernel_graph_->set_output(depend_cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Gets the last monad node, we use a separated UMonad for control flow. |
|
|
|
@@ -665,42 +840,17 @@ class AscendAutoMonadConverter { |
|
|
|
return cnode; |
|
|
|
} |
|
|
|
|
|
|
|
// Return kNoLabel when label id attribute not set for the graph. |
|
|
|
uint32_t GetGraphLabel(const KernelGraphPtr &kg) { |
|
|
|
auto value = kg->get_attr(kAttrLabelIndex); |
|
|
|
if (value == nullptr) { |
|
|
|
return kNoLabel; |
|
|
|
} |
|
|
|
return GetValue<uint32_t>(value); |
|
|
|
} |
|
|
|
|
|
|
|
// Get or create entry label for the given graph. |
|
|
|
uint32_t GetOrCreateGraphLabel(const KernelGraphPtr &kg) { |
|
|
|
auto label = GetGraphLabel(kg); |
|
|
|
if (label == kNoLabel) { |
|
|
|
// Allocate a new label id and save it to the graph. |
|
|
|
label = context_.NewLabel(); |
|
|
|
kg->set_attr(kAttrLabelIndex, MakeValue(label)); |
|
|
|
} |
|
|
|
return label; |
|
|
|
} |
|
|
|
|
|
|
|
// Set child graph attribute for label_goto/label_switch node. |
|
|
|
void SetChildGrapAttr(const AnfNodePtr &node, const std::vector<KernelGraphPtr> &graphs) { |
|
|
|
AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
const KernelGraphPtr &kernel_graph_; |
|
|
|
AscendAutoMonadContext &context_; |
|
|
|
const KernelGraphPtr kernel_graph_; |
|
|
|
|
|
|
|
// Tail call node, null if not found. |
|
|
|
CNodePtr tail_call_node_; |
|
|
|
|
|
|
|
// Call/Switch nodes. |
|
|
|
std::vector<CNodePtr> call_switch_nodes_; |
|
|
|
|
|
|
|
// Call/Switch node to monad map. |
|
|
|
std::map<CNodePtr, AnfNodePtr> monad_map_; |
|
|
|
// Call info for current kernel graph. |
|
|
|
CallInfo &call_info_; |
|
|
|
|
|
|
|
// The last monad for Call/Switch node. |
|
|
|
AnfNodePtr last_monad_; |
|
|
|
@@ -711,14 +861,8 @@ class AscendAutoMonadConverter { |
|
|
|
// The control flow monad const value node. |
|
|
|
AnfNodePtr monad_value_; |
|
|
|
|
|
|
|
// Parameter to store the return value. |
|
|
|
AnfNodePtr output_parameter_; |
|
|
|
|
|
|
|
// The return label id. |
|
|
|
uint32_t return_label_ = kNoLabel; |
|
|
|
|
|
|
|
// Is this graph include recursive calls. |
|
|
|
bool recursive_ = false; |
|
|
|
// Index value node cache for reuse. |
|
|
|
std::map<uint32_t, ValueNodePtr> index_nodes_; |
|
|
|
}; |
|
|
|
|
|
|
|
constexpr size_t kAssignTargetIndex = 1; |
|
|
|
@@ -985,9 +1129,10 @@ class ExecuteOrderGenerator { |
|
|
|
|
|
|
|
void AscendAutoMonad::Run() { |
|
|
|
MS_LOG(DEBUG) << "Ascend auto-monad start."; |
|
|
|
AscendAutoMonadContext context(kernel_graph_.get()); |
|
|
|
AscendAutoMonadConverter converter(&context, kernel_graph_.get()); |
|
|
|
converter.Run(); |
|
|
|
auto kg = kernel_graph_.get(); |
|
|
|
AscendAutoMonadContext context(kg); |
|
|
|
CallInfoFinder::Run(&context); |
|
|
|
AscendAutoMonadConverter::Run(&context); |
|
|
|
kernel_graph_->set_label_num(context.CurrentLabel()); |
|
|
|
MS_LOG(DEBUG) << "Ascend auto-monad finish."; |
|
|
|
DumpGraphForDebug(kernel_graph_); |
|
|
|
|