|
|
|
@@ -7261,8 +7261,11 @@ class DynamicRNN(PrimitiveWithInfer): |
|
|
|
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name) |
|
|
|
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name) |
|
|
|
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name) |
|
|
|
validator.check_value_type("cell_type", cell_type, [str], self.name) |
|
|
|
self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) |
|
|
|
validator.check_value_type("direction", direction, [str], self.name) |
|
|
|
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) |
|
|
|
validator.check_value_type("activation", activation, [str], self.name) |
|
|
|
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): |
|
|
|
|