diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 7ba341fd56..888e262b26 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2157,8 +2157,9 @@ class LSTM(PrimitiveWithInfer): self.num_directions = 1 def infer_shape(self, x_shape, h_shape, c_shape, w_shape): - # (batch, seq, feature) + # (seq, batch_size, feature) validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) + validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) # h and c should be same shape validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)