From 88c92cd2637d5284ccb0aa1b8a53c37c57c3b61a Mon Sep 17 00:00:00 2001 From: jjfeing Date: Wed, 28 Apr 2021 10:52:15 +0800 Subject: [PATCH] clear parameter when param_info clone --- .../optimizer/pass/convert_const_input_to_tensor_input.cc | 7 ++++++- mindspore/ccsrc/utils/utils.h | 1 + mindspore/core/ir/param_info.h | 2 ++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc index 0abf418288..363a506157 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -17,7 +17,7 @@ #include #include -#include +#include #include "ir/graph_utils.h" #include "backend/optimizer/common/helper.h" @@ -62,6 +62,11 @@ AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kerne AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(cnode); + const std::set no_need_to_convert_nodes = {kStackOpName}; + auto node_type = AnfAlgo::GetCNodeName(cnode); + if (no_need_to_convert_nodes.find(node_type) != no_need_to_convert_nodes.end()) { + return nullptr; + } std::vector new_inputs; auto kernel_graph = func_graph->cast>(); auto inputs = cnode->inputs(); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4552c3233b..0437ce3255 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -119,6 +119,7 @@ constexpr auto kTransDataOpName = "TransData"; constexpr auto kStackInitOpName = "StackInit"; constexpr auto kStackPushOpName = "StackPush"; constexpr auto kStackPopOpName = "StackPop"; +constexpr auto kStackOpName = "Stack"; constexpr auto kStackDestroyOpName = "StackDestroy"; constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad"; constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; diff --git a/mindspore/core/ir/param_info.h b/mindspore/core/ir/param_info.h index e6edf6ab63..cba7dbc407 100644 --- a/mindspore/core/ir/param_info.h +++ b/mindspore/core/ir/param_info.h @@ -72,6 +72,7 @@ class ParamInfo { this->be_cloned_ = true; this->be_cloned_index_.push_back(index); clone->init_in_server_ = this->init_in_server_; + clone->ClearParameter(); return clone; } @@ -88,6 +89,7 @@ class ParamInfo { void set_cache_shape(const std::vector &cache_shape) { cache_shape_ = cache_shape; } ParameterPtr parameter() { return parameter_; } void set_parameter(const ParameterPtr ¶meter) { parameter_ = parameter; } + void ClearParameter() { parameter_ = nullptr; } private: std::string name_{"Parameter"};