|
|
|
@@ -349,8 +349,8 @@ class _DynamicLSTMAscend(Cell): |
|
|
|
outputs, h, c, _, _, _, _, _ = self.lstm(self.cast(x, self.dtype), \ |
|
|
|
self.cast(self.transpose(weight, (1, 0)), self.dtype), \ |
|
|
|
self.cast(bias, self.dtype), None, \ |
|
|
|
self.cast(h_0[0].view(1, *h_0[0].shape), self.dtype), \ |
|
|
|
self.cast(h_0[1].view(1, *h_0[1].shape), self.dtype)) |
|
|
|
self.cast(P.ExpandDims()(h_0[0], 0), self.dtype), \ |
|
|
|
self.cast(P.ExpandDims()(h_0[1], 0), self.dtype)) |
|
|
|
if seq_length is not None: |
|
|
|
h = get_hidden(h, seq_length) |
|
|
|
c = get_hidden(c, seq_length) |
|
|
|
|