| @@ -61,7 +61,6 @@ class StackedRNN(nn.Cell): | |||||
| self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16)) | self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16)) | ||||
| self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16)) | self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16)) | ||||
| self.cast = P.Cast() | |||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| self.matmul = nn.MatMul() | self.matmul = nn.MatMul() | ||||
| @@ -118,6 +117,7 @@ class StackedRNNForGPU(nn.Cell): | |||||
| self.transpose = P.Transpose() | self.transpose = P.Transpose() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.cast(x, mstype.float32) | |||||
| x = self.transpose(x, (3, 0, 2, 1)) | x = self.transpose(x, (3, 0, 2, 1)) | ||||
| x = self.reshape(x, (-1, self.batch_size, self.input_size)) | x = self.reshape(x, (-1, self.batch_size, self.input_size)) | ||||
| output, _ = self.lstm(x, (self.h, self.c)) | output, _ = self.lstm(x, (self.h, self.c)) | ||||