diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index f03133e53c..080d0de7bb 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -25,6 +25,7 @@ #include #include #include "utils/ms_context.h" +#include "utils/ordered_map.h" #include "base/core_ops.h" #include "debug/anf_ir_dump.h" #include "pipeline/jit/base.h" @@ -118,6 +119,53 @@ void DumpExecuteOrder(NotNull kg) { fout.close(); } +// Return kNoLabel when label id attribute not set for the graph. +uint32_t GetGraphLabel(const KernelGraphPtr &kg) { + auto value = kg->get_attr(kAttrLabelIndex); + if (value == nullptr) { + return kNoLabel; + } + return GetValue(value); +} + +struct CallBranch { + KernelGraphPtr graph; + std::vector args; +}; + +struct CallSite { + // Call/Switch/SwitchLayer + CNodePtr cnode; + + // The last monad before call. + AnfNodePtr last_monad = nullptr; + + // Branch graph called. + std::vector callees; + + // Parameter for return value. + AnfNodePtr out_param = nullptr; + + // Label id for return. + uint32_t return_label = kNoLabel; + + // Label param to index map. + std::map label_indexes; + + // True if this is a tail call. + bool tail = false; +}; + +struct ReturnPoint { + CallSite *call_site = nullptr; +}; + +struct CallInfo { + std::vector call_sites; + std::vector return_points; + AnfNodePtr label_param = nullptr; +}; + class BaseContext { public: void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } @@ -126,12 +174,14 @@ class BaseContext { const std::set &visited_graphs() const { return visited_graphs_; } + void ClearVisited() { visited_graphs_.clear(); } + private: std::set visited_graphs_; }; // -// AscendAutoMonadContext holds some shared states during auto-moand. +// AscendAutoMonadContext holds some shared states during auto-monad. // class AscendAutoMonadContext : public BaseContext { public: @@ -144,30 +194,20 @@ class AscendAutoMonadContext : public BaseContext { // Current label id, also the number of label ids we currently used. uint32_t CurrentLabel() const { return label_id_; } - // Create or get a parameter for output of the kernel graph. - AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) { - // Find output parameter by kernel graph. - auto iter = kg_out_param_.find(kg); - if (iter != kg_out_param_.end()) { - // Return output parameter if found. - return iter->second; - } - // Create a new one if not found. - // Output parameters are all created on top graph. - auto para = top_graph_->NewParameter(kg->output()->abstract()); + // Create a new parameter. + // Output parameters are all created on top graph. + AnfNodePtr CreateParameter(const AbstractBasePtr &abs) { + auto para = top_graph_->NewParameter(abs); auto out_para = top_graph_->TransTupleToMakeTuple(para); // This is required, so that device memory can be allocated for it. top_graph_->AddChildGraphResult(out_para); - // Save new para as the output parameter of the kg. - kg_out_param_.emplace(kg, out_para); return out_para; } - // Set output parameter for a kernel graph. - void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) { - // Save new para as the output parameter of the kg. - kg_out_param_.emplace(kg, out_para); - } + const KernelGraphPtr &TopGraph() const { return top_graph_; } + + // Map kernel_graph to its call info. + OrderedMap call_info_map; private: // The top graph. @@ -181,49 +221,34 @@ class AscendAutoMonadContext : public BaseContext { }; // -// AscendAutoMonadConverter convert control flow to monad form -// for a kernel graph and its children graphs recursively. +// Call info finder finds graph call information. // -class AscendAutoMonadConverter { +class CallInfoFinder { public: - AscendAutoMonadConverter(AscendAutoMonadContext *context, const KernelGraphPtr &kg) - : context_(*context), kernel_graph_(kg) {} + static void Run(AscendAutoMonadContext *context) { + CallInfoFinder finder(context->TopGraph(), context); + finder.Run(); + } - ~AscendAutoMonadConverter() = default; + private: + CallInfoFinder(const KernelGraphPtr &kg, AscendAutoMonadContext *context) : kernel_graph_(kg), context_(*context) {} + ~CallInfoFinder() = default; void Run() { - // Skip if graph already visited. - if (context_.IsVisited(kernel_graph_)) { + FindCallSites(); + FindCallReturns(); + } + + // Find all call sites. + void FindCallSites() { + auto call_info = CreateCallInfo(); + if (call_info == nullptr) { + // Skip if call_info for this graph already existed. return; } - context_.MarkVisited(kernel_graph_); - // Update directly called sub-graphs. kernel_graph_->UpdateChildGraphOrder(); - - Prepare(); - - // Setup entry label if needed. - auto entry_label = GetGraphLabel(kernel_graph_); - if (entry_label != kNoLabel) { - SetupEntryLabel(entry_label); - } - - // Handle call and switch nodes. - HandleCallSwitch(); - - // Let output depend on monad. - if (monad_) { - MakeMonadDepend(); - } - } - - private: - // - // Prepare information for control flow processing. - // - void Prepare() { - recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive); + // Find Call/Switch/SwitchLayer nodes, and make CallSites for them. AnfNodePtr last_monad = nullptr; auto nodes = TopoSort(kernel_graph_->output()); for (auto &node : nodes) { @@ -231,243 +256,387 @@ class AscendAutoMonadConverter { if (HasAbstractUMonad(node)) { // Found a node with UMonad abstract, set it as the last monad. last_monad = node; + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + MakeCallSite(node->cast(), last_monad, call_info); + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) { + MakeSwitchCallSite(node->cast(), last_monad, call_info); } - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - if (cnode->size() < 1) { - MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl; + } + // Set the last call as tail call if it is the output node. + // We don't set tail call for top graph because return is always required. + if (kernel_graph_ != context_.TopGraph() && !call_info->call_sites.empty()) { + auto real_output = GetRealNode(kernel_graph_->output()); + if (real_output == call_info->call_sites.back().cnode) { + call_info->call_sites.back().tail = true; } - if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { - // Found call/switch/switchlayer node, set it as the tail call node. - tail_call_node_ = cnode; - call_switch_nodes_.emplace_back(cnode); - monad_map_.emplace(cnode, last_monad); - } else if (tail_call_node_ != nullptr && AnfAlgo::IsRealKernel(cnode)) { - // Set no tail call if we found real kernel cnode after call/switch. - tail_call_node_ = nullptr; + } + // Recursively find CallSites from sub-graphs. + for (auto &call_site : call_info->call_sites) { + for (auto &callee : call_site.callees) { + CallInfoFinder finder(callee.graph, &context_); + finder.FindCallSites(); } } } - // - // Handle call and switch node, return true if tail call found. - // - void HandleCallSwitch() { - // Handle call switch nodes. - for (auto &cnode : call_switch_nodes_) { - if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { - HandleCall(cnode); - } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || - AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { - HandleSwitch(cnode); - } else { - MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString(); + // Find call-return pairs. + void FindCallReturns() { + for (auto &entry : context_.call_info_map) { + auto &caller = entry.first; + auto &call_info = entry.second; + for (auto &call_site : call_info.call_sites) { + for (auto &callee : call_site.callees) { + MakeGraphLabel(callee.graph); + } + if (!call_site.tail) { + SearchCallReturns(caller, &call_site); + } } } - // If no tail call, assign output value to output parameter, - // and then goto the return label if set. - if (tail_call_node_ == nullptr || recursive_) { - if (output_parameter_) { - auto assign_output = AssignAll(output_parameter_, kernel_graph_->output()); - monad_ = UpdateState(GetMonad(), assign_output); + } + + // Create entry label for the given graph if not set. + void MakeGraphLabel(const KernelGraphPtr &kg) { + auto label = GetGraphLabel(kg); + if (label == kNoLabel) { + // Allocate a new label id and save it to the graph. + label = context_.NewLabel(); + kg->set_attr(kAttrLabelIndex, MakeValue(label)); + } + } + + // Search return points for all non-tail calls. + void SearchCallReturns(const KernelGraphPtr &caller, CallSite *call_site) { + std::set visited = {caller}; + std::queue call_sites; + call_sites.push(call_site); + while (!call_sites.empty()) { + auto site = call_sites.front(); + call_sites.pop(); + for (auto &callee : site->callees) { + auto &kg = callee.graph; + if (visited.find(kg) != visited.end()) { + // Skip visited graphs. + continue; + } + // Mark visited. + visited.emplace(kg); + // Check callee. + auto &call_info = context_.call_info_map[kg]; + auto &sites = call_info.call_sites; + if (!sites.empty() && sites.back().tail) { + // Follow tail call. + call_sites.push(&sites.back()); + } else { + // Find a call-return relation. + HandleCallReturn(caller, call_site, kg); + } } - if (return_label_ != kNoLabel) { - // Insert label_goto for return. - auto return_goto = LabelGoto(return_label_); - AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); - kernel_graph_->set_end_goto(return_goto); + } + } + + // Handle a call-return relation. + void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) { + // Create a label for the return point. + if (call_site->return_label == kNoLabel) { + call_site->return_label = context_.NewLabel(); + } + // Create a parameter for the return value. + if (call_site->out_param == nullptr) { + call_site->out_param = context_.CreateParameter(call_site->cnode->abstract()); + } + // Add a return point for the callee graph. + auto &call_info = context_.call_info_map[callee]; + auto &return_point = call_info.return_points.emplace_back(); + return_point.call_site = call_site; + + // Setup label index if there are multi return points. + const auto n_return_points = call_info.return_points.size(); + if (n_return_points > 1) { + if (n_return_points == 2) { + // Create a parameter to store label index. + const ShapeVector shape = {1}; + auto abs = std::make_shared(kInt32, shape); + call_info.label_param = context_.CreateParameter(abs); + // Add label index for the first call site. + call_info.return_points.front().call_site->label_indexes.emplace(call_info.label_param, 0); } + // Add label index for the current call site. + auto label_index = static_cast(call_info.return_points.size() - 1); + call_site->label_indexes.emplace(call_info.label_param, label_index); } } - // - // Convert call node: - // out = Call(graph, arg) - // to: - // r = link_args(graph.para, arg, c) - // c = UpdateState(c, r) - // c = LabelGoto(c) : L1 - // - void HandleCall(const CNodePtr &cnode) { - // Update last_monad_. - last_monad_ = monad_map_[cnode]; + // Create a CallInfo for current kernel graph, return null if it is already existed. + CallInfo *CreateCallInfo() { + auto [iter, ok] = context_.call_info_map.add(kernel_graph_); + if (!ok) { + // CallInfo already existed. + return nullptr; + } + return &(iter->second); + } - // The callee graph. - auto graph = GetCallGraph(cnode); - MS_EXCEPTION_IF_NULL(graph); + // Create CallSite for Call node. + void MakeCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) { + auto &call_site = call_info->call_sites.emplace_back(); + call_site.cnode = cnode; + call_site.last_monad = last_monad; + call_site.callees.emplace_back(GetCallBranch(cnode)); + } - // Link arguments for the sub-graph. + // Create CallSite for Switch/SwitchLayer node. + void MakeSwitchCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) { + auto &call_site = call_info->call_sites.emplace_back(); + call_site.cnode = cnode; + call_site.last_monad = last_monad; + call_site.callees = GetSwitchBranches(cnode); + } + + CallBranch GetCallBranch(const CNodePtr &cnode) { + auto input_graph = cnode->input(kCallKernelGraphIndex); + MS_EXCEPTION_IF_NULL(input_graph); + auto kg = GetValueNode(input_graph); + MS_EXCEPTION_IF_NULL(kg); constexpr size_t call_arg_index = 2; auto &inputs = cnode->inputs(); - std::vector args(inputs.begin() + call_arg_index, inputs.end()); - auto linked_args = LinkArguments(args, graph); - if (linked_args != nullptr) { - monad_ = UpdateState(GetMonad(), linked_args); + std::vector args{inputs.begin() + call_arg_index, inputs.end()}; + return {.graph = kg, .args = std::move(args)}; + } + + std::vector GetSwitchBranches(const CNodePtr &cnode) { + constexpr size_t cond_start_index = 2; + std::vector branches; + for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) { + branches.emplace_back(GetSwitchBranch(cnode, index)); } + return branches; + } - // Goto sub-graph label. - uint32_t graph_label = GetOrCreateGraphLabel(graph); - auto goto_node = LabelGoto(graph_label); + CallBranch GetSwitchBranch(const CNodePtr &cnode, size_t index) { + auto partial_cnode = dyn_cast(cnode->input(index)); + if (partial_cnode == nullptr) { + return {nullptr, {}}; + } + auto &inputs = partial_cnode->inputs(); + if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) { + MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString(); + } + auto graph = GetValueNode(inputs.at(1)); + constexpr size_t arg_index = 2; + std::vector args{inputs.begin() + arg_index, inputs.end()}; + return {.graph = graph, .args = std::move(args)}; + } - // Set child graph attribute, so that subsequence steps such - // as 'select kernel' can handle sub graphs. - SetChildGrapAttr(goto_node, {graph}); + static AnfNodePtr GetRealNode(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { + return node; + } + return GetRealNode(node->cast()->input(1)); + } - // Setup return label if this is not a tail call or it is a recursive call. - const bool is_tail_call = (cnode == tail_call_node_); - const bool need_return = (!is_tail_call || recursive_); - if (!need_return) { - // Set as end_goto if no return required. - kernel_graph_->set_end_goto(goto_node); + private: + const KernelGraphPtr &kernel_graph_; + AscendAutoMonadContext &context_; +}; + +// +// AscendAutoMonadConverter convert control flow to monad form +// for a kernel graph and its children graphs recursively. +// +class AscendAutoMonadConverter { + public: + static void Run(AscendAutoMonadContext *context) { + for (auto &entry : context->call_info_map) { + AscendAutoMonadConverter converter(entry.first, context, &entry.second); + converter.Run(); } - auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return); + } + + private: + AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info) + : kernel_graph_(kg), context_(*context), call_info_(*call_info) {} + ~AscendAutoMonadConverter() = default; + + void Run() { + // Setup entry label if found. + SetupEntryLabel(); - // Handle sub-graph recursively. - HandleSubGraph(graph, output_para, return_label); + // Handle call sites. + for (auto &call_site : call_info_.call_sites) { + HandleCallSite(call_site); + } + // Handle return points. + HandleReturnPoints(); + // Let output depend on monad. + if (monad_) { + MakeMonadDepend(); + } } - // - // Convert switch/switchlayer node: - // branch1 = Partial(graph1, arg) - // branch2 = Partial(graph2, arg) - // out = Switch/SwitchLayer(cond/index, branch1, branch2) - // to: - // r = link_args(graph1, arg) - // c = UpdateState(c, r) - // r = link_args(graph2, arg) - // c = UpdateState(c, r) - // c = LabelSwitch(cond/index, c) : L1, L2 - // c = LabelSet(c) : - // - void HandleSwitch(const CNodePtr &cnode) { + void HandleCallSite(const CallSite &call_site) { // Update last_monad_. - last_monad_ = monad_map_[cnode]; + last_monad_ = call_site.last_monad; - // Get branches of the switch or switchlayer, true or 0 branch first. - auto branches = GetSwitchBranches(cnode); + // The call/switch/switch_layer cnode. + auto &cnode = call_site.cnode; - // Link arguments and generate labels for branches. + // Get branches of the call_site. + // for call, there is one branch; + // for switch, the first one is true branch; + // for switch_layer, the first one is 0 branch. + auto &branches = call_site.callees; + + // Link arguments and find labels for branches. std::vector graphes; std::vector labels; graphes.reserve(branches.size()); - labels.reserve(graphes.size()); + labels.reserve(branches.size()); for (auto &[graph, args] : branches) { - if (graph == nullptr) { - MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); - } + MS_EXCEPTION_IF_NULL(graph); auto linked_args = LinkArguments(args, graph); if (linked_args != nullptr) { monad_ = UpdateState(GetMonad(), linked_args); } graphes.push_back(graph); - labels.push_back(GetOrCreateGraphLabel(graph)); + labels.push_back(GetGraphLabel(graph)); } - const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch); - if (is_switch) { - // For Switch, we reverse the graphes and labels, so that the false branch - // is the first one, since for kernel LabelSwitch, false is the first branch. + // Assign label indexes if required. + AssignLabelIndexes(call_site); + + // For Switch, we reverse the graphes and labels, so that the false branch + // is the first one, since for kernel LabelSwitch, false is the first branch. + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { std::reverse(graphes.begin(), graphes.end()); std::reverse(labels.begin(), labels.end()); } - // Add LabelSwith node. - auto switch_node = LabelSwitch(cnode->input(1), labels); + // Create LabelGoto or LabelSwitch node. + auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); - // Set child graph attribute for switch node. - SetChildGrapAttr(switch_node, graphes); - - if (!is_switch) { - // Mark the switch node is for 'switch_layer'. - AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node); + // Setup return label and output if required. + if (call_site.return_label != kNoLabel) { + auto label_node = LabelSet(call_site.return_label); + AnfNodePtr output = call_site.out_param; + MS_EXCEPTION_IF_NULL(output); + // Let output depend on the label node, this ensures the + // return label is set before output is used. + output = MakeDepend(output, label_node); + // Replace the the call/switch node with the output. + ReplaceNode(cnode, output); + return; } - // Setup return label if required. - const bool is_tail_call = (cnode == tail_call_node_); - const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_); - auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return); - - // Handle sub-graphs recursively. - for (auto &graph : graphes) { - HandleSubGraph(graph, output_para, return_label); + // If no return label required, it should be a tail call. + if (!call_site.tail) { + MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString(); } + // For tail calls, replace origin call node with label_goto/label_switch. + ReplaceNode(cnode, label_goto_switch); + kernel_graph_->set_end_goto(label_goto_switch); } - AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector &branches) { - const bool is_tail_call = (cnode == tail_call_node_); - if (is_tail_call && output_parameter_ != nullptr) { - return output_parameter_; + // Assign label indexes to label parameters for a call site. + void AssignLabelIndexes(const CallSite &call_site) { + for (auto &[label_param, label_index] : call_site.label_indexes) { + auto index_value = GetIndexValueNode(label_index); + auto assign = Assign(label_param, index_value); + monad_ = UpdateState(GetMonad(), assign); } - return context_.GetOutputParameter(branches.front()); } - // Make return part of a call for the LabelGoto/LabelSwitch node. - std::tuple MakeReturn(const CNodePtr &cnode, const std::vector &branches, - bool need_return) { - // Prepare return label. - uint32_t return_label = return_label_; - // Prepare output parameter. - auto output_para = GetOutputParameter(cnode, branches); - // Use same output parameter for all branches. - for (auto &branch : branches) { - context_.SetOutputParameter(branch, output_para); - } - auto output = output_para; - // Setup return label if return is required. - if (need_return) { - // Set a new label at return point. - return_label = context_.NewLabel(); - auto label_node = LabelSet(return_label); - // Let output depend on the label node, this ensures the - // return label is set before output is used. - output = MakeDepend(output, label_node); + // Create or reuse ValueNode for the index. + ValueNodePtr GetIndexValueNode(uint32_t index) { + auto iter = index_nodes_.find(index); + if (iter != index_nodes_.end()) { + // Reuse ValueNode for same index. + return iter->second; } - - // Replace the the call/switch node with the output. - kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output)); - return {output_para, return_label}; + // Create a new ValueNode on top graph for the index. + auto &top_graph = context_.TopGraph(); + std::vector data = {static_cast(index)}; + auto tensor = std::make_shared(data, kInt32); + auto value_node = top_graph->NewValueNode(tensor->ToAbstract(), tensor); + top_graph->AddValueNodeToGraph(value_node); + index_nodes_.emplace(index, value_node); + return value_node; } - // Handle sub-graphs recursively. - void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) { - AscendAutoMonadConverter converter(&context_, graph); - converter.output_parameter_ = out_para; - converter.return_label_ = return_label; - converter.Run(); + // Replace a node with new node in current kernel graph. + // We also replace the arguments used for sub-graph calls. + void ReplaceNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + kernel_graph_->ReplaceNode(NOT_NULL(old_node), NOT_NULL(new_node)); + for (auto &call_site : call_info_.call_sites) { + for (auto &callee : call_site.callees) { + std::replace(callee.args.begin(), callee.args.end(), old_node, new_node); + } + } } - KernelGraphPtr GetCallGraph(const CNodePtr &cnode) { - auto input_graph = cnode->input(kCallKernelGraphIndex); - MS_EXCEPTION_IF_NULL(input_graph); - return GetValueNode(input_graph); + // Make a label_goto or label_switch for a Call/Switch/SwitchLayer node. + CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector &graphes, + const std::vector &labels) { + // Create LabelGoto or LabelSwitch according the cnode type. + const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall); + auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels)); + + // Set child graph attribute for the LabelGoto or LabelSwitch node. + SetChildGrapAttr(label_goto_switch, graphes); + + // Mark the label_switch node is for 'switch_layer' if it is. + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { + AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch); + } + return label_goto_switch; } - GraphArgPair GetSwitchBranch(const CNodePtr &cnode, size_t index) { - auto partial_cnode = dyn_cast(cnode->input(index)); - if (partial_cnode == nullptr) { - return {nullptr, {}}; + // + // Handle return points. + // use label_goto for single return point; + // use label_switch for multi return points. + // + void HandleReturnPoints() { + auto &return_points = call_info_.return_points; + // No return points. + if (return_points.empty()) { + return; } - auto &inputs = partial_cnode->inputs(); - if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) { - MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString(); + // Single return point. + if (return_points.size() == 1) { + // Insert Assign for output parameter. + auto &return_point = return_points.front(); + AssignOutput(return_point); + // Insert label_goto for return. + auto return_goto = LabelGoto(return_point.call_site->return_label); + AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); + kernel_graph_->set_end_goto(return_goto); + return; } - auto graph = GetValueNode(inputs.at(1)); - constexpr size_t arg_index = 2; - return {graph, {inputs.begin() + arg_index, inputs.end()}}; + // Multi return points. + std::vector return_labels; + return_labels.reserve(return_points.size()); + for (auto &return_point : return_points) { + // Assign output to out_params of each return point. + AssignOutput(return_point); + // Get return labels. + return_labels.emplace_back(return_point.call_site->return_label); + } + // Insert label_switch for multi return points. + auto &label_param = call_info_.label_param; + MS_EXCEPTION_IF_NULL(label_param); + auto return_switch = LabelSwitch(label_param, return_labels); + AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch); + kernel_graph_->set_end_goto(return_switch); } - std::vector GetSwitchBranches(const CNodePtr &cnode) { - constexpr size_t cond_start_index = 2; - // switch branches - std::vector switch_branches; - for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) { - switch_branches.emplace_back(GetSwitchBranch(cnode, index)); - } - return switch_branches; + // Assign graph output to the output parameter for a return point. + void AssignOutput(const ReturnPoint &return_point) { + auto call_site = return_point.call_site; + MS_EXCEPTION_IF_NULL(call_site); + auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); + monad_ = UpdateState(GetMonad(), assign_output); } // @@ -572,6 +741,7 @@ class AscendAutoMonadConverter { return kernel_graph_->NewCNode(tuple_inputs); } + // Insert UpdateState after input node. AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) { auto update_state = NewValueNode(prim::kPrimUpdateState); auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input}); @@ -589,11 +759,14 @@ class AscendAutoMonadConverter { // c = LabelSet(c) : entry_label // return add(x, y) // - void SetupEntryLabel(uint32_t entry_label) { - // Set entry label. - auto label_node = LabelSet(entry_label); - // Make start label the first one in execution order. - kernel_graph_->set_start_label(label_node); + void SetupEntryLabel() { + auto entry_label = GetGraphLabel(kernel_graph_); + if (entry_label != kNoLabel) { + // Set entry label. + auto label_node = LabelSet(entry_label); + // Make start label the first one in execution order. + kernel_graph_->set_start_label(label_node); + } } // Make a Depend cnode. @@ -609,8 +782,10 @@ class AscendAutoMonadConverter { auto monad = GetMonad(); auto origin_output = kernel_graph_->output(); MS_EXCEPTION_IF_NULL(origin_output); - auto depend_cnode = MakeDepend(origin_output, monad); - kernel_graph_->set_output(depend_cnode); + if (origin_output != monad) { + auto depend_cnode = MakeDepend(origin_output, monad); + kernel_graph_->set_output(depend_cnode); + } } // Gets the last monad node, we use a separated UMonad for control flow. @@ -665,42 +840,17 @@ class AscendAutoMonadConverter { return cnode; } - // Return kNoLabel when label id attribute not set for the graph. - uint32_t GetGraphLabel(const KernelGraphPtr &kg) { - auto value = kg->get_attr(kAttrLabelIndex); - if (value == nullptr) { - return kNoLabel; - } - return GetValue(value); - } - - // Get or create entry label for the given graph. - uint32_t GetOrCreateGraphLabel(const KernelGraphPtr &kg) { - auto label = GetGraphLabel(kg); - if (label == kNoLabel) { - // Allocate a new label id and save it to the graph. - label = context_.NewLabel(); - kg->set_attr(kAttrLabelIndex, MakeValue(label)); - } - return label; - } - + // Set child graph attribute for label_goto/label_switch node. void SetChildGrapAttr(const AnfNodePtr &node, const std::vector &graphs) { AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node); } private: + const KernelGraphPtr &kernel_graph_; AscendAutoMonadContext &context_; - const KernelGraphPtr kernel_graph_; - - // Tail call node, null if not found. - CNodePtr tail_call_node_; - - // Call/Switch nodes. - std::vector call_switch_nodes_; - // Call/Switch node to monad map. - std::map monad_map_; + // Call info for current kernel graph. + CallInfo &call_info_; // The last monad for Call/Switch node. AnfNodePtr last_monad_; @@ -711,14 +861,8 @@ class AscendAutoMonadConverter { // The control flow monad const value node. AnfNodePtr monad_value_; - // Parameter to store the return value. - AnfNodePtr output_parameter_; - - // The return label id. - uint32_t return_label_ = kNoLabel; - - // Is this graph include recursive calls. - bool recursive_ = false; + // Index value node cache for reuse. + std::map index_nodes_; }; constexpr size_t kAssignTargetIndex = 1; @@ -985,9 +1129,10 @@ class ExecuteOrderGenerator { void AscendAutoMonad::Run() { MS_LOG(DEBUG) << "Ascend auto-monad start."; - AscendAutoMonadContext context(kernel_graph_.get()); - AscendAutoMonadConverter converter(&context, kernel_graph_.get()); - converter.Run(); + auto kg = kernel_graph_.get(); + AscendAutoMonadContext context(kg); + CallInfoFinder::Run(&context); + AscendAutoMonadConverter::Run(&context); kernel_graph_->set_label_num(context.CurrentLabel()); MS_LOG(DEBUG) << "Ascend auto-monad finish."; DumpGraphForDebug(kernel_graph_);