|
|
|
@@ -1706,7 +1706,7 @@ class RNNTLoss(PrimitiveWithInfer): |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. |
|
|
|
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`. |
|
|
|
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, U-1)`. |
|
|
|
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. |
|
|
|
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. |
|
|
|
|
|
|
|
@@ -1735,6 +1735,7 @@ class RNNTLoss(PrimitiveWithInfer): |
|
|
|
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) |
|
|
|
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) |
|
|
|
validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2]-1, Rel.EQ, self.name) |
|
|
|
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) |
|
|
|
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) |
|
|
|
costs_shape = (acts_shape[0],) |
|
|
|
|