Browse Source

!83 Adapt ApplyAdam remove outputs m and v

Merge pull request !83 from zhaozhenlong/adapt-apply-adam
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
1c7f9ddd45
3 changed files with 3 additions and 11 deletions
  1. +1
    -1
      graphengine
  2. +0
    -4
      mindspore/ccsrc/transform/op_declare.cc
  3. +2
    -6
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
graphengine

@@ -1 +1 @@
Subproject commit d345a800a4f7c32eb768ea48667d1ce00b841748
Subproject commit 976d1e31b777d65f87333c3a125093946e682a6e

+ 0
- 4
mindspore/ccsrc/transform/op_declare.cc View File

@@ -466,11 +466,7 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{10, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
#ifdef ENABLE_GE
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
#else
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}};
#endif

// Relu6
INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}};


+ 2
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -2323,11 +2323,7 @@ class Adam(PrimitiveWithInfer):
- **gradient** (Tensor) - Gradients.

Outputs:
Tuple of 3 Tensor, the updated parameters.

- **var** (Tensor) - The same shape and data type as `var`.
- **m** (Tensor) - The same shape and data type as `m`.
- **v** (Tensor) - The same shape and data type as `v`.
"""

@prim_attr_register
@@ -2340,7 +2336,7 @@ class Adam(PrimitiveWithInfer):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape, m_shape, v_shape
return var_shape

def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
@@ -2350,7 +2346,7 @@ class Adam(PrimitiveWithInfer):
args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
"beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
return var_dtype, m_dtype, v_dtype
return var_dtype


class BinaryCrossEntropy(PrimitiveWithInfer):


Loading…
Cancel
Save