Browse Source

add ReverseSequenceGrad

tags/v0.6.0-beta
xutianchun 5 years ago
parent
commit
9baf1ba99c
2 changed files with 16 additions and 0 deletions
  1. +11
    -0
      mindspore/ops/_grad/grad_array_ops.py
  2. +5
    -0
      tests/ut/python/ops/test_ops.py

+ 11
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -580,3 +580,14 @@ def get_bprop_batch_to_space_nd(self):
dx = batch_to_space_nd_grad(dout)
return (dx,)
return bprop


@bprop_getters.register(P.ReverseSequence)
def get_bprop_reverse_sequence(self):
"""Generate bprop for ReverseSequence"""
reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)

def bprop(x, seq_lengths, out, dout):
dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths)
return bprop

+ 5
- 0
tests/ut/python/ops/test_ops.py View File

@@ -1378,6 +1378,11 @@ test_case_array_ops = [
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)),
Tensor(np.array([0, 1, 1]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
('ReverseSequence', {
'block': P.ReverseSequence(1, 0),
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
Tensor(np.array([1, 2, 3]).astype(np.int32))],
'desc_bprop': [[3, 3]]}),
]

test_case_other_ops = [


Loading…
Cancel
Save