From d52277a9a45ed79a4bbd2187708ad50c3d38cb63 Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 30 Apr 2020 11:12:29 +0800 Subject: [PATCH] Fix checking bug of ApplyCenteredRMSProp --- mindspore/ops/operations/nn_ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 66656b559e..785fafe13b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1658,9 +1658,11 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) - args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype, - "epsilon": epsilon_dtype} - validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + valid_types = [mstype.float16, mstype.float32] + args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype} + validator.check_type_same(args_rho, valid_types, self.name) + args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype} + validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True) return var_dtype