| @@ -37,6 +37,7 @@ static constexpr size_t kCNodeSwitchLayerBranch = 2; | |||||
| static constexpr size_t kCNodeSwitchLayerLength = 3; | static constexpr size_t kCNodeSwitchLayerLength = 3; | ||||
| static constexpr size_t kCNodeAssignTarget = 1; | static constexpr size_t kCNodeAssignTarget = 1; | ||||
| static constexpr size_t kCNodeAssignSource = 2; | static constexpr size_t kCNodeAssignSource = 2; | ||||
| static constexpr size_t kCNodeAssignDestination = 1; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -291,13 +292,15 @@ void AscendControlParser::EraseAssign(std::shared_ptr<ReferenceCounter> paramete | |||||
| } | } | ||||
| auto &assign_node = assign_iter->second; | auto &assign_node = assign_iter->second; | ||||
| MS_EXCEPTION_IF_NULL(assign_node); | MS_EXCEPTION_IF_NULL(assign_node); | ||||
| if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { | |||||
| auto source = assign_node->input(kCNodeAssignSource); | |||||
| auto destination = assign_node->input(kCNodeAssignDestination); | |||||
| // not assign node or assign destination is transdata which for ref parameter(write 2 times) -> continue | |||||
| if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign) || IsPrimitiveCNode(destination, prim::KPrimTransData)) { | |||||
| parameter_count->EraseElem(para); | parameter_count->EraseElem(para); | ||||
| continue; | continue; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); | MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); | ||||
| EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); | EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); | ||||
| auto source = assign_node->input(kCNodeAssignSource); | |||||
| MS_EXCEPTION_IF_NULL(source); | MS_EXCEPTION_IF_NULL(source); | ||||
| auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; | auto visit_source = AnfAlgo::VisitKernelWithReturnType(source, 0).first; | ||||
| parameter_count->AddWriteCount(para, -1); | parameter_count->AddWriteCount(para, -1); | ||||