From 76ba5a643f0e66a98be7c198330a85c33f4b83f0 Mon Sep 17 00:00:00 2001 From: liangzelang Date: Mon, 22 Jun 2020 11:01:13 +0800 Subject: [PATCH] fix bug of type cast --- mindspore/nn/optim/adam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 5a40d30d5a..ba56af6219 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -72,9 +72,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad update_with_lr = op_mul(lr, update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - next_v = F.depend(next_v, F.assign(param, op_cast(next_param, mstype.float16))) - next_v = F.depend(next_v, F.assign(m, op_cast(next_m, mstype.float16))) - next_v = F.depend(next_v, F.assign(v, op_cast(next_v, mstype.float16))) + next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m)))) + next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v)))) return next_v