|
|
|
@@ -2159,8 +2159,9 @@ class LSTM(PrimitiveWithInfer): |
|
|
|
self.num_directions = 1 |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, h_shape, c_shape, w_shape): |
|
|
|
# (batch, seq, feature) |
|
|
|
# (seq, batch_size, feature) |
|
|
|
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) |
|
|
|
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name) |
|
|
|
|
|
|
|
# h and c should be same shape |
|
|
|
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) |
|
|
|
|