Browse Source

!2381 fix some type cast bug

Merge pull request !2381 from liangzelang/dev
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
60cd188ab8
2 changed files with 5 additions and 5 deletions
  1. +3
    -3
      mindspore/nn/optim/adam.py
  2. +2
    -2
      mindspore/ops/operations/nn_ops.py

+ 3
- 3
mindspore/nn/optim/adam.py View File

@@ -72,9 +72,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))

next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, mstype.float16)))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, mstype.float16)))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, mstype.float16)))
return next_v




+ 2
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -1544,9 +1544,9 @@ class ApplyMomentum(PrimitiveWithInfer):
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T),
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T),
sig_dtype.T1),
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2)
)

@prim_attr_register


Loading…
Cancel
Save