Browse Source

!1029 lstm operators dims check

Merge pull request !1029 from chenweifeng/lstm_ops_check
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
bbc094f2a2
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      mindspore/ops/operations/nn_ops.py

+ 2
- 1
mindspore/ops/operations/nn_ops.py View File

@@ -2159,8 +2159,9 @@ class LSTM(PrimitiveWithInfer):
self.num_directions = 1

def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
# (batch, seq, feature)
# (seq, batch_size, feature)
validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("x[2]", x_shape[2], self.input_size, Rel.EQ, self.name)

# h and c should be same shape
validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)


Loading…
Cancel
Save