|
|
|
@@ -1155,6 +1155,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const std::vector<std::string> trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd), |
|
|
|
string(kNameAssignSub)}; |
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { |
|
|
|
OperatorPtr src = Convert(node); |
|
|
|
auto &inputs = node->inputs(); |
|
|
|
@@ -1167,6 +1170,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node |
|
|
|
if (IsValueNode<None>(pred)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op. |
|
|
|
std::string c_name = GetCNodeFuncName(node); |
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); |
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) { |
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name(); |
|
|
|
auto op_itor = op_cache_.find(pred.get()); |
|
|
|
if (op_itor == op_cache_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; |
|
|
|
} |
|
|
|
if (op_itor->second != nullptr && |
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && |
|
|
|
vars_.find(name) != vars_.end()) { |
|
|
|
auto variable = std::make_shared<Variable>(name); |
|
|
|
auto desc = vars_[name]->GetOutputDesc("y"); |
|
|
|
(void)variable->update_output_desc_y(desc); |
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; |
|
|
|
op_itor->second = variable; // replace parameter with variable |
|
|
|
vars_[name] = variable; |
|
|
|
} |
|
|
|
} |
|
|
|
// find in out_hadnle_cache_ first |
|
|
|
auto it = out_handle_cache_.find(pred.get()); |
|
|
|
if (it != out_handle_cache_.end()) { |
|
|
|
|