diff --git a/model_zoo/official/cv/warpctc/README.md b/model_zoo/official/cv/warpctc/README.md index 247bbdb9a1..5e8d08a0d4 100644 --- a/model_zoo/official/cv/warpctc/README.md +++ b/model_zoo/official/cv/warpctc/README.md @@ -37,8 +37,8 @@ The dataset is self-generated using a third-party library called [captcha](https ## [Environment Requirements](#contents) -- Hardware(Ascend/GPU) - - Prepare hardware environment with Ascend or GPU processor. +- Hardware(Ascend/GPU/CPU) + - Prepare hardware environment with Ascend, GPU or CPU processor. - Framework - [MindSpore](https://gitee.com/mindspore/mindspore) - For more information, please check the resources below: @@ -68,13 +68,13 @@ The dataset is self-generated using a third-party library called [captcha](https - Running on Ascend ```bash - # distribute training example in Ascend + # distribute training example on Ascend $ bash run_distribute_train.sh rank_table.json ../data/train - # evaluation example in Ascend + # evaluation example on Ascend $ bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend - # standalone training example in Ascend + # standalone training example on Ascend $ bash run_standalone_train.sh ../data/train Ascend ``` @@ -88,16 +88,30 @@ The dataset is self-generated using a third-party library called [captcha](https - Running on GPU ```bash - # distribute training example in GPU + # distribute training example on GPU $ bash run_distribute_train_for_gpu.sh 8 ../data/train - # standalone training example in GPU + # standalone training example on GPU $ bash run_standalone_train.sh ../data/train GPU - # evaluation example in GPU + # evaluation example on GPU $ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU ``` + - Running on CPU + + ```bash + # training example on CPU + $ bash run_standalone_train.sh ../data/train CPU + or + python train.py --dataset_path=./data/train --platform=CPU + + # evaluation example on CPU + $ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU + or + python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU + ``` + ## [Script Description](#contents) ### [Script and Sample Code](#contents) diff --git a/model_zoo/official/cv/warpctc/README_CN.md b/model_zoo/official/cv/warpctc/README_CN.md index ebc24402f6..0d137b2e04 100644 --- a/model_zoo/official/cv/warpctc/README_CN.md +++ b/model_zoo/official/cv/warpctc/README_CN.md @@ -42,8 +42,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 ## 环境要求 -- 硬件(Ascend/GPU) - - 使用Ascend或GPU处理器来搭建硬件环境。 +- 硬件(Ascend/GPU/CPU) + - 使用Ascend,GPU或者CPU处理器来搭建硬件环境。 - 框架 - [MindSpore](https://gitee.com/mindspore/mindspore) - 如需查看详情,请参见如下资源: @@ -92,7 +92,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 - 在GPU环境运行 ```bash - # Ascend分布式训练示例 + # GPU分布式训练示例 $ bash run_distribute_train_for_gpu.sh 8 ../data/train # GPU单机训练示例 @@ -102,6 +102,20 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 $ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU ``` + - 在CPU环境运行 + + ```bash + # CPU训练示例 + $ bash run_standalone_train.sh ../data/train CPU + 或者 + python train.py --dataset_path=./data/train --platform=CPU + + # CPU评估示例 + $ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU + 或者 + python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU + ``` + ## 脚本说明 ### 脚本及样例代码 diff --git a/model_zoo/official/cv/warpctc/eval.py b/model_zoo/official/cv/warpctc/eval.py index 54adb98748..92d7f517f0 100755 --- a/model_zoo/official/cv/warpctc/eval.py +++ b/model_zoo/official/cv/warpctc/eval.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.loss import CTCLoss from src.config import config as cf from src.dataset import create_dataset -from src.warpctc import StackedRNN, StackedRNNForGPU +from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU from src.metric import WarpCTCAccuracy set_seed(1) @@ -32,8 +32,8 @@ set_seed(1) parser = argparse.ArgumentParser(description="Warpctc training") parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") -parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='Running platform, choose from Ascend, GPU, and default is Ascend.') +parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], + help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.') args_opt = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) @@ -54,8 +54,10 @@ if __name__ == '__main__': batch_size=cf.batch_size) if args_opt.platform == 'Ascend': net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) - else: + elif args_opt.platform == 'GPU': net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + else: + net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) # load checkpoint param_dict = load_checkpoint(args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/warpctc/export.py b/model_zoo/official/cv/warpctc/export.py index 6f053da3b4..0754ed8105 100644 --- a/model_zoo/official/cv/warpctc/export.py +++ b/model_zoo/official/cv/warpctc/export.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,18 +14,19 @@ # ============================================================================ """export checkpoint file into air models""" import argparse +import math as m import numpy as np from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export -from src.warpctc import StackedRNN +from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU from src.config import config parser = argparse.ArgumentParser(description="warpctc_export") parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.") parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.") -parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format") +parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", help="device target") args = parser.parse_args() @@ -34,15 +35,24 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) if args.device_target == "Ascend": context.set_context(device_id=args.device_id) +if args.file_format == "AIR" and args.device_target != "Ascend": + raise ValueError("export AIR must on Ascend") + if __name__ == "__main__": + input_size = m.ceil(config.captcha_height / 64) * 64 * 3 captcha_width = config.captcha_width captcha_height = config.captcha_height batch_size = config.batch_size hidden_size = config.hidden_size - net = StackedRNN(captcha_height * 3, batch_size, hidden_size) + image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32)) + if args.device_target == 'Ascend': + net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) + image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16)) + elif args.device_target == 'GPU': + net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) + else: + net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) param_dict = load_checkpoint(args.ckpt_file) load_param_into_net(net, param_dict) net.set_train(False) - - image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16)) export(net, image, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/warpctc/scripts/run_eval.sh b/model_zoo/official/cv/warpctc/scripts/run_eval.sh index cc0e3ce252..ba47cf2c3f 100755 --- a/model_zoo/official/cv/warpctc/scripts/run_eval.sh +++ b/model_zoo/official/cv/warpctc/scripts/run_eval.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ run_ascend() { cd .. } -run_gpu() { +run_gpu_cpu() { if [ -d "eval" ]; then rm -rf ./eval fi @@ -70,15 +70,13 @@ run_gpu() { cp -r ../src ./eval cd ./eval || exit env >env.log - python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 & + python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 & cd .. } if [ "Ascend" == $PLATFORM ]; then run_ascend $PATH1 $PATH2 -elif [ "GPU" == $PLATFORM ]; then - run_gpu $PATH1 $PATH2 else - echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." + run_gpu_cpu $PATH1 $PATH2 $PLATFORM fi diff --git a/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh b/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh index 863683dd00..6ae547d10d 100755 --- a/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/warpctc/scripts/run_standalone_train.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,9 +48,9 @@ run_ascend() { cd .. } -run_gpu() { +run_gpu_cpu() { env >env.log - python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 & + python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 & cd .. } @@ -64,8 +64,6 @@ cd ./train || exit if [ "Ascend" == $PLATFORM ]; then run_ascend $PATH1 -elif [ "GPU" == $PLATFORM ]; then - run_gpu $PATH1 else - echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." + run_gpu_cpu $PATH1 $PLATFORM fi \ No newline at end of file diff --git a/model_zoo/official/cv/warpctc/src/warpctc.py b/model_zoo/official/cv/warpctc/src/warpctc.py index e80bef8365..fb09898b78 100755 --- a/model_zoo/official/cv/warpctc/src/warpctc.py +++ b/model_zoo/official/cv/warpctc/src/warpctc.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -131,3 +131,41 @@ class StackedRNNForGPU(nn.Cell): res += (self.expand_dims(self.fc(output[i]), 0),) res = self.concat(res) return res + +class StackedRNNForCPU(nn.Cell): + """ + Define a stacked RNN network which contains two LSTM layers and one full-connect layer on CPU. + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + num_classes(int): the number of classes. + """ + + def __init__(self, input_size, batch_size=64, hidden_size=512, num_classes=11): + super(StackedRNNForCPU, self).__init__() + self.batch_size = batch_size + self.input_size = input_size + k = (1 / hidden_size) ** 0.5 + self.w1 = Parameter( + np.random.uniform(-k, k, (4 * hidden_size * (input_size + hidden_size + 1), 1, 1)).astype(np.float32)) + self.w2 = Parameter( + np.random.uniform(-k, k, (4 * hidden_size * (2 * hidden_size + 1), 1, 1)).astype(np.float32)) + self.h = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32)) + self.c = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32)) + + self.lstm_1 = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) + self.lstm_2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) + + self.fc = nn.Dense(in_channels=hidden_size, out_channels=num_classes) + self.transpose = P.Transpose() + + def construct(self, x): + x = self.transpose(x, (3, 0, 2, 1)) + x = F.reshape(x, (-1, self.batch_size, self.input_size)) + y1, _, _, _, _ = self.lstm_1(x, self.h, self.c, self.w1) + y2, _, _, _, _ = self.lstm_2(y1, self.h, self.c, self.w2) + output = self.fc(y2) # y2 shape: [time_step, bs, hidden_size] output shape: [time_step, bs, num_classes]. + return output diff --git a/model_zoo/official/cv/warpctc/train.py b/model_zoo/official/cv/warpctc/train.py index 096d853515..c754878463 100755 --- a/model_zoo/official/cv/warpctc/train.py +++ b/model_zoo/official/cv/warpctc/train.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ from mindspore.communication.management import init, get_group_size, get_rank from src.loss import CTCLoss from src.config import config as cf from src.dataset import create_dataset -from src.warpctc import StackedRNN, StackedRNNForGPU +from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU from src.warpctc_for_train import TrainOneStepCellWithGradClip from src.lr_schedule import get_lr @@ -37,8 +37,8 @@ set_seed(1) parser = argparse.ArgumentParser(description="Warpctc training") parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') -parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], - help='Running platform, choose from Ascend, GPU, and default is Ascend.') +parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], + help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.') parser.set_defaults(run_distribute=False) args_opt = parser.parse_args() @@ -80,8 +80,10 @@ if __name__ == '__main__': batch_size=cf.batch_size) if args_opt.platform == 'Ascend': net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) - else: + elif args_opt.platform == 'GPU': net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + else: + net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) net = WithLossCell(net, loss)