Browse Source

dims check

tags/v0.3.0-alpha
wilfChen 5 years ago
parent
commit
8eba2c5147
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      mindspore/ops/operations/nn_ops.py

+ 2
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2157,8 +2157,9 @@ class LSTM(PrimitiveWithInfer):
self.num_directions = 1

def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
# (batch, seq, feature)
# (seq, batch_size, feature)
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)

# h and c should be same shape
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)


Loading…
Cancel
Save