Browse Source

!15288 Fix check string attr of DynamicRnn.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/15288/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
721bcca85b
1 changed files with 3 additions and 0 deletions
  1. +3
    -0
      mindspore/ops/operations/nn_ops.py

+ 3
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -7261,8 +7261,11 @@ class DynamicRNN(PrimitiveWithInfer):
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
validator.check_value_type("cell_type", cell_type, [str], self.name)
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
validator.check_value_type("direction", direction, [str], self.name)
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
validator.check_value_type("activation", activation, [str], self.name)
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)

def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):


Loading…
Cancel
Save