From b8038874800c374e8b6f6ffb46a95dfed6692a50 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Mon, 12 Apr 2021 10:24:49 +0800 Subject: [PATCH] update adam_fusion and adam_weight_decay_fusion --- .../ccsrc/backend/optimizer/gpu/adam_fusion.cc | 16 ++++++++++------ .../ccsrc/backend/optimizer/gpu/adam_fusion.h | 2 -- .../optimizer/gpu/adam_weight_decay_fusion.cc | 17 +++++++++++------ .../optimizer/gpu/adam_weight_decay_fusion.h | 2 -- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc index b28e95b2d5..b92cf9d1fe 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,11 +80,12 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u } // namespace const BaseRef AdamFusion::DefinePattern() const { + VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_}); VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); + VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); + VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - - VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); VectorRef next_v = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); @@ -92,10 +93,13 @@ const BaseRef AdamFusion::DefinePattern() const { VectorRef update = VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr}); + + VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v}); + VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load}); - VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); - VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); + VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state}); + next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param}); next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h index a83c4281da..b65fd1257d 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h @@ -35,7 +35,6 @@ class AdamFusion : public PatternProcessPass { v_ = std::make_shared(); gradient_ = std::make_shared(); u_ = std::make_shared(); - u2_ = std::make_shared(); } ~AdamFusion() override = default; const BaseRef DefinePattern() const override; @@ -53,7 +52,6 @@ class AdamFusion : public PatternProcessPass { VarPtr v_; VarPtr gradient_; VarPtr u_; - VarPtr u2_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc index ee3a7dffc5..607b9147a4 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,23 +81,28 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, } // namespace const BaseRef AdamWeightDecayFusion::DefinePattern() const { + VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_}); VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); + VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); + VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); VectorRef next_v = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); VectorRef update = VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); - VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); + VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, load_param}), update}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr}); + + VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v}); + VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load}); - VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); - VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); + VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state}); + next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param}); next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h index d29cefb222..7bf3c05ec4 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h @@ -36,7 +36,6 @@ class AdamWeightDecayFusion : public PatternProcessPass { v_ = std::make_shared(); gradient_ = std::make_shared(); u_ = std::make_shared(); - u2_ = std::make_shared(); } ~AdamWeightDecayFusion() override = default; const BaseRef DefinePattern() const override; @@ -55,7 +54,6 @@ class AdamWeightDecayFusion : public PatternProcessPass { VarPtr v_; VarPtr gradient_; VarPtr u_; - VarPtr u2_; }; } // namespace opt } // namespace mindspore