From 3f3f59771dd36b86a71ce2d8cdcd069d5f773edc Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Tue, 30 Jun 2020 10:07:52 +0800 Subject: [PATCH] fix rnnt label shape and grad --- mindspore/ops/_grad/grad_nn_ops.py | 4 +--- mindspore/ops/operations/nn_ops.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 1254f9e7a2..c692d94559 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -521,11 +521,9 @@ def get_bprop_l2_loss(self): @bprop_getters.register(P.RNNTLoss) def get_bprop_rnnt_loss(self): """Grad definition for `RNNTLoss` operation.""" - expand = P.ExpandDims() def bprop(acts, labels, act_lens, label_lens, out, dout): - grad_loss = out[1] - grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1) + grad = out[1] return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) return bprop diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ce8536c001..9ff20fecc0 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1706,7 +1706,7 @@ class RNNTLoss(PrimitiveWithInfer): Inputs: - **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. - - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`. + - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`. - **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. - **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. @@ -1735,6 +1735,7 @@ class RNNTLoss(PrimitiveWithInfer): validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) costs_shape = (acts_shape[0],)