Browse Source

!8206 Fix validation about input shapes of BNTrainingUpdate.

Merge pull request !8206 from liuxiao93/fix-BNTrainingUpdate-input-check
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
18d85a8543
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      mindspore/ops/operations/nn_ops.py

+ 6
- 6
mindspore/ops/operations/nn_ops.py View File

@@ -836,12 +836,12 @@ class BNTrainingUpdate(PrimitiveWithInfer):
validator.check_equal_int(len(b), 1, "b rank", self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check_equal_int(len(variance), 1, "variance rank", self.name)
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("sum shape", sum[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("scale shape", scale[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
return (x, variance, variance, variance, variance)

def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
@@ -5436,7 +5436,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
`num_labels` indicates the number of actual labels. Blank labels are reserved.
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
The type must be int32. Each value in the tensor must not greater than `max_time`.
The type must be int32. Each value in the tensor must be equal to or less than `max_time`.

Outputs:
- **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).


Loading…
Cancel
Save