|
|
|
@@ -1,5 +1,5 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* Copyright 2020-2021 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
@@ -81,23 +81,28 @@ AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef AdamWeightDecayFusion::DefinePattern() const { |
|
|
|
VectorRef load_param = VectorRef({prim::kPrimLoad, param_, u_}); |
|
|
|
VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); |
|
|
|
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); |
|
|
|
|
|
|
|
VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), |
|
|
|
VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); |
|
|
|
VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); |
|
|
|
VectorRef next_v = |
|
|
|
VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), |
|
|
|
VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); |
|
|
|
|
|
|
|
VectorRef update = |
|
|
|
VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); |
|
|
|
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); |
|
|
|
VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, load_param}), update}); |
|
|
|
|
|
|
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); |
|
|
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); |
|
|
|
VectorRef next_param = VectorRef({prim::kPrimSub, load_param, update_with_lr}); |
|
|
|
|
|
|
|
VectorRef tuple_load = VectorRef({prim::kPrimMakeTuple, load_param, load_m, load_v}); |
|
|
|
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u_, tuple_load}); |
|
|
|
|
|
|
|
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); |
|
|
|
VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); |
|
|
|
VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, next_state}); |
|
|
|
next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_param}); |
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); |
|
|
|
|
|
|
|
VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); |
|
|
|
|