From 5f2bfb5679d39be091053ae1d807bbd2954bc49a Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Fri, 24 Apr 2020 14:35:12 +0800 Subject: [PATCH] trans const to variable in assign case --- mindspore/ccsrc/transform/convert.cc | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) mode change 100755 => 100644 mindspore/ccsrc/transform/convert.cc diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc old mode 100755 new mode 100644 index 2daa86b960..36faa5787a --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -1154,6 +1154,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { } } +const std::vector trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd), + string(kNameAssignSub)}; + void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { OperatorPtr src = Convert(node); auto &inputs = node->inputs(); @@ -1166,6 +1169,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node if (IsValueNode(pred)) { continue; } + // transform "Const" op to "Variable" op when the next node is "Assign" op. + std::string c_name = GetCNodeFuncName(node); + auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); + if (!training_ && pos != trans_var_list.end() && pred->isa()) { + std::string name = std::static_pointer_cast(pred)->name(); + auto op_itor = op_cache_.find(pred.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; + } + if (op_itor->second != nullptr && + (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && + vars_.find(name) != vars_.end()) { + auto variable = std::make_shared(name); + auto desc = vars_[name]->GetOutputDesc("y"); + (void)variable->update_output_desc_y(desc); + MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; + } + } // find in out_hadnle_cache_ first auto it = out_handle_cache_.find(pred.get()); if (it != out_handle_cache_.end()) {