From: @huangbingjian Reviewed-by: @zh_qh,@wilfchen Signed-off-by: @wilfchenpull/14795/MERGE
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -80,11 +80,12 @@ AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AdamFusion::DefinePattern() const { | const BaseRef AdamFusion::DefinePattern() const { | ||||
| VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_}); | |||||
| VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | ||||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||||
| VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | ||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | ||||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||||
| VectorRef next_v = | VectorRef next_v = | ||||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | ||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | ||||
| @@ -92,10 +93,13 @@ const BaseRef AdamFusion::DefinePattern() const { | |||||
| VectorRef update = | VectorRef update = | ||||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | ||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); | ||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||||
| VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr}); | |||||
| VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v}); | |||||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load}); | |||||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); | |||||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); | |||||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state}); | |||||
| next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param}); | |||||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | ||||
| VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | ||||
| @@ -35,7 +35,6 @@ class AdamFusion : public PatternProcessPass { | |||||
| v_ = std::make_shared<Var>(); | v_ = std::make_shared<Var>(); | ||||
| gradient_ = std::make_shared<Var>(); | gradient_ = std::make_shared<Var>(); | ||||
| u_ = std::make_shared<Var>(); | u_ = std::make_shared<Var>(); | ||||
| u2_ = std::make_shared<Var>(); | |||||
| } | } | ||||
| ~AdamFusion() override = default; | ~AdamFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -53,7 +52,6 @@ class AdamFusion : public PatternProcessPass { | |||||
| VarPtr v_; | VarPtr v_; | ||||
| VarPtr gradient_; | VarPtr gradient_; | ||||
| VarPtr u_; | VarPtr u_; | ||||
| VarPtr u2_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -81,23 +81,28 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, | |||||
| } // namespace | } // namespace | ||||
| const BaseRef AdamWeightDecayFusion::DefinePattern() const { | const BaseRef AdamWeightDecayFusion::DefinePattern() const { | ||||
| VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_}); | |||||
| VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); | ||||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||||
| VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), | ||||
| VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); | ||||
| VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); | |||||
| VectorRef next_v = | VectorRef next_v = | ||||
| VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), | ||||
| VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); | ||||
| VectorRef update = | VectorRef update = | ||||
| VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); | ||||
| VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); | |||||
| VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, load_param}), update}); | |||||
| VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); | ||||
| VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); | |||||
| VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr}); | |||||
| VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v}); | |||||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load}); | |||||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); | |||||
| VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); | |||||
| VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state}); | |||||
| next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param}); | |||||
| next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); | ||||
| VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); | ||||
| @@ -36,7 +36,6 @@ class AdamWeightDecayFusion : public PatternProcessPass { | |||||
| v_ = std::make_shared<Var>(); | v_ = std::make_shared<Var>(); | ||||
| gradient_ = std::make_shared<Var>(); | gradient_ = std::make_shared<Var>(); | ||||
| u_ = std::make_shared<Var>(); | u_ = std::make_shared<Var>(); | ||||
| u2_ = std::make_shared<Var>(); | |||||
| } | } | ||||
| ~AdamWeightDecayFusion() override = default; | ~AdamWeightDecayFusion() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| @@ -55,7 +54,6 @@ class AdamWeightDecayFusion : public PatternProcessPass { | |||||
| VarPtr v_; | VarPtr v_; | ||||
| VarPtr gradient_; | VarPtr gradient_; | ||||
| VarPtr u_; | VarPtr u_; | ||||
| VarPtr u2_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||