From fbc696ad7560830664b0af1a260d733c9b649b8c Mon Sep 17 00:00:00 2001 From: wilfChen Date: Fri, 31 Jul 2020 09:35:23 +0800 Subject: [PATCH] gpu AdamWeightDecay --- mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc | 9 +++++---- .../backend/optimizer/gpu/adam_weight_decay_fusion.cc | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc index 41e4abee27..d9194950b4 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -60,10 +60,11 @@ const BaseRef AdamFusion::DefinePattern() const { {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; + + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + return next_param; } const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc index c95945c980..f0f4ac6f36 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -62,10 +62,11 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const { VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; + + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); + next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + return next_param; } const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,