Browse Source

!8765 fix lstm input dataytype bug in warpctc

From: @gengdongjie
Reviewed-by: @oacjiewen,@liangchenghui
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
85ddfb246a
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      model_zoo/official/cv/warpctc/src/warpctc.py

+ 1
- 1
model_zoo/official/cv/warpctc/src/warpctc.py View File

@@ -61,7 +61,6 @@ class StackedRNN(nn.Cell):
self.fc_weight = Tensor(np.random.random((hidden_size, num_class)).astype(np.float16))
self.fc_bias = Tensor(np.random.random(self.num_class).astype(np.float16))

self.cast = P.Cast()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.matmul = nn.MatMul()
@@ -118,6 +117,7 @@ class StackedRNNForGPU(nn.Cell):
self.transpose = P.Transpose()

def construct(self, x):
x = self.cast(x, mstype.float32)
x = self.transpose(x, (3, 0, 2, 1))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
output, _ = self.lstm(x, (self.h, self.c))


Loading…
Cancel
Save