Browse Source

!5893 fix bugs of adam and lamb optimizer

Merge pull request !5893 from fary86/fix_adam_lamb_bug
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
174de814fb
2 changed files with 7 additions and 7 deletions
  1. +3
    -3
      mindspore/nn/optim/adam.py
  2. +4
    -4
      mindspore/nn/optim/lamb.py

+ 3
- 3
mindspore/nn/optim/adam.py View File

@@ -51,7 +51,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
Returns:
Tensor, the new value of v after updating.
"""
success = True
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
@@ -81,8 +80,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
success = F.depend(success, next_param)
return success

return op_cast(next_param, F.dtype(param))
return gradient


@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",


+ 4
- 4
mindspore/nn/optim/lamb.py View File

@@ -104,11 +104,11 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v

next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_param = F.depend(next_param, F.assign(param, next_param))
next_param = F.depend(next_param, F.assign(m, next_m))
next_param = F.depend(next_param, F.assign(v, next_v))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))

return next_param
return op_cast(next_param, F.dtype(param))
return gradient




Loading…
Cancel
Save