|
|
|
@@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d |
|
|
|
Returns: |
|
|
|
Tensor, the new value of v after updating. |
|
|
|
""" |
|
|
|
success = True |
|
|
|
if optim_filter: |
|
|
|
op_mul = P.Mul() |
|
|
|
op_square = P.Square() |
|
|
|
@@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d |
|
|
|
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) |
|
|
|
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) |
|
|
|
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) |
|
|
|
success = F.depend(success, next_param) |
|
|
|
return success |
|
|
|
|
|
|
|
return op_cast(next_param, F.dtype(param)) |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", |
|
|
|
|