Browse Source

!3796 Gpu AdamWeightDecay fusion

Merge pull request !3796 from chenweifeng/AdamWeighDecayFusionFix
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
607cb58ae5
2 changed files with 10 additions and 8 deletions
  1. +5
    -4
      mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc
  2. +5
    -4
      mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc

+ 5
- 4
mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc View File

@@ -60,10 +60,11 @@ const BaseRef AdamFusion::DefinePattern() const {
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
return depend3;

next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
return next_param;
} }


const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {


+ 5
- 4
mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc View File

@@ -62,10 +62,11 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const {


VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})});
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})});
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})});
return depend3;

next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})});
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})});
return next_param;
} }


const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,


Loading…
Cancel
Save