From 846d8691dae8830ac68bf412b473dfbd108dca47 Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Thu, 19 Nov 2020 10:59:38 +0800 Subject: [PATCH] align the datatype of lstm input in warpctc --- model_zoo/official/cv/warpctc/src/warpctc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_zoo/official/cv/warpctc/src/warpctc.py b/model_zoo/official/cv/warpctc/src/warpctc.py index 73b5173c5d..dc8a491784 100755 --- a/model_zoo/official/cv/warpctc/src/warpctc.py +++ b/model_zoo/official/cv/warpctc/src/warpctc.py @@ -61,7 +61,6 @@ class StackedRNN(nn.Cell): self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16)) self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16)) - self.cast = P.Cast() self.reshape = P.Reshape() self.transpose = P.Transpose() self.matmul = nn.MatMul() @@ -118,6 +117,7 @@ class StackedRNNForGPU(nn.Cell): self.transpose = P.Transpose() def construct(self, x): + x = self.cast(x, mstype.float32) x = self.transpose(x, (3, 0, 2, 1)) x = self.reshape(x, (-1, self.batch_size, self.input_size)) output, _ = self.lstm(x, (self.h, self.c))