Browse Source

!13825 Fix crnn precision by remove manul float16 assignment

From: @c_34
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @linqingke
pull/13825/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
bca26cf68e
4 changed files with 30 additions and 23 deletions
  1. +2
    -2
      model_zoo/official/cv/crnn/eval.py
  2. +2
    -2
      model_zoo/official/cv/crnn/export.py
  3. +24
    -17
      model_zoo/official/cv/crnn/src/crnn.py
  4. +2
    -2
      model_zoo/official/cv/crnn/train.py

+ 2
- 2
model_zoo/official/cv/crnn/eval.py View File

@@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import CRNN
from src.crnn import crnn
from src.metric import CRNNAccuracy

set_seed(1)
@@ -60,7 +60,7 @@ if __name__ == '__main__':
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = CRNN(config)
net = crnn(config)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)


+ 2
- 2
model_zoo/official/cv/crnn/export.py View File

@@ -20,7 +20,7 @@ import numpy as np
import mindspore as ms
from mindspore import Tensor, context, load_checkpoint, export

from src.crnn import CRNN
from src.crnn import crnn
from src.config import config1 as config

parser = argparse.ArgumentParser(description="CRNN_export")
@@ -37,7 +37,7 @@ if args.device_target == "Ascend":

if __name__ == "__main__":
config.batch_size = 1
net = CRNN(config)
net = crnn(config)

load_checkpoint(args.ckpt_file, net=net)
net.set_train(False)


+ 24
- 17
model_zoo/official/cv/crnn/src/crnn.py View File

@@ -96,28 +96,28 @@ class CRNN(nn.Cell):
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)

w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1 = Parameter(w1.astype(np.float16), name="w1")
self.w1 = Parameter(w1.astype(np.float32), name="w1")
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2 = Parameter(w2.astype(np.float16), name="w2")
self.w2 = Parameter(w2.astype(np.float32), name="w2")
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw")
self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw")
self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")

self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")

self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))

self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))

self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
@@ -142,7 +142,6 @@ class CRNN(nn.Cell):

def construct(self, x):
x = self.vgg(x)
x = self.cast(x, mstype.float16)

x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
@@ -169,3 +168,11 @@ class CRNN(nn.Cell):
output += (y2_after_fc,)
output = self.concat(output)
return output


def crnn(config, full_precision=False):
"""Create a CRNN network with mixed_precision or full_precision"""
net = CRNN(config)
if not full_precision:
net = net.to_float(mstype.float16)
return net

+ 2
- 2
model_zoo/official/cv/crnn/train.py View File

@@ -26,7 +26,7 @@ from mindspore.communication.management import init, get_group_size, get_rank

from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import CRNN
from src.crnn import crnn
from src.crnn_for_train import TrainOneStepCellWithGradClip

set_seed(1)
@@ -83,7 +83,7 @@ if __name__ == '__main__':
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = CRNN(config)
net = crnn(config)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)

net = WithLossCell(net, loss)


Loading…
Cancel
Save