|
|
|
@@ -23,6 +23,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from ..._checkparam import Rel |
|
|
|
|
|
|
|
__all__ = ['LSTM', 'LSTMCell'] |
|
|
|
|
|
|
|
@@ -123,6 +124,8 @@ class LSTM(Cell): |
|
|
|
self.num_layers = num_layers |
|
|
|
self.has_bias = has_bias |
|
|
|
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) |
|
|
|
self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) |
|
|
|
self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) |
|
|
|
self.dropout = float(dropout) |
|
|
|
self.bidirectional = bidirectional |
|
|
|
if self.batch_first: |
|
|
|
|