From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34pull/15759/MERGE
| @@ -37,8 +37,8 @@ The dataset is self-generated using a third-party library called [captcha](https | |||
| ## [Environment Requirements](#contents) | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend or GPU processor. | |||
| - Hardware(Ascend/GPU/CPU) | |||
| - Prepare hardware environment with Ascend, GPU or CPU processor. | |||
| - Framework | |||
| - [MindSpore](https://gitee.com/mindspore/mindspore) | |||
| - For more information, please check the resources below: | |||
| @@ -68,13 +68,13 @@ The dataset is self-generated using a third-party library called [captcha](https | |||
| - Running on Ascend | |||
| ```bash | |||
| # distribute training example in Ascend | |||
| # distribute training example on Ascend | |||
| $ bash run_distribute_train.sh rank_table.json ../data/train | |||
| # evaluation example in Ascend | |||
| # evaluation example on Ascend | |||
| $ bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend | |||
| # standalone training example in Ascend | |||
| # standalone training example on Ascend | |||
| $ bash run_standalone_train.sh ../data/train Ascend | |||
| ``` | |||
| @@ -88,16 +88,30 @@ The dataset is self-generated using a third-party library called [captcha](https | |||
| - Running on GPU | |||
| ```bash | |||
| # distribute training example in GPU | |||
| # distribute training example on GPU | |||
| $ bash run_distribute_train_for_gpu.sh 8 ../data/train | |||
| # standalone training example in GPU | |||
| # standalone training example on GPU | |||
| $ bash run_standalone_train.sh ../data/train GPU | |||
| # evaluation example in GPU | |||
| # evaluation example on GPU | |||
| $ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU | |||
| ``` | |||
| - Running on CPU | |||
| ```bash | |||
| # training example on CPU | |||
| $ bash run_standalone_train.sh ../data/train CPU | |||
| or | |||
| python train.py --dataset_path=./data/train --platform=CPU | |||
| # evaluation example on CPU | |||
| $ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU | |||
| or | |||
| python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU | |||
| ``` | |||
| ## [Script Description](#contents) | |||
| ### [Script and Sample Code](#contents) | |||
| @@ -42,8 +42,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 | |||
| ## 环境要求 | |||
| - 硬件(Ascend/GPU) | |||
| - 使用Ascend或GPU处理器来搭建硬件环境。 | |||
| - 硬件(Ascend/GPU/CPU) | |||
| - 使用Ascend,GPU或者CPU处理器来搭建硬件环境。 | |||
| - 框架 | |||
| - [MindSpore](https://gitee.com/mindspore/mindspore) | |||
| - 如需查看详情,请参见如下资源: | |||
| @@ -92,7 +92,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 | |||
| - 在GPU环境运行 | |||
| ```bash | |||
| # Ascend分布式训练示例 | |||
| # GPU分布式训练示例 | |||
| $ bash run_distribute_train_for_gpu.sh 8 ../data/train | |||
| # GPU单机训练示例 | |||
| @@ -102,6 +102,20 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请 | |||
| $ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU | |||
| ``` | |||
| - 在CPU环境运行 | |||
| ```bash | |||
| # CPU训练示例 | |||
| $ bash run_standalone_train.sh ../data/train CPU | |||
| 或者 | |||
| python train.py --dataset_path=./data/train --platform=CPU | |||
| # CPU评估示例 | |||
| $ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU | |||
| 或者 | |||
| python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU | |||
| ``` | |||
| ## 脚本说明 | |||
| ### 脚本及样例代码 | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.loss import CTCLoss | |||
| from src.config import config as cf | |||
| from src.dataset import create_dataset | |||
| from src.warpctc import StackedRNN, StackedRNNForGPU | |||
| from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU | |||
| from src.metric import WarpCTCAccuracy | |||
| set_seed(1) | |||
| @@ -32,8 +32,8 @@ set_seed(1) | |||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||
| parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") | |||
| parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") | |||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='Running platform, choose from Ascend, GPU, and default is Ascend.') | |||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], | |||
| help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.') | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) | |||
| @@ -54,8 +54,10 @@ if __name__ == '__main__': | |||
| batch_size=cf.batch_size) | |||
| if args_opt.platform == 'Ascend': | |||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| else: | |||
| elif args_opt.platform == 'GPU': | |||
| net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| else: | |||
| net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| # load checkpoint | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,18 +14,19 @@ | |||
| # ============================================================================ | |||
| """export checkpoint file into air models""" | |||
| import argparse | |||
| import math as m | |||
| import numpy as np | |||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||
| from src.warpctc import StackedRNN | |||
| from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU | |||
| from src.config import config | |||
| parser = argparse.ArgumentParser(description="warpctc_export") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.") | |||
| parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.") | |||
| parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format") | |||
| parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") | |||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", | |||
| help="device target") | |||
| args = parser.parse_args() | |||
| @@ -34,15 +35,24 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| if args.device_target == "Ascend": | |||
| context.set_context(device_id=args.device_id) | |||
| if args.file_format == "AIR" and args.device_target != "Ascend": | |||
| raise ValueError("export AIR must on Ascend") | |||
| if __name__ == "__main__": | |||
| input_size = m.ceil(config.captcha_height / 64) * 64 * 3 | |||
| captcha_width = config.captcha_width | |||
| captcha_height = config.captcha_height | |||
| batch_size = config.batch_size | |||
| hidden_size = config.hidden_size | |||
| net = StackedRNN(captcha_height * 3, batch_size, hidden_size) | |||
| image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32)) | |||
| if args.device_target == 'Ascend': | |||
| net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) | |||
| image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16)) | |||
| elif args.device_target == 'GPU': | |||
| net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) | |||
| else: | |||
| net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size) | |||
| param_dict = load_checkpoint(args.ckpt_file) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16)) | |||
| export(net, image, file_name=args.file_name, file_format=args.file_format) | |||
| @@ -1,5 +1,5 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -61,7 +61,7 @@ run_ascend() { | |||
| cd .. | |||
| } | |||
| run_gpu() { | |||
| run_gpu_cpu() { | |||
| if [ -d "eval" ]; then | |||
| rm -rf ./eval | |||
| fi | |||
| @@ -70,15 +70,13 @@ run_gpu() { | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env >env.log | |||
| python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 & | |||
| python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 & | |||
| cd .. | |||
| } | |||
| if [ "Ascend" == $PLATFORM ]; then | |||
| run_ascend $PATH1 $PATH2 | |||
| elif [ "GPU" == $PLATFORM ]; then | |||
| run_gpu $PATH1 $PATH2 | |||
| else | |||
| echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." | |||
| run_gpu_cpu $PATH1 $PATH2 $PLATFORM | |||
| fi | |||
| @@ -1,5 +1,5 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -48,9 +48,9 @@ run_ascend() { | |||
| cd .. | |||
| } | |||
| run_gpu() { | |||
| run_gpu_cpu() { | |||
| env >env.log | |||
| python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 & | |||
| python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 & | |||
| cd .. | |||
| } | |||
| @@ -64,8 +64,6 @@ cd ./train || exit | |||
| if [ "Ascend" == $PLATFORM ]; then | |||
| run_ascend $PATH1 | |||
| elif [ "GPU" == $PLATFORM ]; then | |||
| run_gpu $PATH1 | |||
| else | |||
| echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU." | |||
| run_gpu_cpu $PATH1 $PLATFORM | |||
| fi | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -131,3 +131,41 @@ class StackedRNNForGPU(nn.Cell): | |||
| res += (self.expand_dims(self.fc(output[i]), 0),) | |||
| res = self.concat(res) | |||
| return res | |||
| class StackedRNNForCPU(nn.Cell): | |||
| """ | |||
| Define a stacked RNN network which contains two LSTM layers and one full-connect layer on CPU. | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| num_classes(int): the number of classes. | |||
| """ | |||
| def __init__(self, input_size, batch_size=64, hidden_size=512, num_classes=11): | |||
| super(StackedRNNForCPU, self).__init__() | |||
| self.batch_size = batch_size | |||
| self.input_size = input_size | |||
| k = (1 / hidden_size) ** 0.5 | |||
| self.w1 = Parameter( | |||
| np.random.uniform(-k, k, (4 * hidden_size * (input_size + hidden_size + 1), 1, 1)).astype(np.float32)) | |||
| self.w2 = Parameter( | |||
| np.random.uniform(-k, k, (4 * hidden_size * (2 * hidden_size + 1), 1, 1)).astype(np.float32)) | |||
| self.h = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32)) | |||
| self.c = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32)) | |||
| self.lstm_1 = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) | |||
| self.lstm_2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size) | |||
| self.fc = nn.Dense(in_channels=hidden_size, out_channels=num_classes) | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| x = self.transpose(x, (3, 0, 2, 1)) | |||
| x = F.reshape(x, (-1, self.batch_size, self.input_size)) | |||
| y1, _, _, _, _ = self.lstm_1(x, self.h, self.c, self.w1) | |||
| y2, _, _, _, _ = self.lstm_2(y1, self.h, self.c, self.w2) | |||
| output = self.fc(y2) # y2 shape: [time_step, bs, hidden_size] output shape: [time_step, bs, num_classes]. | |||
| return output | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -28,7 +28,7 @@ from mindspore.communication.management import init, get_group_size, get_rank | |||
| from src.loss import CTCLoss | |||
| from src.config import config as cf | |||
| from src.dataset import create_dataset | |||
| from src.warpctc import StackedRNN, StackedRNNForGPU | |||
| from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU | |||
| from src.warpctc_for_train import TrainOneStepCellWithGradClip | |||
| from src.lr_schedule import get_lr | |||
| @@ -37,8 +37,8 @@ set_seed(1) | |||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||
| parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | |||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='Running platform, choose from Ascend, GPU, and default is Ascend.') | |||
| parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], | |||
| help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.') | |||
| parser.set_defaults(run_distribute=False) | |||
| args_opt = parser.parse_args() | |||
| @@ -80,8 +80,10 @@ if __name__ == '__main__': | |||
| batch_size=cf.batch_size) | |||
| if args_opt.platform == 'Ascend': | |||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| else: | |||
| elif args_opt.platform == 'GPU': | |||
| net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| else: | |||
| net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||
| net = WithLossCell(net, loss) | |||