| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -166,6 +166,63 @@ struct CallInfo { | |||||
| AnfNodePtr label_param = nullptr; | AnfNodePtr label_param = nullptr; | ||||
| }; | }; | ||||
| // | |||||
| // ParameterPool cache parameters by its abstract, so that we can reuse | |||||
| // parameter with same abstract to store return values. | |||||
| // | |||||
| class ParameterPool { | |||||
| public: | |||||
| explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} | |||||
| ~ParameterPool() = default; | |||||
| // Create or get a parameter from pool with the given abstract. | |||||
| AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { | |||||
| // Find parameter in pool by the given abstract. | |||||
| auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { | |||||
| auto para_abs = para->abstract(); | |||||
| // Reuse output parameter with compatible abstract. | |||||
| return IsCompatible(abs, para_abs); | |||||
| }); | |||||
| // Return the parameter if found. | |||||
| if (iter != paras_.end()) { | |||||
| return *iter; | |||||
| } | |||||
| // If parameter not found with the given abstract, create a new one. | |||||
| auto para = top_graph_->NewParameter(abs); | |||||
| auto out_para = top_graph_->TransTupleToMakeTuple(para); | |||||
| // This is required, so that device memory can be allocated for it. | |||||
| top_graph_->AddChildGraphResult(out_para); | |||||
| // Save new para to pool. | |||||
| paras_.push_back(out_para); | |||||
| return out_para; | |||||
| } | |||||
| protected: | |||||
| // Check if one abstract is compatible with another abstract. | |||||
| static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { | |||||
| if (a1 == nullptr || a2 == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) { | |||||
| // This make AbstractRef compatible with AbstractTensor. | |||||
| auto &t1 = static_cast<abstract::AbstractTensor &>(*a1); | |||||
| auto &t2 = static_cast<abstract::AbstractTensor &>(*a2); | |||||
| return t1 == t2; | |||||
| } | |||||
| return *a1 == *a2; | |||||
| } | |||||
| private: | |||||
| // The top graph. | |||||
| const KernelGraphPtr &top_graph_; | |||||
| // Cached parameters. | |||||
| std::vector<AnfNodePtr> paras_; | |||||
| }; | |||||
| // | |||||
| // Base class for context. | |||||
| // | |||||
| class BaseContext { | class BaseContext { | ||||
| public: | public: | ||||
| void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } | void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } | ||||
| @@ -185,7 +242,7 @@ class BaseContext { | |||||
| // | // | ||||
| class AscendAutoMonadContext : public BaseContext { | class AscendAutoMonadContext : public BaseContext { | ||||
| public: | public: | ||||
| explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {} | |||||
| explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {} | |||||
| ~AscendAutoMonadContext() = default; | ~AscendAutoMonadContext() = default; | ||||
| // Label id start from 1, and increased by 1 for each new id. | // Label id start from 1, and increased by 1 for each new id. | ||||
| @@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext { | |||||
| return out_para; | return out_para; | ||||
| } | } | ||||
| // Get or create a temporary parameter for the given abstract. | |||||
| AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); } | |||||
| const KernelGraphPtr &TopGraph() const { return top_graph_; } | const KernelGraphPtr &TopGraph() const { return top_graph_; } | ||||
| // Map kernel_graph to its call info. | // Map kernel_graph to its call info. | ||||
| @@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext { | |||||
| // The top graph. | // The top graph. | ||||
| const KernelGraphPtr &top_graph_; | const KernelGraphPtr &top_graph_; | ||||
| // Map kernel_graph to its output parameter. | |||||
| std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_; | |||||
| // The parameter pool that cache parameters for return value. | |||||
| ParameterPool param_pool_; | |||||
| // Current label id. | // Current label id. | ||||
| uint32_t label_id_ = 1; | uint32_t label_id_ = 1; | ||||
| @@ -521,9 +581,18 @@ class AscendAutoMonadConverter { | |||||
| auto label_node = LabelSet(call_site.return_label); | auto label_node = LabelSet(call_site.return_label); | ||||
| AnfNodePtr output = call_site.out_param; | AnfNodePtr output = call_site.out_param; | ||||
| MS_EXCEPTION_IF_NULL(output); | MS_EXCEPTION_IF_NULL(output); | ||||
| // Let output depend on the label node, this ensures the | |||||
| // return label is set before output is used. | |||||
| output = MakeDepend(output, label_node); | |||||
| const bool is_single_call = call_site.label_indexes.empty(); | |||||
| if (is_single_call) { | |||||
| // For single call, let output depend on the label node, | |||||
| // this ensures the return label is set before output is used. | |||||
| output = MakeDepend(output, label_node); | |||||
| } else { | |||||
| // For multi-return call, assign result from temp parameter to | |||||
| // output parameter, this prevent result be overwritten by next call. | |||||
| auto tmp_param = context_.GetTempParameter(output->abstract()); | |||||
| output = AssignAll(output, tmp_param); | |||||
| monad_ = UpdateState(GetMonad(), output); | |||||
| } | |||||
| // Replace the the call/switch node with the output. | // Replace the the call/switch node with the output. | ||||
| ReplaceNode(cnode, output); | ReplaceNode(cnode, output); | ||||
| return; | return; | ||||
| @@ -603,12 +672,12 @@ class AscendAutoMonadConverter { | |||||
| if (return_points.empty()) { | if (return_points.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| // Assign output according the return points. | |||||
| AssignOutput(return_points); | |||||
| // Single return point. | // Single return point. | ||||
| if (return_points.size() == 1) { | if (return_points.size() == 1) { | ||||
| // Insert Assign for output parameter. | |||||
| auto &return_point = return_points.front(); | |||||
| AssignOutput(return_point); | |||||
| // Insert label_goto for return. | // Insert label_goto for return. | ||||
| auto &return_point = return_points.front(); | |||||
| auto return_goto = LabelGoto(return_point.call_site->return_label); | auto return_goto = LabelGoto(return_point.call_site->return_label); | ||||
| AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); | AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); | ||||
| kernel_graph_->set_end_goto(return_goto); | kernel_graph_->set_end_goto(return_goto); | ||||
| @@ -617,12 +686,9 @@ class AscendAutoMonadConverter { | |||||
| // Multi return points. | // Multi return points. | ||||
| std::vector<uint32_t> return_labels; | std::vector<uint32_t> return_labels; | ||||
| return_labels.reserve(return_points.size()); | return_labels.reserve(return_points.size()); | ||||
| for (auto &return_point : return_points) { | |||||
| // Assign output to out_params of each return point. | |||||
| AssignOutput(return_point); | |||||
| // Get return labels. | |||||
| return_labels.emplace_back(return_point.call_site->return_label); | |||||
| } | |||||
| // Get return labels from return points. | |||||
| std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels), | |||||
| [](const ReturnPoint &return_point) { return return_point.call_site->return_label; }); | |||||
| // Insert label_switch for multi return points. | // Insert label_switch for multi return points. | ||||
| auto &label_param = call_info_.label_param; | auto &label_param = call_info_.label_param; | ||||
| MS_EXCEPTION_IF_NULL(label_param); | MS_EXCEPTION_IF_NULL(label_param); | ||||
| @@ -631,11 +697,18 @@ class AscendAutoMonadConverter { | |||||
| kernel_graph_->set_end_goto(return_switch); | kernel_graph_->set_end_goto(return_switch); | ||||
| } | } | ||||
| // Assign graph output to the output parameter for a return point. | |||||
| void AssignOutput(const ReturnPoint &return_point) { | |||||
| auto call_site = return_point.call_site; | |||||
| // Assign graph output to the output parameter. | |||||
| void AssignOutput(const std::vector<ReturnPoint> &return_points) { | |||||
| // For single call: we directly assign output to the output parameter of the call site; | |||||
| // For multi call: we assign output to a temp parameter, and let caller assign the | |||||
| // temp parameter to a output parameter after returned. | |||||
| auto call_site = return_points.front().call_site; | |||||
| MS_EXCEPTION_IF_NULL(call_site); | MS_EXCEPTION_IF_NULL(call_site); | ||||
| auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); | |||||
| const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty()); | |||||
| AnfNodePtr out_param = | |||||
| (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); | |||||
| MS_EXCEPTION_IF_NULL(out_param); | |||||
| auto assign_output = AssignAll(out_param, kernel_graph_->output()); | |||||
| monad_ = UpdateState(GetMonad(), assign_output); | monad_ = UpdateState(GetMonad(), assign_output); | ||||
| } | } | ||||
| @@ -699,7 +772,7 @@ class AscendAutoMonadConverter { | |||||
| // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. | // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. | ||||
| AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } | AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } | ||||
| AnfNodePtr GetAssignMonad() { | |||||
| AnfNodePtr GetLinkMonad() { | |||||
| if (last_monad_ != nullptr) { | if (last_monad_ != nullptr) { | ||||
| return last_monad_; | return last_monad_; | ||||
| } | } | ||||
| @@ -708,7 +781,7 @@ class AscendAutoMonadConverter { | |||||
| // Make a assign cnode. | // Make a assign cnode. | ||||
| CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { | CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { | ||||
| auto monad = GetAssignMonad(); | |||||
| auto monad = (is_link ? GetLinkMonad() : GetMonad()); | |||||
| auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); | auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); | ||||
| if (is_link) { | if (is_link) { | ||||
| // Mark this assign is to link real argument to formal argument. | // Mark this assign is to link real argument to formal argument. | ||||