Browse Source

!252 fix rnntloss label shape and grads

Merge pull request !252 from yanzhenxiang2020/fix_rnntloss_incubator
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
176cdd9533
2 changed files with 3 additions and 4 deletions
  1. +1
    -3
      mindspore/ops/_grad/grad_nn_ops.py
  2. +2
    -1
      mindspore/ops/operations/nn_ops.py

+ 1
- 3
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -521,11 +521,9 @@ def get_bprop_l2_loss(self):
@bprop_getters.register(P.RNNTLoss)
def get_bprop_rnnt_loss(self):
"""Grad definition for `RNNTLoss` operation."""
expand = P.ExpandDims()

def bprop(acts, labels, act_lens, label_lens, out, dout):
grad_loss = out[1]
grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1)
grad = out[1]
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
return bprop



+ 2
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -1706,7 +1706,7 @@ class RNNTLoss(PrimitiveWithInfer):

Inputs:
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`.
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`.
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`.
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.

@@ -1735,6 +1735,7 @@ class RNNTLoss(PrimitiveWithInfer):
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],)


Loading…
Cancel
Save