Browse Source

!14795 modify adam_fusion and adam_weight_decay_fusion

From: @huangbingjian
Reviewed-by: @zh_qh,@wilfchen
Signed-off-by: @wilfchen
pull/14795/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
72b7d78bf9
4 changed files with 21 additions and 16 deletions
  1. +10
    -6
      mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc
  2. +0
    -2
      mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h
  3. +11
    -6
      mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc
  4. +0
    -2
      mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h

+ 10
- 6
mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -80,11 +80,12 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u
} // namespace

const BaseRef AdamFusion::DefinePattern() const {
VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_});
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});

VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});

VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
VectorRef next_v =
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});
@@ -92,10 +93,13 @@ const BaseRef AdamFusion::DefinePattern() const {
VectorRef update =
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr});

VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load});

VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state});
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});

VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});


+ 0
- 2
mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h View File

@@ -35,7 +35,6 @@ class AdamFusion : public PatternProcessPass {
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
u_ = std::make_shared<Var>();
u2_ = std::make_shared<Var>();
}
~AdamFusion() override = default;
const BaseRef DefinePattern() const override;
@@ -53,7 +52,6 @@ class AdamFusion : public PatternProcessPass {
VarPtr v_;
VarPtr gradient_;
VarPtr u_;
VarPtr u2_;
};
} // namespace opt
} // namespace mindspore


+ 11
- 6
mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -81,23 +81,28 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay,
} // namespace

const BaseRef AdamWeightDecayFusion::DefinePattern() const {
VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_});
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_});
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});

VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}),
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})});
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_});
VectorRef next_v =
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}),
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})});

VectorRef update =
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})});
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update});
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, load_param}), update});

VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update});
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr});
VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr});

VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load});

VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_});
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param});
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state});
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param});
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param});

VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state});


+ 0
- 2
mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h View File

@@ -36,7 +36,6 @@ class AdamWeightDecayFusion : public PatternProcessPass {
v_ = std::make_shared<Var>();
gradient_ = std::make_shared<Var>();
u_ = std::make_shared<Var>();
u2_ = std::make_shared<Var>();
}
~AdamWeightDecayFusion() override = default;
const BaseRef DefinePattern() const override;
@@ -55,7 +54,6 @@ class AdamWeightDecayFusion : public PatternProcessPass {
VarPtr v_;
VarPtr gradient_;
VarPtr u_;
VarPtr u2_;
};
} // namespace opt
} // namespace mindspore


Loading…
Cancel
Save