Merge pull request !2732 from gengdongjie/mastertags/v0.6.0-beta
| @@ -716,7 +716,7 @@ def get_bprop_basic_lstm_cell(self): | |||
| def bprop(x, h, c, w, b, out, dout): | |||
| _, _, it, jt, ft, ot, tanhct = out | |||
| dct, dht, _, _, _, _, _ = dout | |||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) | |||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) | |||
| dxt, dht = basic_lstm_cell_input_grad(dgate, w) | |||
| dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | |||
| return dxt, dht, dct_1, dw, db | |||
| @@ -29,8 +29,8 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ | |||
| .input(1, "dht", False, "required", "all") \ | |||
| .input(2, "dct", False, "required", "all") \ | |||
| .input(3, "it", False, "required", "all") \ | |||
| .input(4, "ft", False, "required", "all") \ | |||
| .input(5, "jt", False, "required", "all") \ | |||
| .input(4, "jt", False, "required", "all") \ | |||
| .input(5, "ft", False, "required", "all") \ | |||
| .input(6, "ot", False, "required", "all") \ | |||
| .input(7, "tanhct", False, "required", "all") \ | |||
| .output(0, "dgate", False, "required", "all") \ | |||
| @@ -0,0 +1,137 @@ | |||
| # Warpctc Example | |||
| ## Description | |||
| These is an example of training Warpctc with self-generated captcha image dataset in MindSpore. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Generate captcha images. | |||
| > The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately. | |||
| > ``` | |||
| > $ cd scripts | |||
| > $ sh run_process_data.sh | |||
| > | |||
| > # after execution, you will find the dataset like the follows: | |||
| > . | |||
| > └─warpctc | |||
| > └─data | |||
| > ├─ train # train dataset | |||
| > └─ test # evaluate dataset | |||
| > ... | |||
| ## Structure | |||
| ```shell | |||
| . | |||
| └──warpct | |||
| ├── README.md | |||
| ├── script | |||
| ├── run_distribute_train.sh # launch distributed training(8 pcs) | |||
| ├── run_eval.sh # launch evaluation | |||
| ├── run_process_data.sh # launch dataset generation | |||
| └── run_standalone_train.sh # launch standalone training(1 pcs) | |||
| ├── src | |||
| ├── config.py # parameter configuration | |||
| ├── dataset.py # data preprocessing | |||
| ├── loss.py # ctcloss definition | |||
| ├── lr_generator.py # generate learning rate for each step | |||
| ├── metric.py # accuracy metric for warpctc network | |||
| ├── warpctc.py # warpctc network definition | |||
| └── warpctc_for_train.py # warp network with grad, loss and gradient clip | |||
| ├── eval.py # eval net | |||
| ├── process_data.py # dataset generation script | |||
| └── train.py # train net | |||
| ``` | |||
| ## Parameter configuration | |||
| Parameters for both training and evaluation can be set in config.py. | |||
| ``` | |||
| "max_captcha_digits": 4, # max number of digits in each | |||
| "captcha_width": 160, # width of captcha images | |||
| "captcha_height": 64, # height of capthca images | |||
| "batch_size": 64, # batch size of input tensor | |||
| "epoch_size": 30, # only valid for taining, which is always 1 for inference | |||
| "hidden_size": 512, # hidden size in LSTM layers | |||
| "learning_rate": 0.01, # initial learning rate | |||
| "momentum": 0.9 # momentum of SGD optimizer | |||
| "save_checkpoint": True, # whether save checkpoint or not | |||
| "save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step | |||
| "keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint | |||
| "save_checkpoint_path": "./", # path to save checkpoint | |||
| ``` | |||
| ## Running the example | |||
| ### Train | |||
| #### Usage | |||
| ``` | |||
| # distributed training | |||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] | |||
| # standalone training | |||
| Usage: sh run_standalone_train.sh [DATASET_PATH] | |||
| ``` | |||
| #### Launch | |||
| ``` | |||
| # distribute training example | |||
| sh run_distribute_train.sh rank_table.json ../data/train | |||
| # standalone training example | |||
| sh run_standalone_train.sh ../data/train | |||
| ``` | |||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | |||
| #### Result | |||
| Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. | |||
| ``` | |||
| # distribute training result(8 pcs) | |||
| Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944] | |||
| Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951] | |||
| Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385] | |||
| Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193] | |||
| Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809] | |||
| ... | |||
| ``` | |||
| ### Evaluation | |||
| #### Usage | |||
| ``` | |||
| # evaluation | |||
| Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ``` | |||
| #### Launch | |||
| ``` | |||
| # evaluation example | |||
| sh run_eval.sh ../data/test warpctc-30-98.ckpt | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| #### Result | |||
| Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. | |||
| ``` | |||
| result: {'WarpCTCAccuracy': 0.9901472929936306} | |||
| ``` | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Warpctc evaluation""" | |||
| import os | |||
| import math as m | |||
| import random | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import dataset as de | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.loss import CTCLoss | |||
| from src.config import config as cf | |||
| from src.dataset import create_dataset | |||
| from src.warpctc import StackedRNN | |||
| from src.metric import WarpCTCAccuracy | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| de.config.set_seed(1) | |||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||
| parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") | |||
| parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") | |||
| args_opt = parser.parse_args() | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| save_graphs=False, | |||
| device_id=device_id) | |||
| if __name__ == '__main__': | |||
| max_captcha_digits = cf.max_captcha_digits | |||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | |||
| # create dataset | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| # define loss | |||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||
| # define net | |||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| # load checkpoint | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| # define model | |||
| model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) | |||
| # start evaluation | |||
| res = model.eval(dataset) | |||
| print("result:", res, flush=True) | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate train and test dataset""" | |||
| import os | |||
| import math as m | |||
| import random | |||
| from multiprocessing import Process | |||
| from captcha.image import ImageCaptcha | |||
| def _generate_captcha_per_process(path, total, start, end, img_width, img_height, max_digits): | |||
| captcha = ImageCaptcha(width=img_width, height=img_height) | |||
| filename_head = '{:0>' + str(len(str(total))) + '}-' | |||
| for i in range(start, end): | |||
| digits = '' | |||
| digits_length = random.randint(1, max_digits) | |||
| for _ in range(0, digits_length): | |||
| integer = random.randint(0, 9) | |||
| digits += str(integer) | |||
| captcha.write(digits, os.path.join(path, filename_head.format(i) + digits + '.png')) | |||
| def generate_captcha(name, img_num, img_width, img_height, max_digits, process_num=16): | |||
| """ | |||
| generate captcha images | |||
| Args: | |||
| name(str): name of folder, under which captcha images are saved in | |||
| img_num(int): number of generated captcha images | |||
| img_width(int): width of generated captcha images | |||
| img_height(int): height of generated captcha images | |||
| max_digits(int): max number of digits in each captcha images. For each captcha images, number of digits is in | |||
| range [1,max_digits] | |||
| process_num(int): number of process to generate captcha images, default is 16 | |||
| """ | |||
| cur_script_path = os.path.dirname(os.path.realpath(__file__)) | |||
| path = os.path.join(cur_script_path, "data", name) | |||
| print("Generating dataset [{}] under {}...".format(name, path)) | |||
| if os.path.exists(path): | |||
| os.system("rm -rf {}".format(path)) | |||
| os.system("mkdir -p {}".format(path)) | |||
| img_num_per_thread = m.ceil(img_num / process_num) | |||
| processes = [] | |||
| for i in range(process_num): | |||
| start = i * img_num_per_thread | |||
| end = start + img_num_per_thread if i != (process_num - 1) else img_num | |||
| p = Process(target=_generate_captcha_per_process, | |||
| args=(path, img_num, start, end, img_width, img_height, max_digits)) | |||
| p.start() | |||
| processes.append(p) | |||
| for p in processes: | |||
| p.join() | |||
| print("Generating dataset [{}] finished, total number is {}!".format(name, img_num)) | |||
| if __name__ == '__main__': | |||
| generate_captcha("test", img_num=10000, img_width=160, img_height=64, max_digits=4) | |||
| generate_captcha("train", img_num=50000, img_width=160, img_height=64, max_digits=4) | |||
| @@ -0,0 +1,62 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ]; then | |||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path() { | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| PATH2=$(get_real_path $2) | |||
| if [ ! -f $PATH1 ]; then | |||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $PATH2 ]; then | |||
| echo "error: DATASET_PATH=$PATH2 is not a directory" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 | |||
| export RANK_TABLE_FILE=$PATH1 | |||
| for ((i = 0; i < ${DEVICE_NUM}; i++)); do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env >env.log | |||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & | |||
| cd .. | |||
| done | |||
| @@ -0,0 +1,60 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ]; then | |||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path() { | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| PATH2=$(get_real_path $2) | |||
| if [ ! -d $PATH1 ]; then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $PATH2 ]; then | |||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export RANK_ID=0 | |||
| if [ -d "eval" ]; then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env >env.log | |||
| echo "start evaluation for device $DEVICE_ID" | |||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & | |||
| cd .. | |||
| @@ -0,0 +1,20 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| CUR_PATH=$(dirname $PWD/$0) | |||
| cd $CUR_PATH/../ && | |||
| python process_data.py && | |||
| cd - || exit | |||
| @@ -0,0 +1,54 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 1 ]; then | |||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path() { | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| if [ ! -d $PATH1 ]; then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| if [ -d "train" ]; then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env >env.log | |||
| python train.py --dataset=$PATH1 &>log & | |||
| cd .. | |||
| @@ -0,0 +1,31 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Network parameters.""" | |||
| from easydict import EasyDict | |||
| config = EasyDict({ | |||
| "max_captcha_digits": 4, | |||
| "captcha_width": 160, | |||
| "captcha_height": 64, | |||
| "batch_size": 64, | |||
| "epoch_size": 30, | |||
| "hidden_size": 512, | |||
| "learning_rate": 0.01, | |||
| "momentum": 0.9, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_steps": 98, | |||
| "keep_checkpoint_max": 30, | |||
| "save_checkpoint_path": "./", | |||
| }) | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Dataset preprocessing.""" | |||
| import os | |||
| import math as m | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.transforms.c_transforms as c | |||
| import mindspore.dataset.transforms.vision.c_transforms as vc | |||
| from PIL import Image | |||
| from src.config import config as cf | |||
| class _CaptchaDataset(): | |||
| """ | |||
| create train or evaluation dataset for warpctc | |||
| Args: | |||
| img_root_dir(str): root path of images | |||
| max_captcha_digits(int): max number of digits in images. | |||
| blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label | |||
| length is less than max_captcha_digits, the remaining labels are padding with blank. | |||
| """ | |||
| def __init__(self, img_root_dir, max_captcha_digits, blank=10): | |||
| if not os.path.exists(img_root_dir): | |||
| raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) | |||
| self.img_root_dir = img_root_dir | |||
| self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] | |||
| self.max_captcha_digits = max_captcha_digits | |||
| self.blank = blank | |||
| def __len__(self): | |||
| return len(self.img_names) | |||
| def __getitem__(self, item): | |||
| img_name = self.img_names[item] | |||
| im = Image.open(os.path.join(self.img_root_dir, img_name)) | |||
| r, g, b = im.split() | |||
| im = Image.merge("RGB", (b, g, r)) | |||
| image = np.array(im) | |||
| label_str = os.path.splitext(img_name)[0] | |||
| label_str = label_str[label_str.find('-') + 1:] | |||
| label = [int(i) for i in label_str] | |||
| label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) | |||
| label = np.array(label) | |||
| return image, label | |||
| def create_dataset(dataset_path, repeat_num=1, batch_size=1): | |||
| """ | |||
| create train or evaluation dataset for warpctc | |||
| Args: | |||
| dataset_path(int): dataset path | |||
| repeat_num(int): dataset repetition num, default is 1 | |||
| batch_size(int): batch size of generated dataset, default is 1 | |||
| """ | |||
| rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1 | |||
| rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0 | |||
| dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) | |||
| ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||
| ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) | |||
| image_trans = [ | |||
| vc.Rescale(1.0 / 255.0, 0.0), | |||
| vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), | |||
| vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), | |||
| vc.HWC2CHW() | |||
| ] | |||
| label_trans = [ | |||
| c.TypeCast(mstype.int32) | |||
| ] | |||
| ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) | |||
| ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) | |||
| ds = ds.batch(batch_size) | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CTC Loss.""" | |||
| import numpy as np | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| class CTCLoss(_Loss): | |||
| """ | |||
| CTCLoss definition | |||
| Args: | |||
| max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image | |||
| width | |||
| max_label_length(int): max number of label length for each input. | |||
| batch_size(int): batch size of input logits | |||
| """ | |||
| def __init__(self, max_sequence_length, max_label_length, batch_size): | |||
| super(CTCLoss, self).__init__() | |||
| self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32), | |||
| name="sequence_length") | |||
| labels_indices = [] | |||
| for i in range(batch_size): | |||
| for j in range(max_label_length): | |||
| labels_indices.append([i, j]) | |||
| self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices") | |||
| self.reshape = P.Reshape() | |||
| self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True) | |||
| def construct(self, logit, label): | |||
| labels_values = self.reshape(label, (-1,)) | |||
| loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) | |||
| return loss | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Learning rate generator.""" | |||
| def get_lr(epoch_size, step_size, lr_init): | |||
| """ | |||
| generate learning rate for each step, which decays in every 10 epoch | |||
| Args: | |||
| epoch_size(int): total epoch number | |||
| step_size(int): total step number in each step | |||
| lr_init(int): initial learning rate | |||
| Returns: | |||
| List, learning rate array | |||
| """ | |||
| lr = lr_init | |||
| lrs = [] | |||
| for i in range(1, epoch_size + 1): | |||
| if i % 10 == 0: | |||
| lr *= 0.1 | |||
| lrs.extend([lr for _ in range(step_size)]) | |||
| return lrs | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Metric for accuracy evaluation.""" | |||
| from mindspore import nn | |||
| BLANK_LABLE = 10 | |||
| class WarpCTCAccuracy(nn.Metric): | |||
| """ | |||
| Define accuracy metric for warpctc network. | |||
| """ | |||
| def __init__(self): | |||
| super(WarpCTCAccuracy).__init__() | |||
| self._correct_num = 0 | |||
| self._total_num = 0 | |||
| self._count = 0 | |||
| def clear(self): | |||
| self._correct_num = 0 | |||
| self._total_num = 0 | |||
| def update(self, *inputs): | |||
| if len(inputs) != 2: | |||
| raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||
| y_pred = self._convert_data(inputs[0]) | |||
| y = self._convert_data(inputs[1]) | |||
| self._count += 1 | |||
| pred_lbls = self._get_prediction(y_pred) | |||
| for b_idx, target in enumerate(y): | |||
| if self._is_eq(pred_lbls[b_idx], target): | |||
| self._correct_num += 1 | |||
| self._total_num += 1 | |||
| def eval(self): | |||
| if self._total_num == 0: | |||
| raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') | |||
| return self._correct_num / self._total_num | |||
| @staticmethod | |||
| def _is_eq(pred_lbl, target): | |||
| """ | |||
| check whether predict label is equal to target label | |||
| """ | |||
| target = target.tolist() | |||
| pred_diff = len(target) - len(pred_lbl) | |||
| if pred_diff > 0: | |||
| # padding by BLANK_LABLE | |||
| pred_lbl.extend([BLANK_LABLE] * pred_diff) | |||
| return pred_lbl == target | |||
| @staticmethod | |||
| def _get_prediction(y_pred): | |||
| """ | |||
| parse predict result to labels | |||
| """ | |||
| seq_len, batch_size, _ = y_pred.shape | |||
| indices = y_pred.argmax(axis=2) | |||
| lens = [seq_len] * batch_size | |||
| pred_lbls = [] | |||
| for i in range(batch_size): | |||
| idx = indices[:, i] | |||
| last_idx = BLANK_LABLE | |||
| pred_lbl = [] | |||
| for j in range(lens[i]): | |||
| cur_idx = idx[j] | |||
| if cur_idx not in [last_idx, BLANK_LABLE]: | |||
| pred_lbl.append(cur_idx) | |||
| last_idx = cur_idx | |||
| pred_lbls.append(pred_lbl) | |||
| return pred_lbls | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Warpctc network definition.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| class StackedRNN(nn.Cell): | |||
| """ | |||
| Define a stacked RNN network which contains two LSTM layers and one full-connect layer. | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| """ | |||
| def __init__(self, input_size, batch_size=64, hidden_size=512): | |||
| super(StackedRNN, self).__init__() | |||
| self.batch_size = batch_size | |||
| self.input_size = input_size | |||
| self.num_classes = 11 | |||
| self.reshape = P.Reshape() | |||
| self.cast = P.Cast() | |||
| k = (1 / hidden_size) ** 0.5 | |||
| self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||
| self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||
| self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1)) | |||
| .astype(np.float16), name="w1") | |||
| self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1)) | |||
| .astype(np.float16), name="w2") | |||
| self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1") | |||
| self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2") | |||
| self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||
| self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||
| self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh") | |||
| self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) | |||
| self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) | |||
| self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), | |||
| bias_init=Tensor(self.fc_bias)) | |||
| self.fc.to_float(mstype.float32) | |||
| self.expand_dims = P.ExpandDims() | |||
| self.concat = P.Concat() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| x = self.cast(x, mstype.float16) | |||
| x = self.transpose(x, (3, 0, 2, 1)) | |||
| x = self.reshape(x, (-1, self.batch_size, self.input_size)) | |||
| h1 = self.h1 | |||
| c1 = self.c1 | |||
| h2 = self.h2 | |||
| c2 = self.c2 | |||
| c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1) | |||
| c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) | |||
| h2_after_fc = self.fc(h2) | |||
| output = self.expand_dims(h2_after_fc, 0) | |||
| for i in range(1, F.shape(x)[0]): | |||
| c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1) | |||
| c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) | |||
| h2_after_fc = self.fc(h2) | |||
| h2_after_fc = self.expand_dims(h2_after_fc, 0) | |||
| output = self.concat((output, h2_after_fc)) | |||
| return output | |||
| @@ -0,0 +1,114 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Automatic differentiation with grad clip.""" | |||
| from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, | |||
| _get_parallel_mode) | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| import numpy as np | |||
| compute_norm = C.MultitypeFuncGraph("compute_norm") | |||
| @compute_norm.register("Tensor") | |||
| def _compute_norm(grad): | |||
| norm = nn.Norm() | |||
| norm = norm(F.cast(grad, mstype.float32)) | |||
| ret = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return ret | |||
| grad_div = C.MultitypeFuncGraph("grad_div") | |||
| @grad_div.register("Tensor", "Tensor") | |||
| def _grad_div(val, grad): | |||
| div = P.Div() | |||
| mul = P.Mul() | |||
| grad = mul(grad, 10.0) | |||
| ret = div(grad, val) | |||
| return ret | |||
| class TrainOneStepCellWithGradClip(Cell): | |||
| """ | |||
| Network training package class. | |||
| Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||
| Backward graph with grad clip will be created in the construct function to do parameter updating. | |||
| Different parallel modes are available to run the training. | |||
| Args: | |||
| network (Cell): The training network. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||
| Inputs: | |||
| - data (Tensor) - Tensor of shape :(N, ...). | |||
| - label (Tensor) - Tensor of shape :(N, ...). | |||
| Outputs: | |||
| Tensor, a scalar Tensor with shape :math:`()`. | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.hyper_map = C.HyperMap() | |||
| self.greater = P.Greater() | |||
| self.select = P.Select() | |||
| self.norm = nn.Norm(keep_dims=True) | |||
| self.dtype = P.DType() | |||
| self.cast = P.Cast() | |||
| self.concat = P.Concat(axis=0) | |||
| self.ten = Tensor(np.array([10.0]).astype(np.float32)) | |||
| parallel_mode = _get_parallel_mode() | |||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = _get_mirror_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, data, label): | |||
| weights = self.weights | |||
| loss = self.network(data, label) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(data, label, sens) | |||
| norm = self.hyper_map(F.partial(compute_norm), grads) | |||
| norm = self.concat(norm) | |||
| norm = self.norm(norm) | |||
| cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) | |||
| clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) | |||
| grads = self.hyper_map(F.partial(grad_div, clip_val), grads) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Warpctc training""" | |||
| import os | |||
| import math as m | |||
| import random | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore import dataset as de | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from mindspore.nn.wrap import WithLossCell | |||
| from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint | |||
| from mindspore.communication.management import init | |||
| from src.loss import CTCLoss | |||
| from src.config import config as cf | |||
| from src.dataset import create_dataset | |||
| from src.warpctc import StackedRNN | |||
| from src.warpctc_for_train import TrainOneStepCellWithGradClip | |||
| from src.lr_schedule import get_lr | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| de.config.set_seed(1) | |||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||
| parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") | |||
| parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | |||
| args_opt = parser.parse_args() | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| save_graphs=False, | |||
| device_id=device_id) | |||
| if __name__ == '__main__': | |||
| if args_opt.run_distribute: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=args_opt.device_num, | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| init() | |||
| max_captcha_digits = cf.max_captcha_digits | |||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | |||
| # create dataset | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| # define lr | |||
| lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num | |||
| lr = get_lr(cf.epoch_size, step_size, lr_init) | |||
| # define loss | |||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||
| # define net | |||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||
| # define opt | |||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||
| net = WithLossCell(net, loss) | |||
| net = TrainOneStepCellWithGradClip(net, opt).set_train() | |||
| # define model | |||
| model = Model(net) | |||
| # define callbacks | |||
| callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | |||
| if cf.save_checkpoint: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, | |||
| keep_checkpoint_max=cf.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) | |||
| callbacks.append(ckpt_cb) | |||
| model.train(cf.epoch_size, dataset, callbacks=callbacks) | |||