|
|
|
@@ -149,7 +149,7 @@ class LSTM(Cell): |
|
|
|
if self.batch_first: |
|
|
|
x = self.transpose1(x, (1, 0, 2)) |
|
|
|
h0, c0 = hx |
|
|
|
output, hn, cn, _ = self.lstm(x, h0, c0, self.weight) |
|
|
|
output, hn, cn, _, _ = self.lstm(x, h0, c0, self.weight) |
|
|
|
if self.batch_first: |
|
|
|
output = self.transpose2(output, (1, 0, 2)) |
|
|
|
return (output, (hn, cn)) |