Browse Source

!4927 fix bug for identity

Merge pull request !4927 from flywind/pynative_identity
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
aedd6de6d5
4 changed files with 30 additions and 4 deletions
  1. +9
    -2
      mindspore/nn/cell.py
  2. +1
    -1
      mindspore/ops/functional.py
  3. +2
    -1
      mindspore/ops/operations/__init__.py
  4. +18
    -0
      mindspore/ops/operations/other_ops.py

+ 9
- 2
mindspore/nn/cell.py View File

@@ -853,8 +853,15 @@ class Cell:
self.add_flags_recursive(**flags)
return self

def set_grad(self, mode=True):
self.requires_grad = mode
def set_grad(self, requires_grad=True):
"""
Sets the cell flag for gradient.

Args:
requires_grad (bool): Specifies if the net need to grad, if it is
True, cell will construct backward network in pynative mode. Default: True.
"""
self.requires_grad = requires_grad
return self

def set_train(self, mode=True):


+ 1
- 1
mindspore/ops/functional.py View File

@@ -82,6 +82,7 @@ pack = P.Pack()
partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
identity = P.identity()

tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
@@ -135,7 +136,6 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike()
identity = Primitive('identity')
distribute = Primitive('distribute')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()


+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
@@ -268,6 +268,7 @@ __all__ = [
'MakeRefKey',
'Partial',
'Depend',
'identity',
'AvgPool',
# Back Primitive
'Equal',


+ 18
- 0
mindspore/ops/operations/other_ops.py View File

@@ -560,3 +560,21 @@ class Pull(PrimitiveWithInfer):

def infer_dtype(self, key_dtype, weight_dtype):
return mstype.float32

class identity(Primitive):
"""
Make a identify primitive, used for pynative mode.

Inputs:
- **x** (Any) - identity input value.

Outputs:
The same as input.
"""

@prim_attr_register
def __init__(self):
pass

def __call__(self, x):
return x

Loading…
Cancel
Save