Browse Source

fix bug of type cast

tags/v0.6.0-beta
liangzelang 5 years ago
parent
commit
76ba5a643f
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/nn/optim/adam.py

+ 3
- 3
mindspore/nn/optim/adam.py View File

@@ -72,9 +72,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_v = F.depend(next_v, F.assign(param, op_cast(next_param, mstype.float16)))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, mstype.float16)))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, mstype.float16)))
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v




Loading…
Cancel
Save