Merge pull request !6383 from linqingke/fasterrcnntags/v1.1.0
| @@ -0,0 +1,354 @@ | |||
| # Contents | |||
| - [CNNCTC Description](#CNNCTC-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Features](#features) | |||
| - [Mixed Precision](#mixed-precision) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Process](#training-process) | |||
| - [Training](#training) | |||
| - [Distributed Training](#distributed-training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| - [Inference Performance](#evaluation-performance) | |||
| - [How to use](#how-to-use) | |||
| - [Inference](#inference) | |||
| - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) | |||
| - [Transfer Learning](#transfer-learning) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [CNNCTC Description](#contents) | |||
| This paper proposes three major contributions to addresses scene text recognition (STR). | |||
| First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies. | |||
| Second, we introduce a unified four-stage STR framework that most existing STR models fit into. | |||
| Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously | |||
| unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed, | |||
| and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current | |||
| comparisons to understand the performance gain of the existing modules. | |||
| [Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019. | |||
| # [Model Architecture](#contents) | |||
| This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore. | |||
| # [Dataset](#contents) | |||
| The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation. | |||
| - step 1: | |||
| All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt). | |||
| - step 2: | |||
| Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT. | |||
| - step 3: | |||
| Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below: | |||
| ``` | |||
| |--- CNNCTC/ | |||
| |--- cnnctc_data/ | |||
| |--- ST/ | |||
| data.mdb | |||
| lock.mdb | |||
| |--- MJ/ | |||
| data.mdb | |||
| lock.mdb | |||
| |--- IIIT/ | |||
| data.mdb | |||
| lock.mdb | |||
| ...... | |||
| ``` | |||
| - step 4: | |||
| Preprocess the dataset by running: | |||
| ``` | |||
| python src/preprocess_dataset.py | |||
| ``` | |||
| This takes around 75 minutes. | |||
| # [Features](#contents) | |||
| ## Mixed Precision | |||
| The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||
| For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||
| # [Quick Start](#contents) | |||
| - Install dependencies: | |||
| ``` | |||
| pip install lmdb | |||
| pip install Pillow | |||
| pip install tqdm | |||
| pip install six | |||
| ``` | |||
| - Standalone Training: | |||
| ``` | |||
| bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT | |||
| ``` | |||
| - Distributed Training: | |||
| ``` | |||
| bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT | |||
| ``` | |||
| - Evaluation: | |||
| ``` | |||
| bash scripts/run_eval_ascend.sh $TRAINED_CKPT | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| The entire code structure is as following: | |||
| ``` | |||
| |--- CNNCTC/ | |||
| |---README.md // descriptions about cnnctc | |||
| |---train.py // train scripts | |||
| |---eval.py // eval scripts | |||
| |---scripts | |||
| |---run_standalone_train_ascend.sh // shell script for standalone on ascend | |||
| |---run_distribute_train_ascend.sh // shell script for distributed on ascend | |||
| |---run_eval_ascend.sh // shell script for eval on ascend | |||
| |---src | |||
| |---__init__.py // init file | |||
| |---cnn_ctc.py // cnn_ctc network | |||
| |---config.py // total config | |||
| |---callback.py // loss callback file | |||
| |---dataset.py // process dataset | |||
| |---util.py // routine operation | |||
| |---generate_hccn_file.py // generate distribute json file | |||
| |---preprocess_dataset.py // preprocess dataset | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| Parameters for both training and evaluation can be set in `config.py`. | |||
| Arguments: | |||
| * `--CHARACTER`: Character labels. | |||
| * `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss. | |||
| * `--HIDDEN_SIZE`: Model hidden size. | |||
| * `--FINAL_FEATURE_WIDTH`: The number of features. | |||
| * `--IMG_H`: The height of input image. | |||
| * `--IMG_W`: The width of input image. | |||
| * `--TRAIN_DATASET_PATH`: The path to training dataset. | |||
| * `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order . | |||
| * `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape. | |||
| * `--TRAIN_DATASET_SIZE`: Training dataset size. | |||
| * `--TEST_DATASET_PATH`: The path to test dataset. | |||
| * `--TEST_BATCH_SIZE`: Test batch size. | |||
| * `--TEST_DATASET_SIZE`:Test dataset size. | |||
| * `--TRAIN_EPOCHS`:Total training epochs. | |||
| * `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation. | |||
| * `--SAVE_PATH`:The path to save model checkpoint file. | |||
| * `--LR`:Learning rate for standalone training. | |||
| * `--LR_PARA`:Learning rate for distributed training. | |||
| * `--MOMENTUM`:Momentum. | |||
| * `--LOSS_SCALE`:Loss scale to prevent gradient underflow. | |||
| * `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps. | |||
| * `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file. | |||
| ## [Training Process](#contents) | |||
| ### Training | |||
| - Standalone Training: | |||
| ``` | |||
| bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT | |||
| ``` | |||
| Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`. | |||
| `$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. | |||
| - Distributed Training: | |||
| ``` | |||
| bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT | |||
| ``` | |||
| Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively. | |||
| Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`. | |||
| `$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend. | |||
| `$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. | |||
| ### Training Result | |||
| Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log. | |||
| ``` | |||
| # distribute training result(8p) | |||
| epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712 | |||
| epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203 | |||
| epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573 | |||
| epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527 | |||
| epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406 | |||
| epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215 | |||
| ... | |||
| epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549 | |||
| epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116 | |||
| epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555 | |||
| epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375 | |||
| epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031 | |||
| epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573 | |||
| epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345 | |||
| epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777 | |||
| epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694 | |||
| epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257 | |||
| ``` | |||
| ## [Evaluation Process](#contents) | |||
| ### Evaluation | |||
| - Evaluation: | |||
| ``` | |||
| bash scripts/run_eval_ascend.sh $TRAINED_CKPT | |||
| ``` | |||
| The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed. | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| | Parameters | FasterRcnn | | |||
| | -------------------------- | ----------------------------------------------------------- | | |||
| | Model Version | V1 | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | |||
| | uploaded Date | 09/28/2020 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | | |||
| | Dataset | MJSynth,SynthText | | |||
| | Training Parameters | epoch=3, batch_size=192 | | |||
| | Optimizer | RMSProp | | |||
| | Loss Function | CTCLoss | | |||
| | Speed | 1pc: 300 ms/step; 8pcs: 310 ms/step | | |||
| | Total time | 1pc: 18 hours; 8pcs: 2.3 hours | | |||
| | Parameters (M) | 177 | | |||
| | Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc | | |||
| ### Evaluation Performance | |||
| | Parameters | FasterRcnn | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | V1 | | |||
| | Resource | Ascend 910 | | |||
| | Uploaded Date | 09/28/2020 (month/day/year) | | |||
| | MindSpore Version | 1.0.0 | | |||
| | Dataset | IIIT5K | | |||
| | batch_size | 192 | | |||
| | outputs | Accuracy | | |||
| | Accuracy | 85% | | |||
| | Model for inference | 675M (.ckpt file) | | |||
| ## [How to use](#contents) | |||
| ### Inference | |||
| If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: | |||
| - Running on Ascend | |||
| ``` | |||
| # Set context | |||
| context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target) | |||
| context.set_context(device_id=cfg.device_id) | |||
| # Load unseen dataset for inference | |||
| dataset = dataset.create_dataset(cfg.data_path, 1, False) | |||
| # Define model | |||
| net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, | |||
| cfg.momentum, weight_decay=cfg.weight_decay) | |||
| loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||
| ctc_merge_repeated=True, | |||
| ignore_longer_outputs_than_inputs=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| # Load pre-trained model | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| # Make predictions on the unseen dataset | |||
| acc = model.eval(dataset) | |||
| print("accuracy: ", acc) | |||
| ``` | |||
| ### Continue Training on the Pretrained Model | |||
| - running on Ascend | |||
| ``` | |||
| # Load dataset | |||
| dataset = create_dataset(cfg.data_path, 1) | |||
| batch_num = dataset.get_dataset_size() | |||
| # Define model | |||
| net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) | |||
| # Continue training if set pre_trained to be True | |||
| if cfg.pre_trained: | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, | |||
| steps_per_epoch=batch_num) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), | |||
| Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) | |||
| loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||
| ctc_merge_repeated=True, | |||
| ignore_longer_outputs_than_inputs=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
| # Set callbacks | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| time_cb = TimeMonitor(data_size=batch_num) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", | |||
| config=config_ck) | |||
| loss_cb = LossMonitor() | |||
| # Start training | |||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||
| print("train success") | |||
| ``` | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,109 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cnnctc eval""" | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import Tensor, context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.dataset import GeneratorDataset | |||
| from src.util import CTCLabelConverter, AverageMeter | |||
| from src.config import Config_CNNCTC | |||
| from src.dataset import IIIT_Generator_batch | |||
| from src.cnn_ctc import CNNCTC_Model | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | |||
| save_graphs_path=".", enable_auto_mixed_precision=False) | |||
| def test_dataset_creator(): | |||
| ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str']) | |||
| return ds | |||
| def test(config): | |||
| ds = test_dataset_creator() | |||
| net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||
| ckpt_path = config.CKPT_PATH | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| print('parameters loaded! from: ', ckpt_path) | |||
| converter = CTCLabelConverter(config.CHARACTER) | |||
| model_run_time = AverageMeter() | |||
| npu_to_cpu_time = AverageMeter() | |||
| postprocess_time = AverageMeter() | |||
| count = 0 | |||
| correct_count = 0 | |||
| for data in ds.create_tuple_iterator(): | |||
| img, _, text, _, length = data | |||
| img_tensor = Tensor(img, mstype.float32) | |||
| model_run_begin = time.time() | |||
| model_predict = net(img_tensor) | |||
| model_run_end = time.time() | |||
| model_run_time.update(model_run_end - model_run_begin) | |||
| npu_to_cpu_begin = time.time() | |||
| model_predict = np.squeeze(model_predict.asnumpy()) | |||
| npu_to_cpu_end = time.time() | |||
| npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin) | |||
| postprocess_begin = time.time() | |||
| preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) | |||
| preds_index = np.argmax(model_predict, 2) | |||
| preds_index = np.reshape(preds_index, [-1]) | |||
| preds_str = converter.decode(preds_index, preds_size) | |||
| postprocess_end = time.time() | |||
| postprocess_time.update(postprocess_end - postprocess_begin) | |||
| label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) | |||
| if count == 0: | |||
| model_run_time.reset() | |||
| npu_to_cpu_time.reset() | |||
| postprocess_time.reset() | |||
| else: | |||
| print('---------model run time--------', model_run_time.avg) | |||
| print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg) | |||
| print('---------postprocess run time--------', postprocess_time.avg) | |||
| print("Prediction samples: \n", preds_str[:5]) | |||
| print("Ground truth: \n", label_str[:5]) | |||
| for pred, label in zip(preds_str, label_str): | |||
| if pred == label: | |||
| correct_count += 1 | |||
| count += 1 | |||
| print('accuracy: ', correct_count / count) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description="FasterRcnn training") | |||
| parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.") | |||
| args_opt = parser.parse_args() | |||
| cfg = Config_CNNCTC() | |||
| if args_opt.ckpt_path != "": | |||
| cfg.CKPT_PATH = args_opt.ckpt_path | |||
| test(cfg) | |||
| @@ -0,0 +1,57 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| current_exec_path=$(pwd) | |||
| echo ${current_exec_path} | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| PATH2=$(get_real_path $2) | |||
| echo $PATH2 | |||
| python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1 | |||
| export RANK_TABLE_FILE=$PATH1 | |||
| export RANK_SIZE=8 | |||
| ulimit -u unlimited | |||
| for((i=0;i<$RANK_SIZE;i++)); | |||
| do | |||
| rm ./train_parallel_$i/ -rf | |||
| mkdir ./train_parallel_$i | |||
| cp ./*.py ./train_parallel_$i | |||
| cp ./scripts/*.sh ./train_parallel_$i | |||
| cp -r ./src ./train_parallel_$i | |||
| cd ./train_parallel_$i || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| if [ -f $PATH2 ] | |||
| then | |||
| python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 & | |||
| else | |||
| python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 & | |||
| fi | |||
| cd .. || exit | |||
| done | |||
| @@ -0,0 +1,54 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 1 ] | |||
| then | |||
| echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| if [ ! -f $PATH1 ] | |||
| then | |||
| echo "error: TRAINED_CKPT=$PATH1 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_ID=0 | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ./*.py ./eval | |||
| cp ./scripts/*.sh ./eval | |||
| cp -r ./src ./eval | |||
| cd ./eval || exit | |||
| echo "start infering for device $DEVICE_ID" | |||
| env > env.log | |||
| python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log & | |||
| cd .. || exit | |||
| @@ -0,0 +1,45 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| ulimit -u unlimited | |||
| if [ -d "train" ]; | |||
| then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ./*.py ./train | |||
| cp ./scripts/*.sh ./train | |||
| cp -r ./src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| if [ -f $PATH1 ] | |||
| then | |||
| python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log & | |||
| else | |||
| python train.py --device_id=$DEVICE_ID --run_distribute=False &> log & | |||
| fi | |||
| cd .. || exit | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """src init file""" | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """loss callback""" | |||
| import time | |||
| from mindspore.train.callback import Callback | |||
| from .util import AverageMeter | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF terminating training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.loss_avg = AverageMeter() | |||
| self.timer = AverageMeter() | |||
| self.start_time = time.time() | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs.asnumpy() | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| cur_num = cb_params.cur_step_num | |||
| if cur_step_in_epoch % 2000 == 1: | |||
| self.loss_avg = AverageMeter() | |||
| self.timer = AverageMeter() | |||
| self.start_time = time.time() | |||
| else: | |||
| self.timer.update(time.time() - self.start_time) | |||
| self.start_time = time.time() | |||
| self.loss_avg.update(loss) | |||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||
| loss_file = open("./loss.log", "a+") | |||
| loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % ( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| self.loss_avg.avg, self.timer.avg)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| print("epoch: %s step: %s , loss is %s, average time per step is %s" % ( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| self.loss_avg.avg, self.timer.avg)) | |||
| @@ -0,0 +1,255 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cnn_ctc network define""" | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal, initializer | |||
| import mindspore.common.dtype as mstype | |||
| class CNNCTC_Model(nn.Cell): | |||
| def __init__(self, num_class, hidden_size, final_feature_width): | |||
| super(CNNCTC_Model, self).__init__() | |||
| self.num_class = num_class | |||
| self.hidden_size = hidden_size | |||
| self.final_feature_width = final_feature_width | |||
| self.FeatureExtraction = ResNet_FeatureExtractor() | |||
| self.Prediction = nn.Dense(self.hidden_size, self.num_class) | |||
| self.transpose = P.Transpose() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x): | |||
| x = self.FeatureExtraction(x) | |||
| x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] | |||
| x = self.reshape(x, (-1, self.hidden_size)) | |||
| x = self.Prediction(x) | |||
| x = self.reshape(x, (-1, self.final_feature_width, self.num_class)) | |||
| return x | |||
| class WithLossCell(nn.Cell): | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| def construct(self, img, label_indices, text, sequence_length): | |||
| model_predict = self._backbone(img) | |||
| return self._loss_fn(model_predict, label_indices, text, sequence_length) | |||
| @property | |||
| def backbone_network(self): | |||
| return self._backbone | |||
| class ctc_loss(nn.Cell): | |||
| def __init__(self): | |||
| super(ctc_loss, self).__init__() | |||
| self.loss = P.CTCLoss(preprocess_collapse_repeated=False, | |||
| ctc_merge_repeated=True, | |||
| ignore_longer_outputs_than_inputs=False) | |||
| self.mean = P.ReduceMean() | |||
| self.transpose = P.Transpose() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, inputs, labels_indices, labels_values, sequence_length): | |||
| inputs = self.transpose(inputs, (1, 0, 2)) | |||
| loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length) | |||
| loss = self.mean(loss) | |||
| return loss | |||
| class ResNet_FeatureExtractor(nn.Cell): | |||
| def __init__(self): | |||
| super(ResNet_FeatureExtractor, self).__init__() | |||
| self.ConvNet = ResNet(3, 512, BasicBlock, [1, 2, 5, 3]) | |||
| def construct(self, featuremap): | |||
| return self.ConvNet(featuremap) | |||
| class ResNet(nn.Cell): | |||
| def __init__(self, input_channel, output_channel, block, layers): | |||
| super(ResNet, self).__init__() | |||
| self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] | |||
| self.inplanes = int(output_channel / 8) | |||
| self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad') | |||
| self.bn0_1 = ms_fused_bn(int(output_channel / 16)) | |||
| self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad') | |||
| self.bn0_2 = ms_fused_bn(self.inplanes) | |||
| self.relu = P.ReLU() | |||
| self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') | |||
| self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) | |||
| self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1, | |||
| pad_mode='pad') | |||
| self.bn1 = ms_fused_bn(self.output_channel_block[0]) | |||
| self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') | |||
| self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1]) | |||
| self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1, | |||
| pad_mode='pad') | |||
| self.bn2 = ms_fused_bn(self.output_channel_block[1]) | |||
| self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (1, 1))) | |||
| self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid') | |||
| self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2]) | |||
| self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1, | |||
| pad_mode='pad') | |||
| self.bn3 = ms_fused_bn(self.output_channel_block[2]) | |||
| self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3]) | |||
| self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1), | |||
| pad_mode='valid') | |||
| self.bn4_1 = ms_fused_bn(self.output_channel_block[3]) | |||
| self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0, | |||
| pad_mode='valid') | |||
| self.bn4_2 = ms_fused_bn(self.output_channel_block[3]) | |||
| def _make_layer(self, block, planes, blocks, stride=1): | |||
| downsample = None | |||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||
| downsample = nn.SequentialCell( | |||
| [ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride), | |||
| ms_fused_bn(planes * block.expansion)] | |||
| ) | |||
| layers = [] | |||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||
| self.inplanes = planes * block.expansion | |||
| for _ in range(1, blocks): | |||
| layers.append(block(self.inplanes, planes)) | |||
| return nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| x = self.conv0_1(x) | |||
| x = self.bn0_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv0_2(x) | |||
| x = self.bn0_2(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool1(x) | |||
| x = self.layer1(x) | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.maxpool2(x) | |||
| x = self.layer2(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.relu(x) | |||
| x = self.pad(x) | |||
| x = self.maxpool3(x) | |||
| x = self.layer3(x) | |||
| x = self.conv3(x) | |||
| x = self.bn3(x) | |||
| x = self.relu(x) | |||
| x = self.layer4(x) | |||
| x = self.pad(x) | |||
| x = self.conv4_1(x) | |||
| x = self.bn4_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv4_2(x) | |||
| x = self.bn4_2(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class BasicBlock(nn.Cell): | |||
| expansion = 1 | |||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
| super(BasicBlock, self).__init__() | |||
| self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad') | |||
| self.bn1 = ms_fused_bn(planes) | |||
| self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad') | |||
| self.bn2 = ms_fused_bn(planes) | |||
| self.relu = P.ReLU() | |||
| self.downsample = downsample | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x): | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| residual = self.downsample(x) | |||
| out = self.add(out, residual) | |||
| out = self.relu(out) | |||
| return out | |||
| def weight_variable(shape, factor=0.1, half_precision=False): | |||
| if half_precision: | |||
| return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16) | |||
| return TruncatedNormal(0.02) | |||
| def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||
| """Get a conv2d layer with 3x3 kernel size.""" | |||
| init_value = weight_variable((out_channels, in_channels, 3, 3)) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||
| has_bias=has_bias) | |||
| def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||
| """Get a conv2d layer with 1x1 kernel size.""" | |||
| init_value = weight_variable((out_channels, in_channels, 1, 1)) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||
| has_bias=has_bias) | |||
| def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): | |||
| """Get a conv2d layer with 2x2 kernel size.""" | |||
| init_value = weight_variable((out_channels, in_channels, 1, 1)) | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, | |||
| has_bias=has_bias) | |||
| def ms_fused_bn(channels, momentum=0.1): | |||
| """Get a fused batchnorm""" | |||
| return nn.BatchNorm2d(channels, momentum=momentum) | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """network config setting, will be used in train.py and eval.py""" | |||
| class Config_CNNCTC(): | |||
| # model config | |||
| CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz' | |||
| NUM_CLASS = len(CHARACTER) + 1 | |||
| HIDDEN_SIZE = 512 | |||
| FINAL_FEATURE_WIDTH = 26 | |||
| # dataset config | |||
| IMG_H = 32 | |||
| IMG_W = 100 | |||
| TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/' | |||
| TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl' | |||
| TRAIN_BATCH_SIZE = 192 | |||
| TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000' | |||
| TEST_BATCH_SIZE = 256 | |||
| TEST_DATASET_SIZE = 2976 | |||
| TRAIN_EPOCHS = 3 | |||
| # training config | |||
| CKPT_PATH = '' | |||
| SAVE_PATH = './' | |||
| LR = 1e-4 | |||
| LR_PARA = 5e-4 | |||
| MOMENTUM = 0.8 | |||
| LOSS_SCALE = 8096 | |||
| SAVE_CKPT_PER_N_STEP = 2000 | |||
| KEEP_CKPT_MAX_NUM = 5 | |||
| @@ -0,0 +1,265 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cnn_ctc dataset""" | |||
| import sys | |||
| import pickle | |||
| import math | |||
| import six | |||
| import numpy as np | |||
| from PIL import Image | |||
| import lmdb | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| from .util import CTCLabelConverter | |||
| from .config import Config_CNNCTC | |||
| config = Config_CNNCTC() | |||
| class NormalizePAD(): | |||
| def __init__(self, max_size, PAD_type='right'): | |||
| self.max_size = max_size | |||
| self.PAD_type = PAD_type | |||
| def __call__(self, img): | |||
| # toTensor | |||
| img = np.array(img, dtype=np.float32) | |||
| img = img.transpose([2, 0, 1]) | |||
| img = img.astype(np.float) | |||
| img = np.true_divide(img, 255) | |||
| # normalize | |||
| img = np.subtract(img, 0.5) | |||
| img = np.true_divide(img, 0.5) | |||
| _, _, w = img.shape | |||
| Pad_img = np.zeros(shape=self.max_size, dtype=np.float32) | |||
| Pad_img[:, :, :w] = img # right pad | |||
| if self.max_size[2] != w: # add border Pad | |||
| Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w)) | |||
| return Pad_img | |||
| class AlignCollate(): | |||
| def __init__(self, imgH=32, imgW=100): | |||
| self.imgH = imgH | |||
| self.imgW = imgW | |||
| def __call__(self, images): | |||
| resized_max_w = self.imgW | |||
| input_channel = 3 | |||
| transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) | |||
| resized_images = [] | |||
| for image in images: | |||
| w, h = image.size | |||
| ratio = w / float(h) | |||
| if math.ceil(self.imgH * ratio) > self.imgW: | |||
| resized_w = self.imgW | |||
| else: | |||
| resized_w = math.ceil(self.imgH * ratio) | |||
| resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) | |||
| resized_images.append(transform(resized_image)) | |||
| image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0) | |||
| return image_tensors | |||
| def get_img_from_lmdb(env, index): | |||
| with env.begin(write=False) as txn: | |||
| label_key = 'label-%09d'.encode() % index | |||
| label = txn.get(label_key).decode('utf-8') | |||
| img_key = 'image-%09d'.encode() % index | |||
| imgbuf = txn.get(img_key) | |||
| buf = six.BytesIO() | |||
| buf.write(imgbuf) | |||
| buf.seek(0) | |||
| try: | |||
| img = Image.open(buf).convert('RGB') # for color image | |||
| except IOError: | |||
| print(f'Corrupted image for {index}') | |||
| # make dummy image and dummy label for corrupted image. | |||
| img = Image.new('RGB', (config.IMG_W, config.IMG_H)) | |||
| label = '[dummy_label]' | |||
| label = label.lower() | |||
| return img, label | |||
| class ST_MJ_Generator_batch_fixed_length: | |||
| def __init__(self): | |||
| self.align_collector = AlignCollate() | |||
| self.converter = CTCLabelConverter(config.CHARACTER) | |||
| self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, | |||
| meminit=False) | |||
| if not self.env: | |||
| print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) | |||
| raise ValueError(config.TRAIN_DATASET_PATH) | |||
| with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: | |||
| self.st_mj_filtered_index_list = pickle.load(f) | |||
| print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') | |||
| self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE | |||
| self.batch_size = config.TRAIN_BATCH_SIZE | |||
| def __len__(self): | |||
| return self.dataset_size | |||
| def __getitem__(self, item): | |||
| img_ret = [] | |||
| text_ret = [] | |||
| for i in range(item * self.batch_size, (item + 1) * self.batch_size): | |||
| index = self.st_mj_filtered_index_list[i] | |||
| img, label = get_img_from_lmdb(self.env, index) | |||
| img_ret.append(img) | |||
| text_ret.append(label) | |||
| img_ret = self.align_collector(img_ret) | |||
| text_ret, length = self.converter.encode(text_ret) | |||
| label_indices = [] | |||
| for i, _ in enumerate(length): | |||
| for j in range(length[i]): | |||
| label_indices.append((i, j)) | |||
| label_indices = np.array(label_indices, np.int64) | |||
| sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) | |||
| text_ret = text_ret.astype(np.int32) | |||
| return img_ret, label_indices, text_ret, sequence_length | |||
| class ST_MJ_Generator_batch_fixed_length_para: | |||
| def __init__(self): | |||
| self.align_collector = AlignCollate() | |||
| self.converter = CTCLabelConverter(config.CHARACTER) | |||
| self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, | |||
| meminit=False) | |||
| if not self.env: | |||
| print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) | |||
| raise ValueError(config.TRAIN_DATASET_PATH) | |||
| with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: | |||
| self.st_mj_filtered_index_list = pickle.load(f) | |||
| print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') | |||
| self.rank_id = get_rank() | |||
| self.rank_size = get_group_size() | |||
| self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size | |||
| self.batch_size = config.TRAIN_BATCH_SIZE | |||
| def __len__(self): | |||
| return self.dataset_size | |||
| def __getitem__(self, item): | |||
| img_ret = [] | |||
| text_ret = [] | |||
| rank_item = (item * self.rank_size) + self.rank_id | |||
| for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size): | |||
| index = self.st_mj_filtered_index_list[i] | |||
| img, label = get_img_from_lmdb(self.env, index) | |||
| img_ret.append(img) | |||
| text_ret.append(label) | |||
| img_ret = self.align_collector(img_ret) | |||
| text_ret, length = self.converter.encode(text_ret) | |||
| label_indices = [] | |||
| for i, _ in enumerate(length): | |||
| for j in range(length[i]): | |||
| label_indices.append((i, j)) | |||
| label_indices = np.array(label_indices, np.int64) | |||
| sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) | |||
| text_ret = text_ret.astype(np.int32) | |||
| return img_ret, label_indices, text_ret, sequence_length | |||
| def IIIT_Generator_batch(): | |||
| max_len = int((26 + 1) // 2) | |||
| align_collector = AlignCollate() | |||
| converter = CTCLabelConverter(config.CHARACTER) | |||
| env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||
| if not env: | |||
| print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH)) | |||
| sys.exit(0) | |||
| with env.begin(write=False) as txn: | |||
| nSamples = int(txn.get('num-samples'.encode())) | |||
| nSamples = nSamples | |||
| # Filtering | |||
| filtered_index_list = [] | |||
| for index in range(nSamples): | |||
| index += 1 # lmdb starts with 1 | |||
| label_key = 'label-%09d'.encode() % index | |||
| label = txn.get(label_key).decode('utf-8') | |||
| if len(label) > max_len: | |||
| continue | |||
| illegal_sample = False | |||
| for char_item in label.lower(): | |||
| if char_item not in config.CHARACTER: | |||
| illegal_sample = True | |||
| break | |||
| if illegal_sample: | |||
| continue | |||
| filtered_index_list.append(index) | |||
| img_ret = [] | |||
| text_ret = [] | |||
| print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') | |||
| for index in filtered_index_list: | |||
| img, label = get_img_from_lmdb(env, index) | |||
| img_ret.append(img) | |||
| text_ret.append(label) | |||
| if len(img_ret) == config.TEST_BATCH_SIZE: | |||
| img_ret = align_collector(img_ret) | |||
| text_ret, length = converter.encode(text_ret) | |||
| label_indices = [] | |||
| for i, _ in enumerate(length): | |||
| for j in range(length[i]): | |||
| label_indices.append((i, j)) | |||
| label_indices = np.array(label_indices, np.int64) | |||
| sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32) | |||
| text_ret = text_ret.astype(np.int32) | |||
| yield img_ret, label_indices, text_ret, sequence_length, length | |||
| img_ret = [] | |||
| text_ret = [] | |||
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """generate ascend rank file""" | |||
| import os | |||
| import socket | |||
| import argparse | |||
| parser = argparse.ArgumentParser(description="ascend distribute rank.") | |||
| parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.") | |||
| def main(rank_table_file): | |||
| nproc_per_node = 8 | |||
| visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7'] | |||
| server_id = socket.gethostbyname(socket.gethostname()) | |||
| hccn_configs = open('/etc/hccn.conf', 'r').readlines() | |||
| device_ips = {} | |||
| for hccn_item in hccn_configs: | |||
| hccn_item = hccn_item.strip() | |||
| if hccn_item.startswith('address_'): | |||
| device_id, device_ip = hccn_item.split('=') | |||
| device_id = device_id.split('_')[1] | |||
| device_ips[device_id] = device_ip | |||
| print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) | |||
| hccn_table = {} | |||
| hccn_table['board_id'] = '0x002f' # A+K | |||
| # hccn_table['board_id'] = '0x0000' # A+X | |||
| hccn_table['chip_info'] = '910' | |||
| hccn_table['deploy_mode'] = 'lab' | |||
| hccn_table['group_count'] = '1' | |||
| hccn_table['group_list'] = [] | |||
| instance_list = [] | |||
| for instance_id in range(nproc_per_node): | |||
| instance = {} | |||
| instance['devices'] = [] | |||
| device_id = visible_devices[instance_id] | |||
| device_ip = device_ips[device_id] | |||
| instance['devices'].append({ | |||
| 'device_id': device_id, | |||
| 'device_ip': device_ip, | |||
| }) | |||
| instance['rank_id'] = str(instance_id) | |||
| instance['server_id'] = server_id | |||
| instance_list.append(instance) | |||
| hccn_table['group_list'].append({ | |||
| 'device_num': str(nproc_per_node), | |||
| 'server_num': '1', | |||
| 'group_name': '', | |||
| 'instance_count': str(nproc_per_node), | |||
| 'instance_list': instance_list, | |||
| }) | |||
| hccn_table['para_plane_nic_location'] = 'device' | |||
| hccn_table['para_plane_nic_name'] = [] | |||
| for instance_id in range(nproc_per_node): | |||
| eth_id = visible_devices[instance_id] | |||
| hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) | |||
| hccn_table['para_plane_nic_num'] = str(nproc_per_node) | |||
| hccn_table['status'] = 'completed' | |||
| import json | |||
| with open(rank_table_file, 'w') as table_fp: | |||
| json.dump(hccn_table, table_fp, indent=4) | |||
| if __name__ == '__main__': | |||
| args_opt = parser.parse_args() | |||
| rank_table = args_opt.rank_file | |||
| if os.path.exists(rank_table): | |||
| print('Rank table file exists.') | |||
| else: | |||
| print('Generating rank table file.') | |||
| main(rank_table) | |||
| print('Rank table file generated') | |||
| @@ -0,0 +1,171 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """preprocess dataset""" | |||
| import random | |||
| import pickle | |||
| import numpy as np | |||
| import lmdb | |||
| from tqdm import tqdm | |||
| def combine_lmdbs(lmdb_paths, lmdb_save_path): | |||
| max_len = int((26 + 1) // 2) | |||
| character = '0123456789abcdefghijklmnopqrstuvwxyz' | |||
| env_save = lmdb.open( | |||
| lmdb_save_path, | |||
| map_size=1099511627776) | |||
| cnt = 0 | |||
| for lmdb_path in lmdb_paths: | |||
| env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||
| with env.begin(write=False) as txn: | |||
| nSamples = int(txn.get('num-samples'.encode())) | |||
| nSamples = nSamples | |||
| # Filtering | |||
| for index in tqdm(range(nSamples)): | |||
| index += 1 # lmdb starts with 1 | |||
| label_key = 'label-%09d'.encode() % index | |||
| label = txn.get(label_key).decode('utf-8') | |||
| if len(label) > max_len: | |||
| continue | |||
| illegal_sample = False | |||
| for char_item in label.lower(): | |||
| if char_item not in character: | |||
| illegal_sample = True | |||
| break | |||
| if illegal_sample: | |||
| continue | |||
| img_key = 'image-%09d'.encode() % index | |||
| imgbuf = txn.get(img_key) | |||
| with env_save.begin(write=True) as txn_save: | |||
| cnt += 1 | |||
| label_key_save = 'label-%09d'.encode() % cnt | |||
| label_save = label.encode() | |||
| image_key_save = 'image-%09d'.encode() % cnt | |||
| image_save = imgbuf | |||
| txn_save.put(label_key_save, label_save) | |||
| txn_save.put(image_key_save, image_save) | |||
| nSamples = cnt | |||
| with env_save.begin(write=True) as txn_save: | |||
| txn_save.put('num-samples'.encode(), str(nSamples).encode()) | |||
| def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000): | |||
| label_length_dict = {} | |||
| env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||
| with env.begin(write=False) as txn: | |||
| nSamples = int(txn.get('num-samples'.encode())) | |||
| nSamples = nSamples | |||
| for index in tqdm(range(nSamples)): | |||
| index += 1 # lmdb starts with 1 | |||
| label_key = 'label-%09d'.encode() % index | |||
| label = txn.get(label_key).decode('utf-8') | |||
| label_length = len(label) | |||
| if label_length in label_length_dict: | |||
| label_length_dict[label_length] += 1 | |||
| else: | |||
| label_length_dict[label_length] = 1 | |||
| sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True) | |||
| label_length_sum = 0 | |||
| label_num = 0 | |||
| lengths = [] | |||
| p = [] | |||
| for l, num in sorted_label_length: | |||
| label_length_sum += l * num | |||
| label_num += num | |||
| p.append(num) | |||
| lengths.append(l) | |||
| for i, _ in enumerate(p): | |||
| p[i] /= label_num | |||
| average_overall_length = int(label_length_sum / label_num * batch_size) | |||
| def get_combinations_of_fix_length(fix_length, items, p, batch_size): | |||
| ret = [] | |||
| cur_sum = 0 | |||
| ret = np.random.choice(items, batch_size - 1, True, p) | |||
| cur_sum = sum(ret) | |||
| ret = list(ret) | |||
| if fix_length - cur_sum in items: | |||
| ret.append(fix_length - cur_sum) | |||
| else: | |||
| return None | |||
| return ret | |||
| result = [] | |||
| while len(result) < num_of_combinations: | |||
| ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size) | |||
| if ret is not None: | |||
| result.append(ret) | |||
| return result | |||
| def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000): | |||
| length_index_dict = {} | |||
| env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) | |||
| with env.begin(write=False) as txn: | |||
| nSamples = int(txn.get('num-samples'.encode())) | |||
| nSamples = nSamples | |||
| for index in tqdm(range(nSamples)): | |||
| index += 1 # lmdb starts with 1 | |||
| label_key = 'label-%09d'.encode() % index | |||
| label = txn.get(label_key).decode('utf-8') | |||
| label_length = len(label) | |||
| if label_length in length_index_dict: | |||
| length_index_dict[label_length].append(index) | |||
| else: | |||
| length_index_dict[label_length] = [index] | |||
| ret = [] | |||
| for _ in range(num_of_iters): | |||
| comb = random.choice(combinations) | |||
| for l in comb: | |||
| ret.append(random.choice(length_index_dict[l])) | |||
| with open(pkl_save_path, 'wb') as f: | |||
| pickle.dump(ret, f, -1) | |||
| if __name__ == '__main__': | |||
| # step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file | |||
| print('Begin to combine multiple lmdb datasets') | |||
| combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/', | |||
| '/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'], | |||
| '/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') | |||
| # step 2: generate the order of input data, guarantee that the input batch shape is fixed | |||
| print('Begin to generate the index order of input data') | |||
| combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') | |||
| generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination, | |||
| '/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl') | |||
| print('Done') | |||
| @@ -0,0 +1,102 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """util file""" | |||
| import numpy as np | |||
| class AverageMeter(): | |||
| """Computes and stores the average and current value""" | |||
| def __init__(self): | |||
| self.reset() | |||
| def reset(self): | |||
| self.val = 0 | |||
| self.avg = 0 | |||
| self.sum = 0 | |||
| self.count = 0 | |||
| def update(self, val, n=1): | |||
| self.val = val | |||
| self.sum += val * n | |||
| self.count += n | |||
| self.avg = self.sum / self.count | |||
| class CTCLabelConverter(): | |||
| """ Convert between text-label and text-index """ | |||
| def __init__(self, character): | |||
| # character (str): set of the possible characters. | |||
| dict_character = list(character) | |||
| self.dict = {} | |||
| for i, char in enumerate(dict_character): | |||
| self.dict[char] = i | |||
| self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0) | |||
| self.dict['[blank]'] = len(dict_character) | |||
| def encode(self, text): | |||
| """convert text-label into text-index. | |||
| input: | |||
| text: text labels of each image. [batch_size] | |||
| output: | |||
| text: concatenated text index for CTCLoss. | |||
| [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |||
| length: length of each text. [batch_size] | |||
| """ | |||
| length = [len(s) for s in text] | |||
| text = ''.join(text) | |||
| text = [self.dict[char] for char in text] | |||
| return np.array(text), np.array(length) | |||
| def decode(self, text_index, length): | |||
| """ convert text-index into text-label. """ | |||
| texts = [] | |||
| index = 0 | |||
| for l in length: | |||
| t = text_index[index:index + l] | |||
| char_list = [] | |||
| for i in range(l): | |||
| # if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |||
| if t[i] != self.dict['[blank]'] and ( | |||
| not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. | |||
| char_list.append(self.character[t[i]]) | |||
| text = ''.join(char_list) | |||
| texts.append(text) | |||
| index += l | |||
| return texts | |||
| def reverse_encode(self, text_index, length): | |||
| """ convert text-index into text-label. """ | |||
| texts = [] | |||
| index = 0 | |||
| for l in length: | |||
| t = text_index[index:index + l] | |||
| char_list = [] | |||
| for i in range(l): | |||
| if t[i] != self.dict['[blank]']: # removing repeated characters and blank. | |||
| char_list.append(self.character[t[i]]) | |||
| text = ''.join(char_list) | |||
| texts.append(text) | |||
| index += l | |||
| return texts | |||
| @@ -0,0 +1,100 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cnnctc train""" | |||
| import argparse | |||
| import ast | |||
| import mindspore | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.dataset import GeneratorDataset | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train.model import Model | |||
| from mindspore.communication.management import init | |||
| from mindspore.common import set_seed | |||
| from src.config import Config_CNNCTC | |||
| from src.callback import LossCallBack | |||
| from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para | |||
| from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell | |||
| set_seed(1) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, | |||
| save_graphs_path=".", enable_auto_mixed_precision=False) | |||
| def dataset_creator(run_distribute): | |||
| if run_distribute: | |||
| st_dataset = ST_MJ_Generator_batch_fixed_length_para() | |||
| else: | |||
| st_dataset = ST_MJ_Generator_batch_fixed_length() | |||
| ds = GeneratorDataset(st_dataset, | |||
| ['img', 'label_indices', 'text', 'sequence_length'], | |||
| num_parallel_workers=8) | |||
| return ds | |||
| def train(args_opt, config): | |||
| if args_opt.run_distribute: | |||
| init() | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel") | |||
| ds = dataset_creator(args_opt.run_distribute) | |||
| net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) | |||
| net.set_train(True) | |||
| if config.CKPT_PATH != '': | |||
| param_dict = load_checkpoint(config.CKPT_PATH) | |||
| load_param_into_net(net, param_dict) | |||
| print('parameters loaded!') | |||
| else: | |||
| print('train from scratch...') | |||
| criterion = ctc_loss() | |||
| opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA, | |||
| momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE) | |||
| net = WithLossCell(net, criterion) | |||
| loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False) | |||
| model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") | |||
| callback = LossCallBack() | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP, | |||
| keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM) | |||
| ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH) | |||
| if args_opt.device_id == 0: | |||
| model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False) | |||
| else: | |||
| model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='CNNCTC arg') | |||
| parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.") | |||
| parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.") | |||
| parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, | |||
| help="Run distribute, default is false.") | |||
| args_cfg = parser.parse_args() | |||
| cfg = Config_CNNCTC() | |||
| if args_cfg.ckpt_path != "": | |||
| cfg.CKPT_PATH = args_cfg.ckpt_path | |||
| train(args_cfg, cfg) | |||
| @@ -11,7 +11,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" :=========================================================================== | |||
| # =========================================================================== | |||
| """ | |||
| network config setting, will be used in train.py and eval.py | |||
| """ | |||