Merge pull request !2732 from gengdongjie/mastertags/v0.6.0-beta
| @@ -716,7 +716,7 @@ def get_bprop_basic_lstm_cell(self): | |||||
| def bprop(x, h, c, w, b, out, dout): | def bprop(x, h, c, w, b, out, dout): | ||||
| _, _, it, jt, ft, ot, tanhct = out | _, _, it, jt, ft, ot, tanhct = out | ||||
| dct, dht, _, _, _, _, _ = dout | dct, dht, _, _, _, _, _ = dout | ||||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) | |||||
| dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) | |||||
| dxt, dht = basic_lstm_cell_input_grad(dgate, w) | dxt, dht = basic_lstm_cell_input_grad(dgate, w) | ||||
| dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) | ||||
| return dxt, dht, dct_1, dw, db | return dxt, dht, dct_1, dw, db | ||||
| @@ -29,8 +29,8 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ | |||||
| .input(1, "dht", False, "required", "all") \ | .input(1, "dht", False, "required", "all") \ | ||||
| .input(2, "dct", False, "required", "all") \ | .input(2, "dct", False, "required", "all") \ | ||||
| .input(3, "it", False, "required", "all") \ | .input(3, "it", False, "required", "all") \ | ||||
| .input(4, "ft", False, "required", "all") \ | |||||
| .input(5, "jt", False, "required", "all") \ | |||||
| .input(4, "jt", False, "required", "all") \ | |||||
| .input(5, "ft", False, "required", "all") \ | |||||
| .input(6, "ot", False, "required", "all") \ | .input(6, "ot", False, "required", "all") \ | ||||
| .input(7, "tanhct", False, "required", "all") \ | .input(7, "tanhct", False, "required", "all") \ | ||||
| .output(0, "dgate", False, "required", "all") \ | .output(0, "dgate", False, "required", "all") \ | ||||
| @@ -0,0 +1,137 @@ | |||||
| # Warpctc Example | |||||
| ## Description | |||||
| These is an example of training Warpctc with self-generated captcha image dataset in MindSpore. | |||||
| ## Requirements | |||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||||
| - Generate captcha images. | |||||
| > The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately. | |||||
| > ``` | |||||
| > $ cd scripts | |||||
| > $ sh run_process_data.sh | |||||
| > | |||||
| > # after execution, you will find the dataset like the follows: | |||||
| > . | |||||
| > └─warpctc | |||||
| > └─data | |||||
| > ├─ train # train dataset | |||||
| > └─ test # evaluate dataset | |||||
| > ... | |||||
| ## Structure | |||||
| ```shell | |||||
| . | |||||
| └──warpct | |||||
| ├── README.md | |||||
| ├── script | |||||
| ├── run_distribute_train.sh # launch distributed training(8 pcs) | |||||
| ├── run_eval.sh # launch evaluation | |||||
| ├── run_process_data.sh # launch dataset generation | |||||
| └── run_standalone_train.sh # launch standalone training(1 pcs) | |||||
| ├── src | |||||
| ├── config.py # parameter configuration | |||||
| ├── dataset.py # data preprocessing | |||||
| ├── loss.py # ctcloss definition | |||||
| ├── lr_generator.py # generate learning rate for each step | |||||
| ├── metric.py # accuracy metric for warpctc network | |||||
| ├── warpctc.py # warpctc network definition | |||||
| └── warpctc_for_train.py # warp network with grad, loss and gradient clip | |||||
| ├── eval.py # eval net | |||||
| ├── process_data.py # dataset generation script | |||||
| └── train.py # train net | |||||
| ``` | |||||
| ## Parameter configuration | |||||
| Parameters for both training and evaluation can be set in config.py. | |||||
| ``` | |||||
| "max_captcha_digits": 4, # max number of digits in each | |||||
| "captcha_width": 160, # width of captcha images | |||||
| "captcha_height": 64, # height of capthca images | |||||
| "batch_size": 64, # batch size of input tensor | |||||
| "epoch_size": 30, # only valid for taining, which is always 1 for inference | |||||
| "hidden_size": 512, # hidden size in LSTM layers | |||||
| "learning_rate": 0.01, # initial learning rate | |||||
| "momentum": 0.9 # momentum of SGD optimizer | |||||
| "save_checkpoint": True, # whether save checkpoint or not | |||||
| "save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step | |||||
| "keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint | |||||
| "save_checkpoint_path": "./", # path to save checkpoint | |||||
| ``` | |||||
| ## Running the example | |||||
| ### Train | |||||
| #### Usage | |||||
| ``` | |||||
| # distributed training | |||||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] | |||||
| # standalone training | |||||
| Usage: sh run_standalone_train.sh [DATASET_PATH] | |||||
| ``` | |||||
| #### Launch | |||||
| ``` | |||||
| # distribute training example | |||||
| sh run_distribute_train.sh rank_table.json ../data/train | |||||
| # standalone training example | |||||
| sh run_standalone_train.sh ../data/train | |||||
| ``` | |||||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | |||||
| #### Result | |||||
| Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. | |||||
| ``` | |||||
| # distribute training result(8 pcs) | |||||
| Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944] | |||||
| Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951] | |||||
| Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385] | |||||
| Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193] | |||||
| Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809] | |||||
| ... | |||||
| ``` | |||||
| ### Evaluation | |||||
| #### Usage | |||||
| ``` | |||||
| # evaluation | |||||
| Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ``` | |||||
| #### Launch | |||||
| ``` | |||||
| # evaluation example | |||||
| sh run_eval.sh ../data/test warpctc-30-98.ckpt | |||||
| ``` | |||||
| > checkpoint can be produced in training process. | |||||
| #### Result | |||||
| Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. | |||||
| ``` | |||||
| result: {'WarpCTCAccuracy': 0.9901472929936306} | |||||
| ``` | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Warpctc evaluation""" | |||||
| import os | |||||
| import math as m | |||||
| import random | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import context | |||||
| from mindspore import dataset as de | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.loss import CTCLoss | |||||
| from src.config import config as cf | |||||
| from src.dataset import create_dataset | |||||
| from src.warpctc import StackedRNN | |||||
| from src.metric import WarpCTCAccuracy | |||||
| random.seed(1) | |||||
| np.random.seed(1) | |||||
| de.config.set_seed(1) | |||||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||||
| parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") | |||||
| parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") | |||||
| args_opt = parser.parse_args() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend", | |||||
| save_graphs=False, | |||||
| device_id=device_id) | |||||
| if __name__ == '__main__': | |||||
| max_captcha_digits = cf.max_captcha_digits | |||||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | |||||
| # create dataset | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) | |||||
| step_size = dataset.get_dataset_size() | |||||
| # define loss | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||||
| # define net | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| # load checkpoint | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| net.set_train(False) | |||||
| # define model | |||||
| model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) | |||||
| # start evaluation | |||||
| res = model.eval(dataset) | |||||
| print("result:", res, flush=True) | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Generate train and test dataset""" | |||||
| import os | |||||
| import math as m | |||||
| import random | |||||
| from multiprocessing import Process | |||||
| from captcha.image import ImageCaptcha | |||||
| def _generate_captcha_per_process(path, total, start, end, img_width, img_height, max_digits): | |||||
| captcha = ImageCaptcha(width=img_width, height=img_height) | |||||
| filename_head = '{:0>' + str(len(str(total))) + '}-' | |||||
| for i in range(start, end): | |||||
| digits = '' | |||||
| digits_length = random.randint(1, max_digits) | |||||
| for _ in range(0, digits_length): | |||||
| integer = random.randint(0, 9) | |||||
| digits += str(integer) | |||||
| captcha.write(digits, os.path.join(path, filename_head.format(i) + digits + '.png')) | |||||
| def generate_captcha(name, img_num, img_width, img_height, max_digits, process_num=16): | |||||
| """ | |||||
| generate captcha images | |||||
| Args: | |||||
| name(str): name of folder, under which captcha images are saved in | |||||
| img_num(int): number of generated captcha images | |||||
| img_width(int): width of generated captcha images | |||||
| img_height(int): height of generated captcha images | |||||
| max_digits(int): max number of digits in each captcha images. For each captcha images, number of digits is in | |||||
| range [1,max_digits] | |||||
| process_num(int): number of process to generate captcha images, default is 16 | |||||
| """ | |||||
| cur_script_path = os.path.dirname(os.path.realpath(__file__)) | |||||
| path = os.path.join(cur_script_path, "data", name) | |||||
| print("Generating dataset [{}] under {}...".format(name, path)) | |||||
| if os.path.exists(path): | |||||
| os.system("rm -rf {}".format(path)) | |||||
| os.system("mkdir -p {}".format(path)) | |||||
| img_num_per_thread = m.ceil(img_num / process_num) | |||||
| processes = [] | |||||
| for i in range(process_num): | |||||
| start = i * img_num_per_thread | |||||
| end = start + img_num_per_thread if i != (process_num - 1) else img_num | |||||
| p = Process(target=_generate_captcha_per_process, | |||||
| args=(path, img_num, start, end, img_width, img_height, max_digits)) | |||||
| p.start() | |||||
| processes.append(p) | |||||
| for p in processes: | |||||
| p.join() | |||||
| print("Generating dataset [{}] finished, total number is {}!".format(name, img_num)) | |||||
| if __name__ == '__main__': | |||||
| generate_captcha("test", img_num=10000, img_width=160, img_height=64, max_digits=4) | |||||
| generate_captcha("train", img_num=50000, img_width=160, img_height=64, max_digits=4) | |||||
| @@ -0,0 +1,62 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ]; then | |||||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path() { | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -f $PATH1 ]; then | |||||
| echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -d $PATH2 ]; then | |||||
| echo "error: DATASET_PATH=$PATH2 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=8 | |||||
| export RANK_SIZE=8 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 | |||||
| export RANK_TABLE_FILE=$PATH1 | |||||
| for ((i = 0; i < ${DEVICE_NUM}; i++)); do | |||||
| export DEVICE_ID=$i | |||||
| export RANK_ID=$i | |||||
| rm -rf ./train_parallel$i | |||||
| mkdir ./train_parallel$i | |||||
| cp ../*.py ./train_parallel$i | |||||
| cp *.sh ./train_parallel$i | |||||
| cp -r ../src ./train_parallel$i | |||||
| cd ./train_parallel$i || exit | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||||
| env >env.log | |||||
| python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & | |||||
| cd .. | |||||
| done | |||||
| @@ -0,0 +1,60 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 2 ]; then | |||||
| echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path() { | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| PATH2=$(get_real_path $2) | |||||
| if [ ! -d $PATH1 ]; then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -f $PATH2 ]; then | |||||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_SIZE=$DEVICE_NUM | |||||
| export RANK_ID=0 | |||||
| if [ -d "eval" ]; then | |||||
| rm -rf ./eval | |||||
| fi | |||||
| mkdir ./eval | |||||
| cp ../*.py ./eval | |||||
| cp *.sh ./eval | |||||
| cp -r ../src ./eval | |||||
| cd ./eval || exit | |||||
| env >env.log | |||||
| echo "start evaluation for device $DEVICE_ID" | |||||
| python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & | |||||
| cd .. | |||||
| @@ -0,0 +1,20 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| CUR_PATH=$(dirname $PWD/$0) | |||||
| cd $CUR_PATH/../ && | |||||
| python process_data.py && | |||||
| cd - || exit | |||||
| @@ -0,0 +1,54 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# != 1 ]; then | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| get_real_path() { | |||||
| if [ "${1:0:1}" == "/" ]; then | |||||
| echo "$1" | |||||
| else | |||||
| echo "$(realpath -m $PWD/$1)" | |||||
| fi | |||||
| } | |||||
| PATH1=$(get_real_path $1) | |||||
| if [ ! -d $PATH1 ]; then | |||||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | |||||
| export DEVICE_NUM=1 | |||||
| export DEVICE_ID=0 | |||||
| export RANK_ID=0 | |||||
| export RANK_SIZE=1 | |||||
| if [ -d "train" ]; then | |||||
| rm -rf ./train | |||||
| fi | |||||
| mkdir ./train | |||||
| cp ../*.py ./train | |||||
| cp *.sh ./train | |||||
| cp -r ../src ./train | |||||
| cd ./train || exit | |||||
| echo "start training for device $DEVICE_ID" | |||||
| env >env.log | |||||
| python train.py --dataset=$PATH1 &>log & | |||||
| cd .. | |||||
| @@ -0,0 +1,31 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Network parameters.""" | |||||
| from easydict import EasyDict | |||||
| config = EasyDict({ | |||||
| "max_captcha_digits": 4, | |||||
| "captcha_width": 160, | |||||
| "captcha_height": 64, | |||||
| "batch_size": 64, | |||||
| "epoch_size": 30, | |||||
| "hidden_size": 512, | |||||
| "learning_rate": 0.01, | |||||
| "momentum": 0.9, | |||||
| "save_checkpoint": True, | |||||
| "save_checkpoint_steps": 98, | |||||
| "keep_checkpoint_max": 30, | |||||
| "save_checkpoint_path": "./", | |||||
| }) | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Dataset preprocessing.""" | |||||
| import os | |||||
| import math as m | |||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.c_transforms as c | |||||
| import mindspore.dataset.transforms.vision.c_transforms as vc | |||||
| from PIL import Image | |||||
| from src.config import config as cf | |||||
| class _CaptchaDataset(): | |||||
| """ | |||||
| create train or evaluation dataset for warpctc | |||||
| Args: | |||||
| img_root_dir(str): root path of images | |||||
| max_captcha_digits(int): max number of digits in images. | |||||
| blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label | |||||
| length is less than max_captcha_digits, the remaining labels are padding with blank. | |||||
| """ | |||||
| def __init__(self, img_root_dir, max_captcha_digits, blank=10): | |||||
| if not os.path.exists(img_root_dir): | |||||
| raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) | |||||
| self.img_root_dir = img_root_dir | |||||
| self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] | |||||
| self.max_captcha_digits = max_captcha_digits | |||||
| self.blank = blank | |||||
| def __len__(self): | |||||
| return len(self.img_names) | |||||
| def __getitem__(self, item): | |||||
| img_name = self.img_names[item] | |||||
| im = Image.open(os.path.join(self.img_root_dir, img_name)) | |||||
| r, g, b = im.split() | |||||
| im = Image.merge("RGB", (b, g, r)) | |||||
| image = np.array(im) | |||||
| label_str = os.path.splitext(img_name)[0] | |||||
| label_str = label_str[label_str.find('-') + 1:] | |||||
| label = [int(i) for i in label_str] | |||||
| label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) | |||||
| label = np.array(label) | |||||
| return image, label | |||||
| def create_dataset(dataset_path, repeat_num=1, batch_size=1): | |||||
| """ | |||||
| create train or evaluation dataset for warpctc | |||||
| Args: | |||||
| dataset_path(int): dataset path | |||||
| repeat_num(int): dataset repetition num, default is 1 | |||||
| batch_size(int): batch size of generated dataset, default is 1 | |||||
| """ | |||||
| rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1 | |||||
| rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0 | |||||
| dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) | |||||
| ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||||
| ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) | |||||
| image_trans = [ | |||||
| vc.Rescale(1.0 / 255.0, 0.0), | |||||
| vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), | |||||
| vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), | |||||
| vc.HWC2CHW() | |||||
| ] | |||||
| label_trans = [ | |||||
| c.TypeCast(mstype.int32) | |||||
| ] | |||||
| ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) | |||||
| ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) | |||||
| ds = ds.batch(batch_size) | |||||
| ds = ds.repeat(repeat_num) | |||||
| return ds | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CTC Loss.""" | |||||
| import numpy as np | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| class CTCLoss(_Loss): | |||||
| """ | |||||
| CTCLoss definition | |||||
| Args: | |||||
| max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image | |||||
| width | |||||
| max_label_length(int): max number of label length for each input. | |||||
| batch_size(int): batch size of input logits | |||||
| """ | |||||
| def __init__(self, max_sequence_length, max_label_length, batch_size): | |||||
| super(CTCLoss, self).__init__() | |||||
| self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32), | |||||
| name="sequence_length") | |||||
| labels_indices = [] | |||||
| for i in range(batch_size): | |||||
| for j in range(max_label_length): | |||||
| labels_indices.append([i, j]) | |||||
| self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices") | |||||
| self.reshape = P.Reshape() | |||||
| self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True) | |||||
| def construct(self, logit, label): | |||||
| labels_values = self.reshape(label, (-1,)) | |||||
| loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) | |||||
| return loss | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Learning rate generator.""" | |||||
| def get_lr(epoch_size, step_size, lr_init): | |||||
| """ | |||||
| generate learning rate for each step, which decays in every 10 epoch | |||||
| Args: | |||||
| epoch_size(int): total epoch number | |||||
| step_size(int): total step number in each step | |||||
| lr_init(int): initial learning rate | |||||
| Returns: | |||||
| List, learning rate array | |||||
| """ | |||||
| lr = lr_init | |||||
| lrs = [] | |||||
| for i in range(1, epoch_size + 1): | |||||
| if i % 10 == 0: | |||||
| lr *= 0.1 | |||||
| lrs.extend([lr for _ in range(step_size)]) | |||||
| return lrs | |||||
| @@ -0,0 +1,89 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Metric for accuracy evaluation.""" | |||||
| from mindspore import nn | |||||
| BLANK_LABLE = 10 | |||||
| class WarpCTCAccuracy(nn.Metric): | |||||
| """ | |||||
| Define accuracy metric for warpctc network. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(WarpCTCAccuracy).__init__() | |||||
| self._correct_num = 0 | |||||
| self._total_num = 0 | |||||
| self._count = 0 | |||||
| def clear(self): | |||||
| self._correct_num = 0 | |||||
| self._total_num = 0 | |||||
| def update(self, *inputs): | |||||
| if len(inputs) != 2: | |||||
| raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | |||||
| y_pred = self._convert_data(inputs[0]) | |||||
| y = self._convert_data(inputs[1]) | |||||
| self._count += 1 | |||||
| pred_lbls = self._get_prediction(y_pred) | |||||
| for b_idx, target in enumerate(y): | |||||
| if self._is_eq(pred_lbls[b_idx], target): | |||||
| self._correct_num += 1 | |||||
| self._total_num += 1 | |||||
| def eval(self): | |||||
| if self._total_num == 0: | |||||
| raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') | |||||
| return self._correct_num / self._total_num | |||||
| @staticmethod | |||||
| def _is_eq(pred_lbl, target): | |||||
| """ | |||||
| check whether predict label is equal to target label | |||||
| """ | |||||
| target = target.tolist() | |||||
| pred_diff = len(target) - len(pred_lbl) | |||||
| if pred_diff > 0: | |||||
| # padding by BLANK_LABLE | |||||
| pred_lbl.extend([BLANK_LABLE] * pred_diff) | |||||
| return pred_lbl == target | |||||
| @staticmethod | |||||
| def _get_prediction(y_pred): | |||||
| """ | |||||
| parse predict result to labels | |||||
| """ | |||||
| seq_len, batch_size, _ = y_pred.shape | |||||
| indices = y_pred.argmax(axis=2) | |||||
| lens = [seq_len] * batch_size | |||||
| pred_lbls = [] | |||||
| for i in range(batch_size): | |||||
| idx = indices[:, i] | |||||
| last_idx = BLANK_LABLE | |||||
| pred_lbl = [] | |||||
| for j in range(lens[i]): | |||||
| cur_idx = idx[j] | |||||
| if cur_idx not in [last_idx, BLANK_LABLE]: | |||||
| pred_lbl.append(cur_idx) | |||||
| last_idx = cur_idx | |||||
| pred_lbls.append(pred_lbl) | |||||
| return pred_lbls | |||||
| @@ -0,0 +1,90 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Warpctc network definition.""" | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| class StackedRNN(nn.Cell): | |||||
| """ | |||||
| Define a stacked RNN network which contains two LSTM layers and one full-connect layer. | |||||
| Args: | |||||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||||
| captcha images. | |||||
| batch_size(int): batch size of input data, default is 64 | |||||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||||
| """ | |||||
| def __init__(self, input_size, batch_size=64, hidden_size=512): | |||||
| super(StackedRNN, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.input_size = input_size | |||||
| self.num_classes = 11 | |||||
| self.reshape = P.Reshape() | |||||
| self.cast = P.Cast() | |||||
| k = (1 / hidden_size) ** 0.5 | |||||
| self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||||
| self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||||
| self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1)) | |||||
| .astype(np.float16), name="w1") | |||||
| self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1)) | |||||
| .astype(np.float16), name="w2") | |||||
| self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1") | |||||
| self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2") | |||||
| self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||||
| self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) | |||||
| self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh") | |||||
| self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) | |||||
| self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) | |||||
| self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), | |||||
| bias_init=Tensor(self.fc_bias)) | |||||
| self.fc.to_float(mstype.float32) | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.concat = P.Concat() | |||||
| self.transpose = P.Transpose() | |||||
| def construct(self, x): | |||||
| x = self.cast(x, mstype.float16) | |||||
| x = self.transpose(x, (3, 0, 2, 1)) | |||||
| x = self.reshape(x, (-1, self.batch_size, self.input_size)) | |||||
| h1 = self.h1 | |||||
| c1 = self.c1 | |||||
| h2 = self.h2 | |||||
| c2 = self.c2 | |||||
| c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1) | |||||
| c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) | |||||
| h2_after_fc = self.fc(h2) | |||||
| output = self.expand_dims(h2_after_fc, 0) | |||||
| for i in range(1, F.shape(x)[0]): | |||||
| c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1) | |||||
| c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) | |||||
| h2_after_fc = self.fc(h2) | |||||
| h2_after_fc = self.expand_dims(h2_after_fc, 0) | |||||
| output = self.concat((output, h2_after_fc)) | |||||
| return output | |||||
| @@ -0,0 +1,114 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Automatic differentiation with grad clip.""" | |||||
| from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, | |||||
| _get_parallel_mode) | |||||
| from mindspore.train.parallel_utils import ParallelMode | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn.cell import Cell | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | |||||
| import numpy as np | |||||
| compute_norm = C.MultitypeFuncGraph("compute_norm") | |||||
| @compute_norm.register("Tensor") | |||||
| def _compute_norm(grad): | |||||
| norm = nn.Norm() | |||||
| norm = norm(F.cast(grad, mstype.float32)) | |||||
| ret = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||||
| return ret | |||||
| grad_div = C.MultitypeFuncGraph("grad_div") | |||||
| @grad_div.register("Tensor", "Tensor") | |||||
| def _grad_div(val, grad): | |||||
| div = P.Div() | |||||
| mul = P.Mul() | |||||
| grad = mul(grad, 10.0) | |||||
| ret = div(grad, val) | |||||
| return ret | |||||
| class TrainOneStepCellWithGradClip(Cell): | |||||
| """ | |||||
| Network training package class. | |||||
| Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||||
| Backward graph with grad clip will be created in the construct function to do parameter updating. | |||||
| Different parallel modes are available to run the training. | |||||
| Args: | |||||
| network (Cell): The training network. | |||||
| optimizer (Cell): Optimizer for updating the weights. | |||||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||||
| Inputs: | |||||
| - data (Tensor) - Tensor of shape :(N, ...). | |||||
| - label (Tensor) - Tensor of shape :(N, ...). | |||||
| Outputs: | |||||
| Tensor, a scalar Tensor with shape :math:`()`. | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| self.reducer_flag = False | |||||
| self.grad_reducer = None | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.greater = P.Greater() | |||||
| self.select = P.Select() | |||||
| self.norm = nn.Norm(keep_dims=True) | |||||
| self.dtype = P.DType() | |||||
| self.cast = P.Cast() | |||||
| self.concat = P.Concat(axis=0) | |||||
| self.ten = Tensor(np.array([10.0]).astype(np.float32)) | |||||
| parallel_mode = _get_parallel_mode() | |||||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
| self.reducer_flag = True | |||||
| if self.reducer_flag: | |||||
| mean = _get_mirror_mean() | |||||
| degree = _get_device_num() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| def construct(self, data, label): | |||||
| weights = self.weights | |||||
| loss = self.network(data, label) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| grads = self.grad(self.network, weights)(data, label, sens) | |||||
| norm = self.hyper_map(F.partial(compute_norm), grads) | |||||
| norm = self.concat(norm) | |||||
| norm = self.norm(norm) | |||||
| cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) | |||||
| clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) | |||||
| grads = self.hyper_map(F.partial(grad_div, clip_val), grads) | |||||
| if self.reducer_flag: | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Warpctc training""" | |||||
| import os | |||||
| import math as m | |||||
| import random | |||||
| import argparse | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import dataset as de | |||||
| from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.nn.wrap import WithLossCell | |||||
| from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint | |||||
| from mindspore.communication.management import init | |||||
| from src.loss import CTCLoss | |||||
| from src.config import config as cf | |||||
| from src.dataset import create_dataset | |||||
| from src.warpctc import StackedRNN | |||||
| from src.warpctc_for_train import TrainOneStepCellWithGradClip | |||||
| from src.lr_schedule import get_lr | |||||
| random.seed(1) | |||||
| np.random.seed(1) | |||||
| de.config.set_seed(1) | |||||
| parser = argparse.ArgumentParser(description="Warpctc training") | |||||
| parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") | |||||
| parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') | |||||
| args_opt = parser.parse_args() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend", | |||||
| save_graphs=False, | |||||
| device_id=device_id) | |||||
| if __name__ == '__main__': | |||||
| if args_opt.run_distribute: | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| init() | |||||
| max_captcha_digits = cf.max_captcha_digits | |||||
| input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 | |||||
| # create dataset | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size) | |||||
| step_size = dataset.get_dataset_size() | |||||
| # define lr | |||||
| lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num | |||||
| lr = get_lr(cf.epoch_size, step_size, lr_init) | |||||
| # define loss | |||||
| loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) | |||||
| # define net | |||||
| net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) | |||||
| # define opt | |||||
| opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) | |||||
| net = WithLossCell(net, loss) | |||||
| net = TrainOneStepCellWithGradClip(net, opt).set_train() | |||||
| # define model | |||||
| model = Model(net) | |||||
| # define callbacks | |||||
| callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] | |||||
| if cf.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, | |||||
| keep_checkpoint_max=cf.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) | |||||
| callbacks.append(ckpt_cb) | |||||
| model.train(cf.epoch_size, dataset, callbacks=callbacks) | |||||