|
|
|
@@ -96,28 +96,28 @@ class CRNN(nn.Cell): |
|
|
|
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0) |
|
|
|
|
|
|
|
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size)) |
|
|
|
self.w1 = Parameter(w1.astype(np.float16), name="w1") |
|
|
|
self.w1 = Parameter(w1.astype(np.float32), name="w1") |
|
|
|
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size)) |
|
|
|
self.w2 = Parameter(w2.astype(np.float16), name="w2") |
|
|
|
self.w2 = Parameter(w2.astype(np.float32), name="w2") |
|
|
|
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size)) |
|
|
|
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw") |
|
|
|
self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw") |
|
|
|
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size)) |
|
|
|
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw") |
|
|
|
self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw") |
|
|
|
|
|
|
|
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1") |
|
|
|
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2") |
|
|
|
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw") |
|
|
|
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw") |
|
|
|
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1") |
|
|
|
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2") |
|
|
|
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw") |
|
|
|
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw") |
|
|
|
|
|
|
|
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
|
|
|
|
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) |
|
|
|
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) |
|
|
|
|
|
|
|
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32) |
|
|
|
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) |
|
|
|
@@ -142,7 +142,6 @@ class CRNN(nn.Cell): |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.vgg(x) |
|
|
|
x = self.cast(x, mstype.float16) |
|
|
|
|
|
|
|
x = self.reshape(x, (self.batch_size, self.input_size, -1)) |
|
|
|
x = self.transpose(x, (2, 0, 1)) |
|
|
|
@@ -169,3 +168,11 @@ class CRNN(nn.Cell): |
|
|
|
output += (y2_after_fc,) |
|
|
|
output = self.concat(output) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def crnn(config, full_precision=False): |
|
|
|
"""Create a CRNN network with mixed_precision or full_precision""" |
|
|
|
net = CRNN(config) |
|
|
|
if not full_precision: |
|
|
|
net = net.to_float(mstype.float16) |
|
|
|
return net |