Browse Source

fix identity grad

tags/v1.1.0
yanzhenxiang2020 5 years ago
parent
commit
cb391ba234
2 changed files with 18 additions and 6 deletions
  1. +10
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +8
    -6
      mindspore/ops/operations/array_ops.py

+ 10
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -456,6 +456,16 @@ def get_bprop_sparse_gather_v2(self):
return bprop return bprop




@bprop_getters.register(P.Identity)
def get_bprop_identity(self):
"""Generate bprop for Identity"""

def bprop(x, out, dout):
return (dout,)

return bprop


@bprop_getters.register(inner.Range) @bprop_getters.register(inner.Range)
def get_bprop_range(self): def get_bprop_range(self):
"""Generate bprop for Range""" """Generate bprop for Range"""


+ 8
- 6
mindspore/ops/operations/array_ops.py View File

@@ -4104,14 +4104,14 @@ class Meshgrid(PrimitiveWithInfer):
Tensors, A Tuple of N N-D Tensor objects. Tensors, A Tuple of N N-D Tensor objects.


Examples: Examples:
>>> x = np.array([1, 2, 3, 4]).astype(np.int32)
>>> y = np.array([5, 6, 7]).astype(np.int32)
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
>>> y = Tensor(np.array([5, 6, 7]).astype(np.int32))
>>> z = Tensor(np.array([8, 9, 0, 1, 2]).astype(np.int32))
>>> inputs = (x, y, z) >>> inputs = (x, y, z)
>>> meshgrid = ops.Meshgrid(indexing="xy") >>> meshgrid = ops.Meshgrid(indexing="xy")
>>> output = meshgrid(inputs) >>> output = meshgrid(inputs)
>>> print(output) >>> print(output)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
(Tensor(shape=[3, 4, 6], dtype=Int32, value=
[[[1, 1, 1, 1, 1], [[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2], [2, 2, 2, 2, 2],
[3, 3, 3, 3, 3], [3, 3, 3, 3, 3],
@@ -4124,7 +4124,7 @@ class Meshgrid(PrimitiveWithInfer):
[2, 2, 2, 2, 2], [2, 2, 2, 2, 2],
[3, 3, 3, 3, 3], [3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]]]), [4, 4, 4, 4, 4]]]),
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
Tensor(shape=[3, 4, 6], dtype=Int32, value=
[[[5, 5, 5, 5, 5], [[[5, 5, 5, 5, 5],
[5, 5, 5, 5, 5], [5, 5, 5, 5, 5],
[5, 5, 5, 5, 5], [5, 5, 5, 5, 5],
@@ -4137,7 +4137,7 @@ class Meshgrid(PrimitiveWithInfer):
[7, 7, 7, 7, 7], [7, 7, 7, 7, 7],
[7, 7, 7, 7, 7], [7, 7, 7, 7, 7],
[7, 7, 7, 7, 7]]]), [7, 7, 7, 7, 7]]]),
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
Tensor(shape=[3, 4, 6], dtype=Int32, value=
[[[8, 9, 0, 1, 2], [[[8, 9, 0, 1, 2],
[8, 9, 0, 1, 2], [8, 9, 0, 1, 2],
[8, 9, 0, 1, 2], [8, 9, 0, 1, 2],
@@ -4611,6 +4611,8 @@ class Identity(PrimitiveWithInfer):
"""Initialize identity""" """Initialize identity"""


def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid('x', x['dtype'], mstype.number_type + (mstype.bool_,), self.name)
out = {'shape': x['shape'], out = {'shape': x['shape'],
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}


Loading…
Cancel
Save