| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -166,6 +166,63 @@ struct CallInfo { | |||
| AnfNodePtr label_param = nullptr; | |||
| }; | |||
| // | |||
| // ParameterPool cache parameters by its abstract, so that we can reuse | |||
| // parameter with same abstract to store return values. | |||
| // | |||
| class ParameterPool { | |||
| public: | |||
| explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} | |||
| ~ParameterPool() = default; | |||
| // Create or get a parameter from pool with the given abstract. | |||
| AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { | |||
| // Find parameter in pool by the given abstract. | |||
| auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { | |||
| auto para_abs = para->abstract(); | |||
| // Reuse output parameter with compatible abstract. | |||
| return IsCompatible(abs, para_abs); | |||
| }); | |||
| // Return the parameter if found. | |||
| if (iter != paras_.end()) { | |||
| return *iter; | |||
| } | |||
| // If parameter not found with the given abstract, create a new one. | |||
| auto para = top_graph_->NewParameter(abs); | |||
| auto out_para = top_graph_->TransTupleToMakeTuple(para); | |||
| // This is required, so that device memory can be allocated for it. | |||
| top_graph_->AddChildGraphResult(out_para); | |||
| // Save new para to pool. | |||
| paras_.push_back(out_para); | |||
| return out_para; | |||
| } | |||
| protected: | |||
| // Check if one abstract is compatible with another abstract. | |||
| static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { | |||
| if (a1 == nullptr || a2 == nullptr) { | |||
| return false; | |||
| } | |||
| if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) { | |||
| // This make AbstractRef compatible with AbstractTensor. | |||
| auto &t1 = static_cast<abstract::AbstractTensor &>(*a1); | |||
| auto &t2 = static_cast<abstract::AbstractTensor &>(*a2); | |||
| return t1 == t2; | |||
| } | |||
| return *a1 == *a2; | |||
| } | |||
| private: | |||
| // The top graph. | |||
| const KernelGraphPtr &top_graph_; | |||
| // Cached parameters. | |||
| std::vector<AnfNodePtr> paras_; | |||
| }; | |||
| // | |||
| // Base class for context. | |||
| // | |||
| class BaseContext { | |||
| public: | |||
| void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } | |||
| @@ -185,7 +242,7 @@ class BaseContext { | |||
| // | |||
| class AscendAutoMonadContext : public BaseContext { | |||
| public: | |||
| explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {} | |||
| explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {} | |||
| ~AscendAutoMonadContext() = default; | |||
| // Label id start from 1, and increased by 1 for each new id. | |||
| @@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext { | |||
| return out_para; | |||
| } | |||
| // Get or create a temporary parameter for the given abstract. | |||
| AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); } | |||
| const KernelGraphPtr &TopGraph() const { return top_graph_; } | |||
| // Map kernel_graph to its call info. | |||
| @@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext { | |||
| // The top graph. | |||
| const KernelGraphPtr &top_graph_; | |||
| // Map kernel_graph to its output parameter. | |||
| std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_; | |||
| // The parameter pool that cache parameters for return value. | |||
| ParameterPool param_pool_; | |||
| // Current label id. | |||
| uint32_t label_id_ = 1; | |||
| @@ -521,9 +581,18 @@ class AscendAutoMonadConverter { | |||
| auto label_node = LabelSet(call_site.return_label); | |||
| AnfNodePtr output = call_site.out_param; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| // Let output depend on the label node, this ensures the | |||
| // return label is set before output is used. | |||
| output = MakeDepend(output, label_node); | |||
| const bool is_single_call = call_site.label_indexes.empty(); | |||
| if (is_single_call) { | |||
| // For single call, let output depend on the label node, | |||
| // this ensures the return label is set before output is used. | |||
| output = MakeDepend(output, label_node); | |||
| } else { | |||
| // For multi-return call, assign result from temp parameter to | |||
| // output parameter, this prevent result be overwritten by next call. | |||
| auto tmp_param = context_.GetTempParameter(output->abstract()); | |||
| output = AssignAll(output, tmp_param); | |||
| monad_ = UpdateState(GetMonad(), output); | |||
| } | |||
| // Replace the the call/switch node with the output. | |||
| ReplaceNode(cnode, output); | |||
| return; | |||
| @@ -603,12 +672,12 @@ class AscendAutoMonadConverter { | |||
| if (return_points.empty()) { | |||
| return; | |||
| } | |||
| // Assign output according the return points. | |||
| AssignOutput(return_points); | |||
| // Single return point. | |||
| if (return_points.size() == 1) { | |||
| // Insert Assign for output parameter. | |||
| auto &return_point = return_points.front(); | |||
| AssignOutput(return_point); | |||
| // Insert label_goto for return. | |||
| auto &return_point = return_points.front(); | |||
| auto return_goto = LabelGoto(return_point.call_site->return_label); | |||
| AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); | |||
| kernel_graph_->set_end_goto(return_goto); | |||
| @@ -617,12 +686,9 @@ class AscendAutoMonadConverter { | |||
| // Multi return points. | |||
| std::vector<uint32_t> return_labels; | |||
| return_labels.reserve(return_points.size()); | |||
| for (auto &return_point : return_points) { | |||
| // Assign output to out_params of each return point. | |||
| AssignOutput(return_point); | |||
| // Get return labels. | |||
| return_labels.emplace_back(return_point.call_site->return_label); | |||
| } | |||
| // Get return labels from return points. | |||
| std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels), | |||
| [](const ReturnPoint &return_point) { return return_point.call_site->return_label; }); | |||
| // Insert label_switch for multi return points. | |||
| auto &label_param = call_info_.label_param; | |||
| MS_EXCEPTION_IF_NULL(label_param); | |||
| @@ -631,11 +697,18 @@ class AscendAutoMonadConverter { | |||
| kernel_graph_->set_end_goto(return_switch); | |||
| } | |||
| // Assign graph output to the output parameter for a return point. | |||
| void AssignOutput(const ReturnPoint &return_point) { | |||
| auto call_site = return_point.call_site; | |||
| // Assign graph output to the output parameter. | |||
| void AssignOutput(const std::vector<ReturnPoint> &return_points) { | |||
| // For single call: we directly assign output to the output parameter of the call site; | |||
| // For multi call: we assign output to a temp parameter, and let caller assign the | |||
| // temp parameter to a output parameter after returned. | |||
| auto call_site = return_points.front().call_site; | |||
| MS_EXCEPTION_IF_NULL(call_site); | |||
| auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); | |||
| const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty()); | |||
| AnfNodePtr out_param = | |||
| (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); | |||
| MS_EXCEPTION_IF_NULL(out_param); | |||
| auto assign_output = AssignAll(out_param, kernel_graph_->output()); | |||
| monad_ = UpdateState(GetMonad(), assign_output); | |||
| } | |||
| @@ -699,7 +772,7 @@ class AscendAutoMonadConverter { | |||
| // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. | |||
| AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); } | |||
| AnfNodePtr GetAssignMonad() { | |||
| AnfNodePtr GetLinkMonad() { | |||
| if (last_monad_ != nullptr) { | |||
| return last_monad_; | |||
| } | |||
| @@ -708,7 +781,7 @@ class AscendAutoMonadConverter { | |||
| // Make a assign cnode. | |||
| CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { | |||
| auto monad = GetAssignMonad(); | |||
| auto monad = (is_link ? GetLinkMonad() : GetMonad()); | |||
| auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name()); | |||
| if (is_link) { | |||
| // Mark this assign is to link real argument to formal argument. | |||