Browse Source

!79 fix_bug_of_ApplyRMSProp

Merge pull request !79 from fary86/fix_bug_of_ApplyRMSProp_incubator
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
c9ca0c90ea
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      mindspore/ops/operations/nn_ops.py

+ 5
- 3
mindspore/ops/operations/nn_ops.py View File

@@ -1544,9 +1544,11 @@ class ApplyRMSProp(PrimitiveWithInfer):
args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)

args = {"learning_rate": learning_rate_dtype, "decay": decay_dtype,
'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
valid_types = [mstype.float16, mstype.float32]
args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
validator.check_type_same(args_decay, valid_types, self.name)
args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
return var_dtype




Loading…
Cancel
Save