From: @ZhengBina Reviewed-by: @c_34,@wuxuejian Signed-off-by: @c_34tags/v1.3.0
| @@ -48,18 +48,17 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released | |||
| ## [Quick Start](#contents) | |||
| - After the dataset is prepared, you may start running the training or the evaluation scripts as follows: | |||
| - Running on Ascend | |||
| ```shell | |||
| # distribute training example in Ascend | |||
| $ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| $ bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] | |||
| # evaluation example in Ascend | |||
| $ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| $ bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] | |||
| # standalone training example in Ascend | |||
| $ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] | |||
| $ bash run_standalone_train.sh [TRAIN_DATA_DIR] | |||
| ``` | |||
| For distributed training, a hccl configuration file with JSON format needs to be created in advance. | |||
| @@ -67,6 +66,56 @@ For training and evaluation, we use the French Street Name Signs (FSNS) released | |||
| Please follow the instructions in the link below: | |||
| [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). | |||
| - Running on ModelArts | |||
| If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows. | |||
| - Training with 8 cards on ModelArts | |||
| ```python | |||
| # (1) Upload the code folder to S3 bucket. | |||
| # (2) Click to "create training task" on the website UI interface. | |||
| # (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface. | |||
| # (4) Set the startup file to /{path}/crnn_seq2seq_ocr/train.py" on the website UI interface. | |||
| # (5) Perform a or b. | |||
| # a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml. | |||
| # 1. Set ”is_distributed=1“ | |||
| # 2. Set ”enable_modelarts=True“ | |||
| # 3. Set ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package. | |||
| # b. adding on the website UI interface. | |||
| # 1. Add ”is_distributed=1“ | |||
| # 2. Add ”enable_modelarts=True“ | |||
| # 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package. | |||
| # (6) Upload the dataset or the zip package of dataset to S3 bucket. | |||
| # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path). | |||
| # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (9) Under the item "resource pool selection", select the specification of 8 cards. | |||
| # (10) Create your job. | |||
| ``` | |||
| - evaluating with single card on ModelArts | |||
| ```python | |||
| # (1) Upload the code folder to S3 bucket. | |||
| # (2) Click to "create training task" on the website UI interface. | |||
| # (3) Set the code directory to "/{path}/crnn_seq2seq_ocr" on the website UI interface. | |||
| # (4) Set the startup file to /{path}/crnn_seq2seq_ocr/eval.py" on the website UI interface. | |||
| # (5) Perform a or b. | |||
| # a. setting parameters in /{path}/crnn_seq2seq_ocr/default_config.yaml. | |||
| # 1. Set ”enable_modelarts=True“ | |||
| # 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.) | |||
| # 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package. | |||
| # b. adding on the website UI interface. | |||
| # 1. Set ”enable_modelarts=True“ | |||
| # 2. Set “checkpoint_path={checkpoint_path}”({checkpoint_path} Indicates the path of the weight file to be evaluated relative to the file 'eval.py', and the weight file must be included in the code directory.) | |||
| # 3. Add ”modelarts_dataset_unzip_name={filenmae}",if the data is uploaded in the form of zip package. | |||
| # (6) Upload the dataset or the zip package of dataset to S3 bucket. | |||
| # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path). | |||
| # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. | |||
| # (9) Under the item "resource pool selection", select the specification of a single card. | |||
| # (10) Create your job. | |||
| ``` | |||
| ## [Script Description](#contents) | |||
| ### [Script and Sample Code](#contents) | |||
| @@ -79,9 +128,13 @@ crnn-seq2seq-ocr | |||
| │ ├── run_eval_ascend.sh # Launch Ascend evaluation | |||
| │ └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs) | |||
| ├── src | |||
| | |── scripts | |||
| │ | ├── config.py # parsing parameter configuration file of "*.yaml" | |||
| │ | ├── device_adapter.py # local or ModelArts training | |||
| │ | ├── local_adapter.py # get related environment variables in local training | |||
| │ | └── moxing_adapter.py # get related environment variables in ModelArts training | |||
| │ ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper | |||
| │ ├── cnn.py # VGG network | |||
| │ ├── config.py # Parameter configuration | |||
| │ ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth | |||
| │ ├── dataset.py # Data preprocessing for training and evaluation | |||
| │ ├── gru.py # GRU cell wrapper | |||
| @@ -90,8 +143,9 @@ crnn-seq2seq-ocr | |||
| │ ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure | |||
| │ └── utils.py # Utility functions for training and data pre-processing | |||
| │ ├── weight_init.py # weight initialization of LSTM and GRU | |||
| └── train.py # Training script | |||
| ├── eval.py # Evaluation Script | |||
| ├── general_chars.txt # general chars | |||
| └── train.py # Training script | |||
| ``` | |||
| ### [Script Parameters](#contents) | |||
| @@ -100,10 +154,10 @@ crnn-seq2seq-ocr | |||
| ```shell | |||
| # distributed training on Ascend | |||
| Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] | |||
| # standalone training | |||
| Usage: bash run_standalone_train.sh [DATASET_PATH] | |||
| Usage: bash run_standalone_train.sh [TRAIN_DATA_DIR] | |||
| ``` | |||
| #### Parameters Configuration | |||
| @@ -116,14 +170,14 @@ Parameters for both training and evaluation can be set in config.py. | |||
| ## [Training Process](#contents) | |||
| - Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset. | |||
| - Set options in `default_config.yaml`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset. | |||
| ### [Training](#contents) | |||
| - Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now. | |||
| ``` bash | |||
| bash run_standalone_train.sh [DATASET_PATH] | |||
| bash run_standalone_train.sh [TRAIN_DATA_DIR] | |||
| ``` | |||
| #### [Distributed Training](#contents) | |||
| @@ -131,7 +185,7 @@ bash run_standalone_train.sh [DATASET_PATH] | |||
| - Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend. | |||
| ``` bash | |||
| bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| bash run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR] | |||
| ``` | |||
| Check the `train_parallel0/log.txt` and you will get outputs as following: | |||
| @@ -149,7 +203,7 @@ epoch time: 1559886.096 ms, per step time: 382.231 ms | |||
| - Run `run_eval_ascend.sh` for evaluation on Ascend. | |||
| ``` bash | |||
| bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| bash run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] | |||
| ``` | |||
| Check the `eval/log` and you will get outputs as following: | |||
| @@ -0,0 +1,89 @@ | |||
| # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) | |||
| enable_modelarts: False | |||
| # Url for modelarts | |||
| data_url: "" | |||
| train_url: "" | |||
| checkpoint_url: "" | |||
| # Path for local | |||
| data_path: "/cache/data" | |||
| output_path: "/cache/train" | |||
| load_path: "/cache/checkpoint_path" | |||
| device_target: "Ascend" | |||
| enable_profiling: False | |||
| modelarts_dataset_unzip_name: None | |||
| # ============================================================================== | |||
| #train-related | |||
| is_distributed: 0 | |||
| rank_id: 0 | |||
| train_data_dir: '' | |||
| batch_size: 32 | |||
| num_epochs: 20 | |||
| keep_checkpoint_max: 20 | |||
| #eval-related | |||
| eval_batch_size: 32 | |||
| test_data_dir: '' | |||
| checkpoint_path: None | |||
| # logging-related | |||
| log_interval: 100 | |||
| pre_checkpoint_path: '' | |||
| ckpt_path: "outputs/" | |||
| ckpt_interval: None | |||
| is_save_on_master: 0 | |||
| # dataset-related | |||
| mindrecord_dir: '' | |||
| data_root: '' | |||
| annotation_file: '' | |||
| val_data_root: '' | |||
| val_annotation_file: '' | |||
| data_json: '' | |||
| go_shift: 1 | |||
| characters_dictionary: {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3} | |||
| labels_not_use: ['%#�?%', '%#背景#%', '%#不识�?%', '#%不识�?#', '%#模糊#%', '%#模糊#%'] | |||
| vocab_path: "./general_chars.txt" | |||
| # model-related | |||
| img_width: 512 | |||
| img_height: 128 | |||
| channel_size: 3 | |||
| conv_out_dim: 384 | |||
| encoder_hidden_size: 128 | |||
| decoder_hidden_size: 128 | |||
| decoder_output_size: 10000 | |||
| dropout_p: 0.1 | |||
| max_length: 64 | |||
| attn_num_layers: 1 | |||
| teacher_force_ratio: 0.5 | |||
| #optimizer-related | |||
| lr: 0.0008 | |||
| adam_beta1: 0.5 | |||
| adam_beta2: 0.999 | |||
| loss_scale: 1024 | |||
| --- | |||
| # Help description for each configuration | |||
| enable_modelarts: "Whether training on modelarts, default: False" | |||
| data_url: "Url for modelarts" | |||
| train_url: "Url for modelarts" | |||
| data_path: "The location of the input data." | |||
| output_path: "The location of the output file." | |||
| device_target: 'Target device type' | |||
| enable_profiling: 'Whether enable profiling while training, default: False' | |||
| is_distributed: 'Distribute train or not, 1 for yes, 0 for no. Default: 0' | |||
| rank_id: "Local rank of distributed. Default: 0" | |||
| train_data_dir: "Train dataset directory." | |||
| log_interval: "Logging interval steps. Default: 100" | |||
| ckpt_path: "Checkpoint save location. Default: outputs/" | |||
| pre_checkpoint_path: "Checkpoint save location." | |||
| ckpt_interval: "Save checkpoint interval. Default: None" | |||
| is_save_on_master: "Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0" | |||
| test_data_dir: "Test Dataset path" | |||
| checkpoint_path: "Checkpoint of AttentionOCR (Default:None)." | |||
| @@ -19,7 +19,6 @@ CRNN-Seq2Seq-OCR Evaluation. | |||
| import os | |||
| import codecs | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.ops.operations as P | |||
| @@ -29,11 +28,13 @@ from mindspore.common import set_seed | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config | |||
| from src.utils import initialize_vocabulary | |||
| from src.dataset import create_ocr_val_dataset | |||
| from src.attention_ocr import AttentionOCRInfer | |||
| from src.model_utils.config import config | |||
| from src.model_utils.moxing_adapter import moxing_wrapper | |||
| from src.model_utils.device_adapter import get_device_id | |||
| set_seed(1) | |||
| @@ -75,30 +76,20 @@ def LCS_length(str1, str2): | |||
| return lcs[len1 % 2][-1] | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation") | |||
| parser.add_argument("--dataset_path", type=str, default="", | |||
| help="Test Dataset path") | |||
| parser.add_argument("--checkpoint_path", type=str, default=None, | |||
| help="Checkpoint of AttentionOCR (Default:None).") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented, default is Ascend") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| @moxing_wrapper() | |||
| def run_eval(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) | |||
| prefix = "fsns.mindrecord" | |||
| mindrecord_dir = args.dataset_path | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if config.enable_modelarts: | |||
| mindrecord_file = os.path.join(config.data_path, prefix + "0") | |||
| else: | |||
| mindrecord_file = os.path.join(config.test_data_dir, prefix + "0") | |||
| print("mindrecord_file", mindrecord_file) | |||
| dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size) | |||
| data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True) | |||
| print("Dataset creation Done!") | |||
| #Network | |||
| # Network | |||
| network = AttentionOCRInfer(config.eval_batch_size, | |||
| int(config.img_width / 4), | |||
| config.encoder_hidden_size, | |||
| @@ -106,15 +97,16 @@ if __name__ == '__main__': | |||
| config.decoder_output_size, | |||
| config.max_length, | |||
| config.dropout_p) | |||
| ckpt = load_checkpoint(args.checkpoint_path) | |||
| checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.checkpoint_path) | |||
| ckpt = load_checkpoint(checkpoint_path) | |||
| load_param_into_net(network, ckpt) | |||
| network.set_train(False) | |||
| print("Checkpoint loading Done!") | |||
| vocab, rev_vocab = initialize_vocabulary(config.vocab_path) | |||
| eos_id = config.characters_dictionary.get("eos_id") | |||
| sos_id = config.characters_dictionary.get("go_id") | |||
| vocab_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.vocab_path) | |||
| _, rev_vocab = initialize_vocabulary(vocab_path) | |||
| eos_id = config.characters_dictionary.eos_id | |||
| sos_id = config.characters_dictionary.go_id | |||
| num_correct_char = 0 | |||
| num_total_char = 0 | |||
| @@ -125,20 +117,20 @@ if __name__ == '__main__': | |||
| incorrect_file = 'result_incorrect.txt' | |||
| with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \ | |||
| codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect: | |||
| codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect: | |||
| for data in data_loader: | |||
| images = Tensor(data["image"]) | |||
| decoder_inputs = Tensor(data["decoder_input"]) | |||
| decoder_targets = Tensor(data["decoder_target"]) | |||
| # decoder_targets = Tensor(data["decoder_target"]) | |||
| decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size), | |||
| dtype=np.float16), mstype.float16) | |||
| decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32)) | |||
| decoder_input = Tensor((np.ones((config.eval_batch_size, 1)) * sos_id).astype(np.int32)) | |||
| encoder_outputs = network.encoder(images) | |||
| batch_decoded_label = [] | |||
| for di in range(decoder_inputs.shape[1]): | |||
| for _ in range(decoder_inputs.shape[1]): | |||
| decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs) | |||
| topi = P.Argmax()(decoder_output) | |||
| ni = P.ExpandDims()(topi, 1) | |||
| @@ -179,3 +171,5 @@ if __name__ == '__main__': | |||
| print('\nnum of total words = %d' % (num_total_word)) | |||
| print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char)) | |||
| print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word)) | |||
| if __name__ == '__main__': | |||
| run_eval() | |||
| @@ -16,7 +16,7 @@ | |||
| if [ $# -ne 2 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]" | |||
| echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]" | |||
| exit 1 | |||
| fi | |||
| @@ -39,9 +39,9 @@ fi | |||
| PATH2=$(get_real_path $2) | |||
| echo $PATH2 | |||
| if [ ! -f $PATH2 ] | |||
| if [ ! -d $PATH2 ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PATH2 is not a file" | |||
| echo "error: TRAIN_DATA_DIR=$PATH2 is not a folder" | |||
| exit 1 | |||
| fi | |||
| @@ -58,9 +58,11 @@ do | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cp ../*.yaml ./train_parallel$i | |||
| cp ../*.txt ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log & | |||
| python train.py --is_distribute=1 --train_data_dir=$PATH2 &> log & | |||
| cd .. | |||
| done | |||
| @@ -16,7 +16,7 @@ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| echo "Usage: sh run_eval_ascend.sh [TEST_DATA_DIR] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| @@ -34,7 +34,7 @@ echo $PATH2 | |||
| if [ ! -d $PATH1 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a folder" | |||
| echo "error: TEST_DATA_DIR=$PATH1 is not a folder" | |||
| exit 1 | |||
| fi | |||
| @@ -56,10 +56,11 @@ fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp ../*.txt ./eval | |||
| cp ../*.yaml ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env > env.log | |||
| echo "start eval for device $DEVICE_ID" | |||
| python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & | |||
| python eval.py --device_target="Ascend" --test_data_dir=$PATH1 --checkpoint_path=$PATH2 &> log & | |||
| cd .. | |||
| @@ -16,7 +16,7 @@ | |||
| if [ $# -ne 1 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]" | |||
| echo "Usage: sh run_standalone_train_ascend.sh [TRAIN_DATA_DIR]" | |||
| exit 1 | |||
| fi | |||
| @@ -31,9 +31,9 @@ get_real_path(){ | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| if [ ! -f $PATH1 ] | |||
| if [ ! -d $PATH1 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a file" | |||
| echo "error: TRAIN_DATA_DIR=$PATH1 is not a folder" | |||
| exit 1 | |||
| fi | |||
| @@ -50,9 +50,11 @@ fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp ../*.yaml ./train | |||
| cp ../*.txt ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log & | |||
| python train.py --train_data_dir=$PATH1 --is_distributed=0 &> log & | |||
| cd .. | |||
| @@ -1,62 +0,0 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """Config parameters for CRNN-Seq2Seq-OCR model.""" | |||
| from easydict import EasyDict as ed | |||
| config = ed({ | |||
| # dataset-related | |||
| "mindrecord_dir": "", | |||
| "data_root": "", | |||
| "annotation_file": "", | |||
| "val_data_root": "", | |||
| "val_annotation_file": "", | |||
| "data_json": "", | |||
| "go_shift": 1, | |||
| "characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3}, | |||
| "labels_not_use": [u'%#�?%', u'%#背景#%', u'%#不识�?%', u'#%不识�?#', u'%#模糊#%', u'%#模糊#%'], | |||
| "vocab_path": "./general_chars.txt", | |||
| #model-related | |||
| "img_width": 512, | |||
| "img_height": 128, | |||
| "channel_size": 3, | |||
| "conv_out_dim": 384, | |||
| "encoder_hidden_size": 128, | |||
| "decoder_hidden_size": 128, | |||
| "decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space | |||
| "dropout_p": 0.1, | |||
| "max_length": 64, | |||
| "attn_num_layers": 1, | |||
| "teacher_force_ratio": 0.5, | |||
| #optimizer-related | |||
| "lr": 0.0008, | |||
| "adam_beta1": 0.5, | |||
| "adam_beta2": 0.999, | |||
| "loss_scale": 1024, | |||
| #train-related | |||
| "batch_size": 32, | |||
| "num_epochs": 20, | |||
| "keep_checkpoint_max": 20, | |||
| #eval-related | |||
| "eval_batch_size": 32 | |||
| }) | |||
| @@ -19,7 +19,7 @@ import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from config import config | |||
| from src.model_utils.config import config | |||
| from utils import initialize_vocabulary | |||
| @@ -24,7 +24,7 @@ import mindspore.dataset.vision.py_transforms as P | |||
| import mindspore.dataset.transforms.c_transforms as ops | |||
| import mindspore.common.dtype as mstype | |||
| from src.config import config | |||
| from src.model_utils.config import config | |||
| class AugmentationOps(): | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Parse arguments""" | |||
| import os | |||
| import ast | |||
| import argparse | |||
| from pprint import pprint, pformat | |||
| import yaml | |||
| class Config: | |||
| """ | |||
| Configuration namespace. Convert dictionary to members. | |||
| """ | |||
| def __init__(self, cfg_dict): | |||
| for k, v in cfg_dict.items(): | |||
| if isinstance(v, (list, tuple)): | |||
| setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) | |||
| else: | |||
| setattr(self, k, Config(v) if isinstance(v, dict) else v) | |||
| def __str__(self): | |||
| return pformat(self.__dict__) | |||
| def __repr__(self): | |||
| return self.__str__() | |||
| def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): | |||
| """ | |||
| Parse command line arguments to the configuration according to the default yaml. | |||
| Args: | |||
| parser: Parent parser. | |||
| cfg: Base configuration. | |||
| helper: Helper description. | |||
| cfg_path: Path to the default yaml config. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", | |||
| parents=[parser]) | |||
| helper = {} if helper is None else helper | |||
| choices = {} if choices is None else choices | |||
| for item in cfg: | |||
| if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): | |||
| help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) | |||
| choice = choices[item] if item in choices else None | |||
| if isinstance(cfg[item], bool): | |||
| parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| else: | |||
| parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, | |||
| help=help_description) | |||
| args = parser.parse_args() | |||
| return args | |||
| def parse_yaml(yaml_path): | |||
| """ | |||
| Parse the yaml config file. | |||
| Args: | |||
| yaml_path: Path to the yaml config. | |||
| """ | |||
| with open(yaml_path, 'r', encoding='utf-8') as fin: | |||
| try: | |||
| cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) | |||
| cfgs = [x for x in cfgs] | |||
| if len(cfgs) == 1: | |||
| cfg_helper = {} | |||
| cfg = cfgs[0] | |||
| cfg_choices = {} | |||
| elif len(cfgs) == 2: | |||
| cfg, cfg_helper = cfgs | |||
| cfg_choices = {} | |||
| elif len(cfgs) == 3: | |||
| cfg, cfg_helper, cfg_choices = cfgs | |||
| else: | |||
| raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") | |||
| print(cfg_helper) | |||
| except: | |||
| raise ValueError("Failed to parse yaml") | |||
| return cfg, cfg_helper, cfg_choices | |||
| def merge(args, cfg): | |||
| """ | |||
| Merge the base config from yaml file and command line arguments. | |||
| Args: | |||
| args: Command line arguments. | |||
| cfg: Base configuration. | |||
| """ | |||
| args_var = vars(args) | |||
| for item in args_var: | |||
| cfg[item] = args_var[item] | |||
| return cfg | |||
| def get_config(): | |||
| """ | |||
| Get Config according to the yaml file and cli arguments. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="default name", add_help=False) | |||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
| parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), | |||
| help="Config file path") | |||
| path_args, _ = parser.parse_known_args() | |||
| default, helper, choices = parse_yaml(path_args.config_path) | |||
| pprint(default) | |||
| args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) | |||
| final_config = merge(args, default) | |||
| return Config(final_config) | |||
| config = get_config() | |||
| if __name__ == '__main__': | |||
| print(config) | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Device adapter for ModelArts""" | |||
| from src.model_utils.config import config | |||
| if config.enable_modelarts: | |||
| from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| else: | |||
| from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id | |||
| __all__ = [ | |||
| "get_device_id", "get_device_num", "get_rank_id", "get_job_id" | |||
| ] | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Local adapter""" | |||
| import os | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| return "Local Job" | |||
| @@ -0,0 +1,123 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Moxing adapter for ModelArts""" | |||
| import os | |||
| import functools | |||
| from mindspore import context | |||
| from mindspore.profiler import Profiler | |||
| from src.model_utils.config import config | |||
| _global_sync_count = 0 | |||
| def get_device_id(): | |||
| device_id = os.getenv('DEVICE_ID', '0') | |||
| return int(device_id) | |||
| def get_device_num(): | |||
| device_num = os.getenv('RANK_SIZE', '1') | |||
| return int(device_num) | |||
| def get_rank_id(): | |||
| global_rank_id = os.getenv('RANK_ID', '0') | |||
| return int(global_rank_id) | |||
| def get_job_id(): | |||
| job_id = os.getenv('JOB_ID') | |||
| job_id = job_id if job_id != "" else "default" | |||
| return job_id | |||
| def sync_data(from_path, to_path): | |||
| """ | |||
| Download data from remote obs to local directory if the first url is remote url and the second one is local path | |||
| Upload data from local directory to remote obs in contrast. | |||
| """ | |||
| import moxing as mox | |||
| import time | |||
| global _global_sync_count | |||
| sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) | |||
| _global_sync_count += 1 | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("from path: ", from_path) | |||
| print("to path: ", to_path) | |||
| mox.file.copy_parallel(from_path, to_path) | |||
| print("===finish data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| # print("os.mknod({}) success".format(sync_lock)) | |||
| except IOError: | |||
| pass | |||
| print("===save flag===") | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Finish sync data from {} to {}.".format(from_path, to_path)) | |||
| def moxing_wrapper(pre_process=None, post_process=None): | |||
| """ | |||
| Moxing wrapper to download dataset and upload outputs. | |||
| """ | |||
| def wrapper(run_func): | |||
| @functools.wraps(run_func) | |||
| def wrapped_func(*args, **kwargs): | |||
| # Download data from data_url | |||
| if config.enable_modelarts: | |||
| if config.data_url: | |||
| sync_data(config.data_url, config.data_path) | |||
| print("Dataset downloaded: ", os.listdir(config.data_path)) | |||
| if config.checkpoint_url: | |||
| sync_data(config.checkpoint_url, config.load_path) | |||
| print("Preload downloaded: ", os.listdir(config.load_path)) | |||
| if config.train_url: | |||
| sync_data(config.train_url, config.output_path) | |||
| print("Workspace downloaded: ", os.listdir(config.output_path)) | |||
| context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) | |||
| config.device_num = get_device_num() | |||
| config.device_id = get_device_id() | |||
| if not os.path.exists(config.output_path): | |||
| os.makedirs(config.output_path) | |||
| if pre_process: | |||
| pre_process() | |||
| if config.enable_profiling: | |||
| profiler = Profiler() | |||
| run_func(*args, **kwargs) | |||
| if config.enable_profiling: | |||
| profiler.analyse() | |||
| # Upload data to train_url | |||
| if config.enable_modelarts: | |||
| if post_process: | |||
| post_process() | |||
| if config.train_url: | |||
| print("Start to copy output directory") | |||
| sync_data(config.output_path, config.train_url) | |||
| return wrapped_func | |||
| return wrapper | |||
| @@ -16,10 +16,9 @@ | |||
| CRNN-Seq2Seq-OCR train. | |||
| """ | |||
| import os | |||
| import argparse | |||
| import datetime | |||
| import time | |||
| import os | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| @@ -31,62 +30,78 @@ from mindspore import context | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.callback import ModelCheckpoint | |||
| from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config | |||
| from src.dataset import create_ocr_train_dataset | |||
| from src.logger import get_logger | |||
| from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper | |||
| from src.model_utils.moxing_adapter import moxing_wrapper | |||
| from src.model_utils.config import config | |||
| from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num | |||
| set_seed(1) | |||
| def parse_args(): | |||
| """Parse train arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training') | |||
| # device related | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=0, | |||
| help='Distribute train or not, 1 for yes, 0 for no. Default: 0') | |||
| parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0') | |||
| parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1') | |||
| #dataset related | |||
| parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.') | |||
| # logging related | |||
| parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100') | |||
| parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/') | |||
| parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.') | |||
| parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None') | |||
| parser.add_argument('--is_save_on_master', type=int, default=0, | |||
| help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0') | |||
| args, _ = parser.parse_known_args() | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| return args | |||
| def modelarts_pre_process(): | |||
| '''modelarts pre process function.''' | |||
| def unzip(zip_file, save_dir): | |||
| import zipfile | |||
| s_time = time.time() | |||
| if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)): | |||
| zip_isexist = zipfile.is_zipfile(zip_file) | |||
| if zip_isexist: | |||
| fz = zipfile.ZipFile(zip_file, 'r') | |||
| data_num = len(fz.namelist()) | |||
| print("Extract Start...") | |||
| print("unzip file num: {}".format(data_num)) | |||
| data_print = int(data_num / 100) if data_num > 100 else 1 | |||
| i = 0 | |||
| for file in fz.namelist(): | |||
| if i % data_print == 0: | |||
| print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True) | |||
| i += 1 | |||
| fz.extract(file, save_dir) | |||
| print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60), | |||
| int(int(time.time() - s_time) % 60))) | |||
| print("Extract Done.") | |||
| else: | |||
| print("This is not zip.") | |||
| else: | |||
| print("Zip has been extracted.") | |||
| if config.modelarts_dataset_unzip_name: | |||
| zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip") | |||
| save_dir_1 = os.path.join(config.data_path) | |||
| sync_lock = "/tmp/unzip_sync.lock" | |||
| # Each server contains 8 devices as most. | |||
| if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): | |||
| print("Zip file path: ", zip_file_1) | |||
| print("Unzip file save dir: ", save_dir_1) | |||
| unzip(zip_file_1, save_dir_1) | |||
| print("===Finish extract data synchronization===") | |||
| try: | |||
| os.mknod(sync_lock) | |||
| except IOError: | |||
| pass | |||
| while True: | |||
| if os.path.exists(sync_lock): | |||
| break | |||
| time.sleep(1) | |||
| print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) | |||
| config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path) | |||
| @moxing_wrapper(pre_process=modelarts_pre_process) | |||
| def train(): | |||
| """Train function.""" | |||
| args = parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) | |||
| if args.is_distributed: | |||
| rank = args.rank_id | |||
| device_num = args.device_num | |||
| if config.is_distributed: | |||
| rank = get_rank_id() | |||
| device_num = get_device_num() | |||
| context.set_auto_parallel_context(device_num=device_num, | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| @@ -96,25 +111,31 @@ def train(): | |||
| device_num = 1 | |||
| # Logger | |||
| args.logger = get_logger(args.outputs_dir, rank) | |||
| args.rank_save_ckpt_flag = 0 | |||
| if args.is_save_on_master: | |||
| config.outputs_dir = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| config.logger = get_logger(config.outputs_dir, rank) | |||
| config.rank_save_ckpt_flag = 0 | |||
| if config.is_save_on_master: | |||
| if rank == 0: | |||
| args.rank_save_ckpt_flag = 1 | |||
| config.rank_save_ckpt_flag = 1 | |||
| else: | |||
| args.rank_save_ckpt_flag = 1 | |||
| config.rank_save_ckpt_flag = 1 | |||
| # DATASET | |||
| dataset = create_ocr_train_dataset(args.mindrecord_file, | |||
| prefix = "fsns.mindrecord" | |||
| if config.enable_modelarts: | |||
| mindrecord_file = os.path.join(config.data_path, prefix + "0") | |||
| else: | |||
| mindrecord_file = os.path.join(config.train_data_dir, prefix + "0") | |||
| dataset = create_ocr_train_dataset(mindrecord_file, | |||
| config.batch_size, | |||
| rank_size=device_num, | |||
| rank_id=rank) | |||
| args.steps_per_epoch = dataset.get_dataset_size() | |||
| args.logger.info('Finish loading dataset') | |||
| config.steps_per_epoch = dataset.get_dataset_size() | |||
| config.logger.info('Finish loading dataset') | |||
| if not args.ckpt_interval: | |||
| args.ckpt_interval = args.steps_per_epoch | |||
| args.logger.save_args(args) | |||
| if not config.ckpt_interval: | |||
| config.ckpt_interval = config.steps_per_epoch | |||
| config.logger.save_args(config) | |||
| network = AttentionOCR(config.batch_size, | |||
| int(config.img_width / 4), | |||
| @@ -124,8 +145,10 @@ def train(): | |||
| config.max_length, | |||
| config.dropout_p) | |||
| if args.pre_checkpoint_path: | |||
| param_dict = load_checkpoint(args.pre_checkpoint_path) | |||
| if config.pre_checkpoint_path: | |||
| config.pre_checkpoint_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), config.pre_checkpoint_path | |||
| ) | |||
| param_dict = load_checkpoint(config.pre_checkpoint_path) | |||
| load_param_into_net(network, param_dict) | |||
| network = AttentionOCRWithLossCell(network, config.max_length) | |||
| @@ -136,13 +159,13 @@ def train(): | |||
| network = TrainingWrapper(network, opt, sens=config.loss_scale) | |||
| args.logger.info('Finished get network') | |||
| config.logger.info('Finished get network') | |||
| callback = [TimeMonitor(data_size=1), LossMonitor()] | |||
| if args.rank_save_ckpt_flag: | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, | |||
| if config.rank_save_ckpt_flag: | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=config.steps_per_epoch, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/') | |||
| save_ckpt_path = os.path.join(config.outputs_dir, 'checkpoints' + '/') | |||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
| directory=save_ckpt_path, | |||
| prefix="crnn_seq2seq_ocr") | |||
| @@ -151,7 +174,7 @@ def train(): | |||
| model = Model(network) | |||
| model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False) | |||
| args.logger.info('==========Training Done===============') | |||
| config.logger.info('==========Training Done===============') | |||
| if __name__ == "__main__": | |||