|
|
|
@@ -60,10 +60,11 @@ const BaseRef AdamFusion::DefinePattern() const { |
|
|
|
{prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); |
|
|
|
VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); |
|
|
|
VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); |
|
|
|
VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); |
|
|
|
VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); |
|
|
|
VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); |
|
|
|
return depend3; |
|
|
|
|
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); |
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); |
|
|
|
next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); |
|
|
|
return next_param; |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { |
|
|
|
|