From 8eba2c514750faec18c6a9caac3d8c2d03bdf740 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Sat, 9 May 2020 14:23:41 +0800 Subject: [PATCH] dims check --- mindspore/ops/operations/nn_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 7ba341fd56..888e262b26 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2157,8 +2157,9 @@ class LSTM(PrimitiveWithInfer): self.num_directions = 1 def infer_shape(self, x_shape, h_shape, c_shape, w_shape): - # (batch, seq, feature) + # (seq, batch_size, feature) validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) + validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) # h and c should be same shape validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)