From 81077588fe7cae6b0c0b78099394e05ecf3d4a32 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Wed, 29 Apr 2020 10:00:51 +0800 Subject: [PATCH] adapt ApplyAdam remove outputs m and v --- graphengine | 2 +- mindspore/ccsrc/transform/op_declare.cc | 4 ---- mindspore/ops/operations/nn_ops.py | 8 ++------ 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/graphengine b/graphengine index d345a800a4..976d1e31b7 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit d345a800a4f7c32eb768ea48667d1ce00b841748 +Subproject commit 976d1e31b777d65f87333c3a125093946e682a6e diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 76fec9e21c..79b2808b2e 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -466,11 +466,7 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {10, INPUT_DESC(grad)}}; ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; -#ifdef ENABLE_GE -OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; -#else OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; -#endif // Relu6 INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 0687806bb2..6371fc4654 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2323,11 +2323,7 @@ class Adam(PrimitiveWithInfer): - **gradient** (Tensor) - Gradients. Outputs: - Tuple of 3 Tensor, the updated parameters. - - **var** (Tensor) - The same shape and data type as `var`. - - **m** (Tensor) - The same shape and data type as `m`. - - **v** (Tensor) - The same shape and data type as `v`. """ @prim_attr_register @@ -2340,7 +2336,7 @@ class Adam(PrimitiveWithInfer): validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) - return var_shape, m_shape, v_shape + return var_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): @@ -2350,7 +2346,7 @@ class Adam(PrimitiveWithInfer): args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) - return var_dtype, m_dtype, v_dtype + return var_dtype class BinaryCrossEntropy(PrimitiveWithInfer):