diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 049a0aa34d..c1d52a692b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2159,8 +2159,9 @@ class LSTM(PrimitiveWithInfer): self.num_directions = 1 def infer_shape(self, x_shape, h_shape, c_shape, w_shape): - # (batch, seq, feature) + # (seq, batch_size, feature) validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) + validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) # h and c should be same shape validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)