|
|
|
@@ -115,8 +115,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, |
|
|
|
op_sqrt = P.Sqrt() |
|
|
|
scatter_add = P.ScatterAdd(use_locking) |
|
|
|
|
|
|
|
assign_m = F.assign(m, op_mul(beta1, m)) |
|
|
|
assign_v = F.assign(v, op_mul(beta2, v)) |
|
|
|
F.assign(m, op_mul(beta1, m)) |
|
|
|
F.assign(v, op_mul(beta2, v)) |
|
|
|
|
|
|
|
grad_indices = gradient.indices |
|
|
|
grad_value = gradient.values |
|
|
|
@@ -131,17 +131,15 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, |
|
|
|
|
|
|
|
if use_nesterov: |
|
|
|
m_temp = next_m * _scaler_ten |
|
|
|
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) |
|
|
|
F.assign(m, op_mul(beta1, next_m)) |
|
|
|
div_value = scatter_add(m, |
|
|
|
op_mul(grad_indices, _scaler_one), |
|
|
|
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) |
|
|
|
param_update = div_value / (op_sqrt(next_v) + eps) |
|
|
|
|
|
|
|
m_recover = F.assign(m, m_temp / _scaler_ten) |
|
|
|
F.assign(m, m_temp / _scaler_ten) |
|
|
|
|
|
|
|
|
|
|
|
F.control_depend(m_temp, assign_m_nesterov) |
|
|
|
F.control_depend(assign_m_nesterov, div_value) |
|
|
|
F.control_depend(param_update, m_recover) |
|
|
|
else: |
|
|
|
param_update = next_m / (op_sqrt(next_v) + eps) |
|
|
|
|
|
|
|
@@ -149,8 +147,7 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, |
|
|
|
|
|
|
|
next_param = param - lr_t * param_update |
|
|
|
|
|
|
|
F.control_depend(assign_m, next_m) |
|
|
|
F.control_depend(assign_v, next_v) |
|
|
|
|
|
|
|
|
|
|
|
success = F.depend(success, F.assign(param, next_param)) |
|
|
|
success = F.depend(success, F.assign(m, next_m)) |
|
|
|
|