From 01eaaed85fb296336182090e9cf5298e22eddc35 Mon Sep 17 00:00:00 2001 From: He Wei Date: Thu, 18 Mar 2021 10:49:06 +0800 Subject: [PATCH] [auto-monad] Fix multi-call output parameter be overwritten issue --- .../backend/session/ascend_auto_monad.cc | 117 ++++++++++++++---- 1 file changed, 95 insertions(+), 22 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index e4aa5651ed..e09c3d7bda 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -166,6 +166,63 @@ struct CallInfo { AnfNodePtr label_param = nullptr; }; +// +// ParameterPool cache parameters by its abstract, so that we can reuse +// parameter with same abstract to store return values. +// +class ParameterPool { + public: + explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} + ~ParameterPool() = default; + + // Create or get a parameter from pool with the given abstract. + AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { + // Find parameter in pool by the given abstract. + auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { + auto para_abs = para->abstract(); + // Reuse output parameter with compatible abstract. + return IsCompatible(abs, para_abs); + }); + // Return the parameter if found. + if (iter != paras_.end()) { + return *iter; + } + // If parameter not found with the given abstract, create a new one. + auto para = top_graph_->NewParameter(abs); + auto out_para = top_graph_->TransTupleToMakeTuple(para); + // This is required, so that device memory can be allocated for it. + top_graph_->AddChildGraphResult(out_para); + // Save new para to pool. + paras_.push_back(out_para); + return out_para; + } + + protected: + // Check if one abstract is compatible with another abstract. + static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { + if (a1 == nullptr || a2 == nullptr) { + return false; + } + if (a1->isa() && a2->isa()) { + // This make AbstractRef compatible with AbstractTensor. + auto &t1 = static_cast(*a1); + auto &t2 = static_cast(*a2); + return t1 == t2; + } + return *a1 == *a2; + } + + private: + // The top graph. + const KernelGraphPtr &top_graph_; + + // Cached parameters. + std::vector paras_; +}; + +// +// Base class for context. +// class BaseContext { public: void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } @@ -185,7 +242,7 @@ class BaseContext { // class AscendAutoMonadContext : public BaseContext { public: - explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {} + explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {} ~AscendAutoMonadContext() = default; // Label id start from 1, and increased by 1 for each new id. @@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext { return out_para; } + // Get or create a temporary parameter for the given abstract. + AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); } + const KernelGraphPtr &TopGraph() const { return top_graph_; } // Map kernel_graph to its call info. @@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext { // The top graph. const KernelGraphPtr &top_graph_; - // Map kernel_graph to its output parameter. - std::unordered_map kg_out_param_; + // The parameter pool that cache parameters for return value. + ParameterPool param_pool_; // Current label id. uint32_t label_id_ = 1; @@ -521,9 +581,18 @@ class AscendAutoMonadConverter { auto label_node = LabelSet(call_site.return_label); AnfNodePtr output = call_site.out_param; MS_EXCEPTION_IF_NULL(output); - // Let output depend on the label node, this ensures the - // return label is set before output is used. - output = MakeDepend(output, label_node); + const bool is_single_call = call_site.label_indexes.empty(); + if (is_single_call) { + // For single call, let output depend on the label node, + // this ensures the return label is set before output is used. + output = MakeDepend(output, label_node); + } else { + // For multi-return call, assign result from temp parameter to + // output parameter, this prevent result be overwritten by next call. + auto tmp_param = context_.GetTempParameter(output->abstract()); + output = AssignAll(output, tmp_param); + monad_ = UpdateState(GetMonad(), output); + } // Replace the the call/switch node with the output. ReplaceNode(cnode, output); return; @@ -603,12 +672,12 @@ class AscendAutoMonadConverter { if (return_points.empty()) { return; } + // Assign output according the return points. + AssignOutput(return_points); // Single return point. if (return_points.size() == 1) { - // Insert Assign for output parameter. - auto &return_point = return_points.front(); - AssignOutput(return_point); // Insert label_goto for return. + auto &return_point = return_points.front(); auto return_goto = LabelGoto(return_point.call_site->return_label); AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); kernel_graph_->set_end_goto(return_goto); @@ -617,12 +686,9 @@ class AscendAutoMonadConverter { // Multi return points. std::vector return_labels; return_labels.reserve(return_points.size()); - for (auto &return_point : return_points) { - // Assign output to out_params of each return point. - AssignOutput(return_point); - // Get return labels. - return_labels.emplace_back(return_point.call_site->return_label); - } + // Get return labels from return points. + std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels), + [](const ReturnPoint &return_point) { return return_point.call_site->return_label; }); // Insert label_switch for multi return points. auto &label_param = call_info_.label_param; MS_EXCEPTION_IF_NULL(label_param); @@ -631,11 +697,18 @@ class AscendAutoMonadConverter { kernel_graph_->set_end_goto(return_switch); } - // Assign graph output to the output parameter for a return point. - void AssignOutput(const ReturnPoint &return_point) { - auto call_site = return_point.call_site; + // Assign graph output to the output parameter. + void AssignOutput(const std::vector &return_points) { + // For single call: we directly assign output to the output parameter of the call site; + // For multi call: we assign output to a temp parameter, and let caller assign the + // temp parameter to a output parameter after returned. + auto call_site = return_points.front().call_site; MS_EXCEPTION_IF_NULL(call_site); - auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output()); + const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty()); + AnfNodePtr out_param = + (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); + MS_EXCEPTION_IF_NULL(out_param); + auto assign_output = AssignAll(out_param, kernel_graph_->output()); monad_ = UpdateState(GetMonad(), assign_output); } @@ -699,7 +772,7 @@ class AscendAutoMonadConverter { // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared(prim->name())); } - AnfNodePtr GetAssignMonad() { + AnfNodePtr GetLinkMonad() { if (last_monad_ != nullptr) { return last_monad_; } @@ -708,7 +781,7 @@ class AscendAutoMonadConverter { // Make a assign cnode. CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { - auto monad = GetAssignMonad(); + auto monad = (is_link ? GetLinkMonad() : GetMonad()); auto assign_prim = std::make_shared(prim::kPrimAssign->name()); if (is_link) { // Mark this assign is to link real argument to formal argument.