Removed run_eval_gpu.sh update ReadMe fix bug in train.py fix bug of eval fix accuracy issue in the new ascend add teacher_force add self-defined LSTMtags/v1.2.0-rc1
| @@ -0,0 +1,196 @@ | |||
| # Contents | |||
| - [Contents](#contents) | |||
| - [CRNN-Seq2Seq-OCR Description](#crnn-seq2seq-ocr-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Dataset Prepare](#dataset-prepare) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Script Parameters](#training-script-parameters) | |||
| - [Parameters Configuration](#parameters-configuration) | |||
| - [Dataset Preparation](#dataset-preparation) | |||
| - [Training Process](#training-process) | |||
| - [Training](#training) | |||
| - [Distributed Training](#distributed-training) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Training Performance](#training-performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| ## [CRNN-Seq2Seq-OCR Description](#contents) | |||
| CRNN-Seq2Seq-OCR is a neural network model for image based sequence recognition tasks, such as scene text recognition and optical character recognition (OCR). Its architecture is a combination of CNN and sequence to sequence model with attention mechanism. | |||
| ## [Model Architecture](#content) | |||
| CRNN-Seq2Seq-OCR applies a vgg structure to extract features from processed images, following with attention-based encoder and decoder layer, finally utilizes NLL to calculate loss. See src/attention_ocr.py for details. | |||
| ## [Dataset](#content) | |||
| For training and evaluation, we use the French Street Name Signs (FSNS) released by Google as the training data, which contains approximately 1 million training images and their corresponding ground truth words. | |||
| ## [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. You will be able to have access to related resources once approved. | |||
| - Framework | |||
| - [MindSpore](https://gitee.com/mindspore/mindspore) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| ## [Quick Start](#contents) | |||
| - After the dataset is prepared, you may start running the training or the evaluation scripts as follows: | |||
| - Running on Ascend | |||
| ```shell | |||
| # distribute training example in Ascend | |||
| $ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| # evaluation example in Ascend | |||
| $ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| # standalone training example in Ascend | |||
| $ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM] | |||
| ``` | |||
| For distributed training, a hccl configuration file with JSON format needs to be created in advance. | |||
| Please follow the instructions in the link below: | |||
| [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). | |||
| ## [Script Description](#contents) | |||
| ### [Script and Sample Code](#contents) | |||
| ```shell | |||
| crnn-seq2seq-ocr | |||
| ├── README.md # Descriptions about CRNN-Seq2Seq-OCR | |||
| ├── scripts | |||
| │ ├── run_distribute_train.sh # Launch distributed training on Ascend(8 pcs) | |||
| │ ├── run_eval_ascend.sh # Launch Ascend evaluation | |||
| │ └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs) | |||
| ├── src | |||
| │ ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper | |||
| │ ├── cnn.py # VGG network | |||
| │ ├── config.py # Parameter configuration | |||
| │ ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth | |||
| │ ├── dataset.py # Data preprocessing for training and evaluation | |||
| │ ├── gru.py # GRU cell wrapper | |||
| │ ├── logger.py # Logger configuration | |||
| │ ├── lstm.py # LSTM cell wrapper | |||
| │ ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure | |||
| │ └── utils.py # Utility functions for training and data pre-processing | |||
| │ ├── weight_init.py # weight initialization of LSTM and GRU | |||
| └── train.py # Training script | |||
| ├── eval.py # Evaluation Script | |||
| ``` | |||
| ### [Script Parameters](#contents) | |||
| #### Training Script Parameters | |||
| ```shell | |||
| # distributed training on Ascend | |||
| Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| # standalone training | |||
| Usage: bash run_standalone_train.sh [DATASET_PATH] | |||
| ``` | |||
| #### Parameters Configuration | |||
| Parameters for both training and evaluation can be set in config.py. | |||
| ### [Dataset Preparation](#contents) | |||
| - You may refer to "Generate dataset" in [Quick Start](#quick-start) to automatically generate a dataset, or you may choose to generate a text image dataset by yourself. | |||
| ## [Training Process](#contents) | |||
| - Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset. | |||
| ### [Training](#contents) | |||
| - Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now. | |||
| ``` bash | |||
| bash run_standalone_train.sh [DATASET_PATH] | |||
| ``` | |||
| #### [Distributed Training](#contents) | |||
| - Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend. | |||
| ``` bash | |||
| bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] | |||
| ``` | |||
| Check the `train_parallel0/log.txt` and you will get outputs as following: | |||
| ```shell | |||
| epoch: 20 step: 4080, loss is 1.56112 | |||
| epoch: 20 step: 4081, loss is 1.6368448 | |||
| epoch time: 1559886.096 ms, per step time: 382.231 ms | |||
| ``` | |||
| ## [Evaluation Process](#contents) | |||
| ### [Evaluation](#contents) | |||
| - Run `run_eval_ascend.sh` for evaluation on Ascend. | |||
| ``` bash | |||
| bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ``` | |||
| Check the `eval/log` and you will get outputs as following: | |||
| ```shell | |||
| character precision = 0.967522 | |||
| Annotation precision precision = 0.635204 | |||
| ``` | |||
| # Model Description | |||
| ## Performance | |||
| ### Evaluation Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ----------------------------------------------------------- | | |||
| | Model Version | V1 | | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | |||
| | uploaded Date | 02/11/2021 (month/day/year) | | |||
| | MindSpore Version | 1.2.0 | | |||
| | Dataset | FSNS | | |||
| | Training Parameters | epoch=20, batch_size=32 | | |||
| | Optimizer | SGD | | |||
| | Loss Function | Negative Log Likelihood | | |||
| | Speed | 1pc: 355 ms/step; 8pcs: 385 ms/step | | |||
| | Total time | 1pc: 64 hours; 8pcs: 9 hours | | |||
| | Parameters (M) | 12 | | |||
| | Scripts | [crnn_seq2seq_ocr script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn_seq2seq_ocr) | | |||
| ### Inference Performance | |||
| | Parameters | Ascend | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | V1 | | |||
| | Resource | Ascend 910 | | |||
| | Uploaded Date | 02/11/2021 (month/day/year) | | |||
| | MindSpore Version | 1.2.0 | | |||
| | Dataset | FSNS | | |||
| | batch_size | 32 | | |||
| | outputs | Annotation Precision, Character Precision | | |||
| | Accuracy | Annotation Precision=63.52%, Character Precision=96.75% | | |||
| | Model for inference | 12M (.ckpt file) | | |||
| @@ -0,0 +1,181 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| CRNN-Seq2Seq-OCR Evaluation. | |||
| """ | |||
| import os | |||
| import codecs | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.ops.operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common import set_seed | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config | |||
| from src.utils import initialize_vocabulary | |||
| from src.dataset import create_ocr_val_dataset | |||
| from src.attention_ocr import AttentionOCRInfer | |||
| set_seed(1) | |||
| def text_standardization(text_in): | |||
| """ | |||
| replace some particular characters | |||
| """ | |||
| stand_text = text_in.strip() | |||
| stand_text = ' '.join(stand_text.split()) | |||
| stand_text = stand_text.replace(u'(', u'(') | |||
| stand_text = stand_text.replace(u')', u')') | |||
| stand_text = stand_text.replace(u':', u':') | |||
| return stand_text | |||
| def LCS_length(str1, str2): | |||
| """ | |||
| calculate longest common sub-sequence between str1 and str2 | |||
| """ | |||
| if str1 is None or str2 is None: | |||
| return 0 | |||
| len1 = len(str1) | |||
| len2 = len(str2) | |||
| if len1 == 0 or len2 == 0: | |||
| return 0 | |||
| lcs = [[0 for _ in range(len2 + 1)] for _ in range(2)] | |||
| for i in range(1, len1 + 1): | |||
| for j in range(1, len2 + 1): | |||
| if str1[i - 1] == str2[j - 1]: | |||
| lcs[i % 2][j] = lcs[(i - 1) % 2][j - 1] + 1 | |||
| else: | |||
| if lcs[i % 2][j - 1] >= lcs[(i - 1) % 2][j]: | |||
| lcs[i % 2][j] = lcs[i % 2][j - 1] | |||
| else: | |||
| lcs[i % 2][j] = lcs[(i - 1) % 2][j] | |||
| return lcs[len1 % 2][-1] | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation") | |||
| parser.add_argument("--dataset_path", type=str, default="", | |||
| help="Test Dataset path") | |||
| parser.add_argument("--checkpoint_path", type=str, default=None, | |||
| help="Checkpoint of AttentionOCR (Default:None).") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented, default is Ascend") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| args = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| prefix = "fsns.mindrecord" | |||
| mindrecord_dir = args.dataset_path | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| print("mindrecord_file", mindrecord_file) | |||
| dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size) | |||
| data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True) | |||
| print("Dataset creation Done!") | |||
| #Network | |||
| network = AttentionOCRInfer(config.eval_batch_size, | |||
| int(config.img_width / 4), | |||
| config.encoder_hidden_size, | |||
| config.decoder_hidden_size, | |||
| config.decoder_output_size, | |||
| config.max_length, | |||
| config.dropout_p) | |||
| ckpt = load_checkpoint(args.checkpoint_path) | |||
| load_param_into_net(network, ckpt) | |||
| network.set_train(False) | |||
| print("Checkpoint loading Done!") | |||
| vocab, rev_vocab = initialize_vocabulary(config.vocab_path) | |||
| eos_id = config.characters_dictionary.get("eos_id") | |||
| sos_id = config.characters_dictionary.get("go_id") | |||
| num_correct_char = 0 | |||
| num_total_char = 0 | |||
| num_correct_word = 0 | |||
| num_total_word = 0 | |||
| correct_file = 'result_correct.txt' | |||
| incorrect_file = 'result_incorrect.txt' | |||
| with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \ | |||
| codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect: | |||
| for data in data_loader: | |||
| images = Tensor(data["image"]) | |||
| decoder_inputs = Tensor(data["decoder_input"]) | |||
| decoder_targets = Tensor(data["decoder_target"]) | |||
| decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size), | |||
| dtype=np.float16), mstype.float16) | |||
| decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32)) | |||
| encoder_outputs = network.encoder(images) | |||
| batch_decoded_label = [] | |||
| for di in range(decoder_inputs.shape[1]): | |||
| decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs) | |||
| topi = P.Argmax()(decoder_output) | |||
| ni = P.ExpandDims()(topi, 1) | |||
| decoder_input = ni | |||
| topi_id = topi.asnumpy() | |||
| batch_decoded_label.append(topi_id) | |||
| for b in range(config.eval_batch_size): | |||
| text = data["annotation"][b].decode("utf8") | |||
| text = text_standardization(text) | |||
| decoded_label = list(np.array(batch_decoded_label)[:, b]) | |||
| decoded_words = [] | |||
| for idx in decoded_label: | |||
| if idx == eos_id: | |||
| break | |||
| else: | |||
| decoded_words.append(rev_vocab[idx]) | |||
| predict = text_standardization("".join(decoded_words)) | |||
| if predict == text: | |||
| num_correct_word += 1 | |||
| fp_output_correct.write('\t\t' + text + '\n') | |||
| fp_output_correct.write('\t\t' + predict + '\n\n') | |||
| print('correctly predicted : pred: {}, gt: {}'.format(predict, text)) | |||
| else: | |||
| fp_output_incorrect.write('\t\t' + text + '\n') | |||
| fp_output_incorrect.write('\t\t' + predict + '\n\n') | |||
| print('incorrectly predicted : pred: {}, gt: {}'.format(predict, text)) | |||
| num_total_word += 1 | |||
| num_correct_char += 2 * LCS_length(text, predict) | |||
| num_total_char += len(text) + len(predict) | |||
| print('\nnum of correct characters = %d' % (num_correct_char)) | |||
| print('\nnum of total characters = %d' % (num_total_char)) | |||
| print('\nnum of correct words = %d' % (num_correct_word)) | |||
| print('\nnum of total words = %d' % (num_total_word)) | |||
| print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char)) | |||
| print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word)) | |||
| @@ -0,0 +1,66 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 2 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| if [ ! -f $PATH1 ] | |||
| then | |||
| echo "error: RANK_TABLE_FILE=$PATH1 is not a file" | |||
| exit 1 | |||
| fi | |||
| PATH2=$(get_real_path $2) | |||
| echo $PATH2 | |||
| if [ ! -f $PATH2 ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export RANK_TABLE_FILE=$PATH1 | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log & | |||
| cd .. | |||
| done | |||
| @@ -0,0 +1,64 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 2 ] | |||
| then | |||
| echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| PATH2=$(get_real_path $2) | |||
| echo $PATH1 | |||
| echo $PATH2 | |||
| if [ ! -d $PATH1 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a folder" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $PATH2 ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=1 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval || exit | |||
| env > env.log | |||
| echo "start eval for device $DEVICE_ID" | |||
| python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log & | |||
| cd .. | |||
| @@ -0,0 +1,58 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 1 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| if [ ! -f $PATH1 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=1 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| if [ -d "train" ]; | |||
| then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log & | |||
| cd .. | |||
| @@ -0,0 +1,178 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """ | |||
| CRNN-Seq2Seq-OCR model. | |||
| """ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context, Tensor | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from src.seq2seq import Encoder, Decoder | |||
| class NLLLoss(_Loss): | |||
| def __init__(self, reduction='mean'): | |||
| super(NLLLoss, self).__init__(reduction) | |||
| self.one_hot = P.OneHot() | |||
| self.reduce_sum = P.ReduceSum() | |||
| def construct(self, logits, label): | |||
| label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0)) | |||
| loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,)) | |||
| return self.get_loss(loss) | |||
| class AttentionOCRInfer(nn.Cell): | |||
| def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size, | |||
| decoder_output_size, max_length, dropout_p=0.1): | |||
| super(AttentionOCRInfer, self).__init__() | |||
| self.encoder = Encoder(batch_size=batch_size, | |||
| conv_out_dim=conv_out_dim, | |||
| hidden_size=encoder_hidden_size) | |||
| self.decoder = Decoder(hidden_size=decoder_hidden_size, | |||
| output_size=decoder_output_size, | |||
| max_length=max_length, | |||
| dropout_p=dropout_p) | |||
| def construct(self, img, decoder_input, decoder_hidden): | |||
| ''' | |||
| get token output | |||
| ''' | |||
| encoder_outputs = self.encoder(img) | |||
| decoder_output, decoder_hidden, decoder_attention = self.decoder( | |||
| decoder_input, decoder_hidden, encoder_outputs) | |||
| return decoder_output, decoder_hidden, decoder_attention | |||
| class AttentionOCR(nn.Cell): | |||
| def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size, | |||
| decoder_output_size, max_length, dropout_p=0.1): | |||
| super(AttentionOCR, self).__init__() | |||
| self.encoder = Encoder(batch_size=batch_size, | |||
| conv_out_dim=conv_out_dim, | |||
| hidden_size=encoder_hidden_size) | |||
| self.decoder = Decoder(hidden_size=decoder_hidden_size, | |||
| output_size=decoder_output_size, | |||
| max_length=max_length, | |||
| dropout_p=dropout_p) | |||
| self.init_decoder_hidden = Tensor(np.zeros((1, batch_size, decoder_hidden_size), | |||
| dtype=np.float16), mstype.float16) | |||
| self.shape = P.Shape() | |||
| self.split = P.Split(axis=1, output_num=max_length) | |||
| self.concat = P.Concat() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.argmax = P.Argmax() | |||
| self.select = P.Select() | |||
| def construct(self, img, decoder_inputs, decoder_targets, teacher_force): | |||
| encoder_outputs = self.encoder(img) | |||
| _, text_len = self.shape(decoder_inputs) | |||
| decoder_outputs = () | |||
| decoder_input_tuple = self.split(decoder_inputs) | |||
| decoder_target_tuple = self.split(decoder_targets) | |||
| decoder_input = decoder_input_tuple[0] | |||
| decoder_hidden = self.init_decoder_hidden | |||
| for i in range(text_len): | |||
| decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs) | |||
| topi = self.argmax(decoder_output) | |||
| decoder_input_top = self.expand_dims(topi, 1) | |||
| decoder_input = self.select(teacher_force, decoder_target_tuple[i], decoder_input_top) | |||
| decoder_output = self.expand_dims(decoder_output, 0) | |||
| decoder_outputs += (decoder_output,) | |||
| outputs = self.concat(decoder_outputs) | |||
| return outputs | |||
| class AttentionOCRWithLossCell(nn.Cell): | |||
| """AttentionOCR with Loss""" | |||
| def __init__(self, network, max_length): | |||
| super(AttentionOCRWithLossCell, self).__init__() | |||
| self.network = network | |||
| self.loss = NLLLoss() | |||
| self.shape = P.Shape() | |||
| self.add = P.AddN() | |||
| self.mean = P.ReduceMean() | |||
| self.split = P.Split(axis=0, output_num=max_length) | |||
| self.squeeze = P.Squeeze() | |||
| self.cast = P.Cast() | |||
| def construct(self, img, decoder_inputs, decoder_targets, teacher_force): | |||
| decoder_outputs = self.network(img, decoder_inputs, decoder_targets, teacher_force) | |||
| decoder_outputs = self.cast(decoder_outputs, mstype.float32) | |||
| _, text_len = self.shape(decoder_targets) | |||
| loss_total = () | |||
| decoder_output_tuple = self.split(decoder_outputs) | |||
| for i in range(text_len): | |||
| loss = self.loss(self.squeeze(decoder_output_tuple[i]), decoder_targets[:, i]) | |||
| loss = self.mean(loss) | |||
| loss_total += (loss,) | |||
| loss_output = self.add(loss_total) | |||
| return loss_output | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * P.Reciprocal()(scale) | |||
| class TrainingWrapper(nn.Cell): | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = ms.ParameterTuple(network.trainable_params()) | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| # Set parallel_mode | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| if auto_parallel_context().get_device_num_is_set(): | |||
| degree = context.get_auto_parallel_context("device_num") | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, *args): | |||
| weights = self.weights | |||
| loss = self.network(*args) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| grads = self.grad(self.network, weights)(*args, sens) | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -0,0 +1,195 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """ | |||
| CRN-Seq2Seq-OCR CNN model. | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| def calculate_gain(nonlinearity, param=None): | |||
| """calculate_gain""" | |||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||
| res = 0 | |||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||
| res = 1 | |||
| elif nonlinearity == 'tanh': | |||
| res = 5.0 / 3 | |||
| elif nonlinearity == 'relu': | |||
| res = math.sqrt(2.0) | |||
| elif nonlinearity == 'leaky_relu': | |||
| if param is None: | |||
| negative_slope = 0.01 | |||
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||
| negative_slope = param | |||
| else: | |||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||
| res = math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||
| else: | |||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||
| return res | |||
| def _calculate_fan_in_and_fan_out(tensor): | |||
| """_calculate_fan_in_and_fan_out""" | |||
| dimensions = len(tensor) | |||
| if dimensions < 2: | |||
| raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | |||
| if dimensions == 2: | |||
| fan_in = tensor[1] | |||
| fan_out = tensor[0] | |||
| else: | |||
| num_input_fmaps = tensor[1] | |||
| num_output_fmaps = tensor[0] | |||
| receptive_field_size = 1 | |||
| if dimensions > 2: | |||
| receptive_field_size = tensor[2] * tensor[3] | |||
| fan_in = num_input_fmaps * receptive_field_size | |||
| fan_out = num_output_fmaps * receptive_field_size | |||
| return fan_in, fan_out | |||
| def _calculate_correct_fan(tensor, mode): | |||
| mode = mode.lower() | |||
| valid_modes = ['fan_in', 'fan_out'] | |||
| if mode not in valid_modes: | |||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | |||
| return fan_in if mode == 'fan_in' else fan_out | |||
| def kaiming_normal(inputs_shape, gain_param=0, mode='fan_in', nonlinearity='leaky_relu'): | |||
| fan = _calculate_correct_fan(inputs_shape, mode) | |||
| gain = calculate_gain(nonlinearity, gain_param) | |||
| std = gain / math.sqrt(fan) | |||
| return np.random.normal(0, std, size=inputs_shape).astype(np.float32) | |||
| class ConvRelu(nn.Cell): | |||
| """ | |||
| Convolution Layer followed by Relu Layer | |||
| """ | |||
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1): | |||
| super(ConvRelu, self).__init__() | |||
| shape = (out_channels, in_channels, kernel_size[0], kernel_size[1]) | |||
| self.conv = nn.Conv2d(in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| weight_init=Tensor(kaiming_normal(shape))) | |||
| self.relu = nn.ReLU() | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class ConvBNRelu(nn.Cell): | |||
| """ | |||
| Convolution Layer followed by Batch Normalization and Relu Layer | |||
| """ | |||
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same'): | |||
| super(ConvBNRelu, self).__init__() | |||
| shape = (out_channels, in_channels, kernel_size[0], kernel_size[1]) | |||
| self.conv = nn.Conv2d(in_channels, | |||
| out_channels, | |||
| kernel_size, stride, | |||
| pad_mode=pad_mode, | |||
| weight_init=Tensor(kaiming_normal(shape))) | |||
| self.bn = nn.BatchNorm2d(out_channels) | |||
| self.relu = nn.ReLU() | |||
| def construct(self, x): | |||
| x = self.conv(x) | |||
| x = self.bn(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class CNN(nn.Cell): | |||
| """ | |||
| CNN Class for OCR | |||
| """ | |||
| def __init__(self, conv_out_dim): | |||
| super(CNN, self).__init__() | |||
| self.convRelu1 = ConvRelu(3, 64, (3, 3)) | |||
| self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) | |||
| self.convRelu2 = ConvRelu(64, 128, (3, 3)) | |||
| self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |||
| self.convBNRelu1 = ConvBNRelu(128, 256, (3, 3)) | |||
| self.convRelu3 = ConvRelu(256, 256, (3, 3)) | |||
| self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |||
| self.convBNRelu2 = ConvBNRelu(256, 384, (3, 3)) | |||
| self.convRelu4 = ConvRelu(384, 384, (3, 3)) | |||
| self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |||
| self.convBNRelu3 = ConvBNRelu(384, 384, (3, 3)) | |||
| self.convRelu5 = ConvRelu(384, 384, (3, 3)) | |||
| self.maxpool5 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |||
| self.convBNRelu4 = ConvBNRelu(384, 384, (3, 3)) | |||
| self.convRelu6 = ConvRelu(384, 384, (3, 3)) | |||
| self.maxpool6 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |||
| self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 1))) | |||
| self.convBNRelu5 = ConvBNRelu(384, conv_out_dim, (2, 2), pad_mode='valid') | |||
| self.dropout = nn.Dropout(keep_prob=0.5) | |||
| self.squeeze = P.Squeeze(2) | |||
| self.cast = P.Cast() | |||
| def construct(self, x): | |||
| x = self.convRelu1(x) | |||
| x = self.maxpool1(x) | |||
| x = self.convRelu2(x) | |||
| x = self.maxpool2(x) | |||
| x = self.convBNRelu1(x) | |||
| x = self.convRelu3(x) | |||
| x = self.maxpool3(x) | |||
| x = self.convBNRelu2(x) | |||
| x = self.convRelu4(x) | |||
| x = self.maxpool4(x) | |||
| x = self.convBNRelu3(x) | |||
| x = self.convRelu5(x) | |||
| x = self.maxpool5(x) | |||
| x = self.convBNRelu4(x) | |||
| x = self.convRelu6(x) | |||
| x = self.maxpool6(x) | |||
| x = self.pad(x) | |||
| x = self.convBNRelu5(x) | |||
| x = self.dropout(x) | |||
| x = self.squeeze(x) | |||
| return x | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """Config parameters for CRNN-Seq2Seq-OCR model.""" | |||
| from easydict import EasyDict as ed | |||
| config = ed({ | |||
| # dataset-related | |||
| "mindrecord_dir": "", | |||
| "data_root": "", | |||
| "annotation_file": "", | |||
| "val_data_root": "", | |||
| "val_annotation_file": "", | |||
| "data_json": "", | |||
| "characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3}, | |||
| "labels_not_use": [u'%#�?%', u'%#背景#%', u'%#不识�?%', u'#%不识�?#', u'%#模糊#%', u'%#模糊#%'], | |||
| "vocab_path": "./general_chars.txt", | |||
| #model-related | |||
| "img_width": 512, | |||
| "img_height": 128, | |||
| "channel_size": 3, | |||
| "conv_out_dim": 384, | |||
| "encoder_hidden_size": 128, | |||
| "decoder_hidden_size": 128, | |||
| "decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space | |||
| "dropout_p": 0.1, | |||
| "max_length": 64, | |||
| "attn_num_layers": 1, | |||
| "teacher_force_ratio": 0.5, | |||
| #optimizer-related | |||
| "lr": 0.0008, | |||
| "adam_beta1": 0.5, | |||
| "adam_beta2": 0.999, | |||
| "loss_scale": 1024, | |||
| #train-related | |||
| "batch_size": 32, | |||
| "num_epochs": 20, | |||
| "keep_checkpoint_max": 20, | |||
| #eval-related | |||
| "eval_batch_size": 32 | |||
| }) | |||
| @@ -0,0 +1,245 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Create FSNS MindRecord files.""" | |||
| import os | |||
| import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.config import config | |||
| from src.utils import initialize_vocabulary | |||
| def serialize_annotation(img_path, lex, vocab): | |||
| go_id = config.characters_dictionary.get("go_id") | |||
| eos_id = config.characters_dictionary.get("eos_id") | |||
| word = [go_id] | |||
| for special_label in config.labels_not_use: | |||
| if lex == special_label: | |||
| if config.print_no_train_label: | |||
| print("label in for image: %s is special label, related label is: %s, skip ..." % (img_path, lex)) | |||
| return None | |||
| for c in lex: | |||
| if c not in vocab: | |||
| return None | |||
| c_idx = vocab.get(c) | |||
| word.append(c_idx) | |||
| word.append(eos_id) | |||
| word = np.array(word, dtype=np.int32) | |||
| return word | |||
| def create_fsns_label(image_dir, anno_file_dirs): | |||
| """Get image path and annotation.""" | |||
| if not os.path.isdir(image_dir): | |||
| raise ValueError(f'Cannot find {image_dir} dataset path.') | |||
| image_files_dict = {} | |||
| image_anno_dict = {} | |||
| images = [] | |||
| img_id = 0 | |||
| for anno_file_dir in anno_file_dirs: | |||
| anno_file = open(anno_file_dir, 'r').readlines() | |||
| for line in anno_file: | |||
| file_name = line.split('\t')[0] | |||
| labels = line.split('\t')[1].split('\n')[0] | |||
| image_path = os.path.join(image_dir, file_name) | |||
| if not os.path.isfile(image_path): | |||
| print(f'Cannot find image {image_path} according to annotations.') | |||
| continue | |||
| if labels: | |||
| images.append(img_id) | |||
| image_files_dict[img_id] = image_path | |||
| image_anno_dict[img_id] = labels | |||
| img_id += 1 | |||
| return images, image_files_dict, image_anno_dict | |||
| def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8): | |||
| anno_file_dirs = [config.train_annotation_file] | |||
| images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root, | |||
| anno_file_dirs=anno_file_dirs) | |||
| vocab, _ = initialize_vocabulary(config.vocab_path) | |||
| data_schema = {"image": {"type": "bytes"}, | |||
| "label": {"type": "int32", "shape": [-1]}, | |||
| "decoder_input": {"type": "int32", "shape": [-1]}, | |||
| "decoder_mask": {"type": "int32", "shape": [-1]}, | |||
| "decoder_target": {"type": "int32", "shape": [-1]}, | |||
| "annotation": {"type": "string"}} | |||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||
| writer = FileWriter(mindrecord_path, file_num) | |||
| writer.add_schema(data_schema, "ocr") | |||
| for img_id in images: | |||
| image_path = image_path_dict[img_id] | |||
| annotation = image_anno_dict[img_id] | |||
| label_max_len = config.max_text_len | |||
| text_max_len = config.max_text_len - 2 | |||
| if len(annotation) > text_max_len: | |||
| continue | |||
| label = serialize_annotation(image_path, annotation, vocab) | |||
| if label is None: | |||
| continue | |||
| label_len = len(label) | |||
| decoder_input_len = label_max_len | |||
| if label_len <= decoder_input_len: | |||
| label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32))) | |||
| one_mask_len = label_len - config.go_shift | |||
| target_weight = np.concatenate((np.ones(one_mask_len, dtype=np.float32), | |||
| np.zeros(decoder_input_len - one_mask_len, dtype=np.float32))) | |||
| else: | |||
| continue | |||
| decoder_input = (np.array(label).T).astype(np.int32) | |||
| target_weight = (np.array(target_weight).T).astype(np.int32) | |||
| if not len(decoder_input) == len(target_weight): | |||
| continue | |||
| target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)] | |||
| target = (np.array(target)).astype(np.int32) | |||
| with open(image_path, 'rb') as f: | |||
| img = f.read() | |||
| row = {"image": img, | |||
| "label": label, | |||
| "decoder_input": decoder_input, | |||
| "decoder_mask": target_weight, | |||
| "decoder_target": target, | |||
| "annotation": str(annotation)} | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8): | |||
| anno_file_dirs = [config.train_annotation_file] | |||
| images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root, | |||
| anno_file_dirs=anno_file_dirs) | |||
| vocab, _ = initialize_vocabulary(config.vocab_path) | |||
| data_schema = {"image": {"type": "bytes"}, | |||
| "decoder_input": {"type": "int32", "shape": [-1]}, | |||
| "decoder_target": {"type": "int32", "shape": [-1]}, | |||
| "annotation": {"type": "string"}} | |||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||
| writer = FileWriter(mindrecord_path, file_num) | |||
| writer.add_schema(data_schema, "ocr") | |||
| for img_id in images: | |||
| image_path = image_path_dict[img_id] | |||
| annotation = image_anno_dict[img_id] | |||
| label_max_len = config.max_text_len | |||
| text_max_len = config.max_text_len - 2 | |||
| if len(annotation) > text_max_len: | |||
| continue | |||
| label = serialize_annotation(image_path, annotation, vocab) | |||
| if label is None: | |||
| continue | |||
| label_len = len(label) | |||
| decoder_input_len = label_max_len | |||
| if label_len <= decoder_input_len: | |||
| label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32))) | |||
| else: | |||
| continue | |||
| decoder_input = (np.array(label).T).astype(np.int32) | |||
| target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)] | |||
| target = (np.array(target)).astype(np.int32) | |||
| with open(image_path, 'rb') as f: | |||
| img = f.read() | |||
| row = {"image": img, | |||
| "decoder_input": decoder_input, | |||
| "decoder_target": target, | |||
| "annotation": str(annotation)} | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True): | |||
| print("Start creating dataset!") | |||
| if is_training: | |||
| mindrecord_dir = os.path.join(config.mindrecord_dir, "train") | |||
| mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")] | |||
| if not os.path.exists(mindrecord_files[0]): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if dataset == "fsns": | |||
| if os.path.isdir(config.data_root): | |||
| print("Create FSNS Mindrecord files for train pipeline.") | |||
| fsns_train_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix, file_num=8) | |||
| print("Create FSNS Mindrecord files for train pipeline Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("{} not exits!".format(config.data_root)) | |||
| else: | |||
| print("{} dataset is not defined!".format(dataset)) | |||
| if not is_training: | |||
| mindrecord_dir = os.path.join(config.mindrecord_dir, "val") | |||
| mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")] | |||
| if not os.path.exists(mindrecord_files[0]): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if dataset == "fsns": | |||
| if os.path.isdir(config.val_data_root): | |||
| print("Create FSNS Mindrecord files for val pipeline.") | |||
| fsns_val_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix) | |||
| print("Create FSNS Mindrecord files for val pipeline Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("{} not exits!".format(config.val_data_root)) | |||
| else: | |||
| print("{} dataset is not defined!".format(dataset)) | |||
| return mindrecord_files | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FSNS dataset""" | |||
| import cv2 | |||
| import numpy as np | |||
| from PIL import Image | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| import mindspore.dataset.vision.py_transforms as P | |||
| import mindspore.dataset.transforms.c_transforms as ops | |||
| import mindspore.common.dtype as mstype | |||
| from src.config import config | |||
| class AugmentationOps(): | |||
| def __init__(self, min_area_ratio=0.8, aspect_ratio_range=(0.8, 1.2), brightness=32./255., | |||
| contrast=0.5, saturation=0.5, hue=0.2, img_tile_shape=(150, 150)): | |||
| self.min_area_ratio = min_area_ratio | |||
| self.aspect_ratio_range = aspect_ratio_range | |||
| self.img_tile_shape = img_tile_shape | |||
| self.random_image_distortion_ops = P.RandomColorAdjust(brightness=brightness, | |||
| contrast=contrast, | |||
| saturation=saturation, | |||
| hue=hue) | |||
| def __call__(self, img): | |||
| img_h = self.img_tile_shape[0] | |||
| img_w = self.img_tile_shape[1] | |||
| img_new = np.zeros([128, 512, 3]) | |||
| for i in range(4): | |||
| img_tile = img[:, (i*150):((i+1)*150), :] | |||
| # Random crop cut from the street sign image, resized to the same size. | |||
| # Assures that the crop covers at least 0.8 area of the input image. | |||
| # Aspect ratio of cropped image is within [0.8,1.2] range. | |||
| h = img_h + 1 | |||
| w = img_w + 1 | |||
| while (w >= img_w or h >= img_h): | |||
| aspect_ratio = np.random.uniform(self.aspect_ratio_range[0], | |||
| self.aspect_ratio_range[1]) | |||
| h_low = np.ceil(np.sqrt(self.min_area_ratio * img_h * img_w / aspect_ratio)) | |||
| h_high = np.floor(np.sqrt(img_h * img_w / aspect_ratio)) | |||
| h = np.random.randint(h_low, h_high) | |||
| w = int(h * aspect_ratio) | |||
| y = np.random.randint(img_w - w) | |||
| x = np.random.randint(img_h - h) | |||
| img_tile = img_tile[x:(x+h), y:(y+w), :] | |||
| # Randomly chooses one of the 4 interpolation resize methods. | |||
| interpolation = np.random.choice([cv2.INTER_LINEAR, | |||
| cv2.INTER_CUBIC, | |||
| cv2.INTER_AREA, | |||
| cv2.INTER_NEAREST]) | |||
| img_tile = cv2.resize(img_tile, (128, 128), interpolation=interpolation) | |||
| # Random color distortion ops. | |||
| img_tile_pil = Image.fromarray(img_tile) | |||
| img_tile_pil = self.random_image_distortion_ops(img_tile_pil) | |||
| img_tile = np.array(img_tile_pil) | |||
| img_new[:, (i*128):((i+1)*128), :] = img_tile | |||
| img_new = 2 * (img_new / 255.) - 1 | |||
| return img_new | |||
| class ImageResizeWithRescale(): | |||
| def __init__(self, standard_img_height, standard_img_width, channel_size=3): | |||
| self.standard_img_height = standard_img_height | |||
| self.standard_img_width = standard_img_width | |||
| self.channel_size = channel_size | |||
| def __call__(self, img): | |||
| img = cv2.resize(img, (self.standard_img_width, self.standard_img_height)) | |||
| img = 2 * (img / 255.) - 1 | |||
| return img | |||
| def random_teacher_force(images, source_ids, target_ids): | |||
| teacher_force = np.random.random() < config.teacher_force_ratio | |||
| teacher_force_array = np.array([teacher_force], dtype=bool) | |||
| return images, source_ids, target_ids, teacher_force_array | |||
| def create_ocr_train_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0, | |||
| is_training=True, num_parallel_workers=4, use_multiprocessing=True): | |||
| ds = de.MindDataset(mindrecord_file, | |||
| columns_list=["image", "decoder_input", "decoder_target"], | |||
| num_shards=rank_size, | |||
| shard_id=rank_id, | |||
| num_parallel_workers=num_parallel_workers, | |||
| shuffle=is_training) | |||
| aug_ops = AugmentationOps() | |||
| transforms = [C.Decode(), | |||
| aug_ops, | |||
| C.HWC2CHW()] | |||
| ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"]) | |||
| ds = ds.map(operations=random_teacher_force, input_columns=["image", "decoder_input", "decoder_target"], | |||
| output_columns=["image", "decoder_input", "decoder_target", "teacher_force"], | |||
| column_order=["image", "decoder_input", "decoder_target", "teacher_force"]) | |||
| type_cast_op_bool = ops.TypeCast(mstype.bool_) | |||
| ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force") | |||
| print("Train dataset size= %s" % (int(ds.get_dataset_size()))) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| return ds | |||
| def create_ocr_val_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0, | |||
| num_parallel_workers=4, use_multiprocessing=True): | |||
| ds = de.MindDataset(mindrecord_file, | |||
| columns_list=["image", "annotation", "decoder_input", "decoder_target"], | |||
| num_shards=rank_size, | |||
| shard_id=rank_id, | |||
| num_parallel_workers=num_parallel_workers, | |||
| shuffle=False) | |||
| resize_rescale_op = ImageResizeWithRescale(standard_img_height=128, standard_img_width=512) | |||
| transforms = [C.Decode(), | |||
| resize_rescale_op, | |||
| C.HWC2CHW()] | |||
| ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing, | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"], | |||
| python_multiprocessing=use_multiprocessing, num_parallel_workers=8) | |||
| ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_input"], | |||
| python_multiprocessing=use_multiprocessing, num_parallel_workers=8) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| print("Val dataset size= %s" % (str(int(ds.get_dataset_size())*batch_size))) | |||
| return ds | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """ | |||
| GRU cell | |||
| """ | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from src.weight_init import gru_default_state | |||
| class GRU(nn.Cell): | |||
| ''' | |||
| GRU model | |||
| Args: | |||
| input_size: The number of expected features in the input | |||
| hidden_size: The number of features in the hidden state | |||
| ''' | |||
| def __init__(self, input_size, hidden_size): | |||
| super(GRU, self).__init__() | |||
| self.input_size = input_size | |||
| self.hidden_size = hidden_size | |||
| self.weight_i, self.weight_h, self.bias_i, self.bias_h = gru_default_state(self.input_size, self.hidden_size) | |||
| self.rnn = P.DynamicGRUV2() | |||
| self.cast = P.Cast() | |||
| def construct(self, x, h): | |||
| ''' | |||
| GRU construction | |||
| Args: | |||
| x(Tensor): GRU input | |||
| h(Tensor): GRU hidden state | |||
| Returns: | |||
| output(Tensor): rnn output | |||
| hidden(Tensor): hidden state | |||
| ''' | |||
| x = self.cast(x, mstype.float16) | |||
| h = self.cast(h, mstype.float16) | |||
| y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, h) | |||
| return y1, h1 | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Custom Logger.""" | |||
| import os | |||
| import sys | |||
| import logging | |||
| from datetime import datetime | |||
| class LOGGER(logging.Logger): | |||
| """ | |||
| Logger. | |||
| Args: | |||
| logger_name: String. Logger name. | |||
| rank: Integer. Rank id. | |||
| """ | |||
| def __init__(self, logger_name, rank=0): | |||
| super(LOGGER, self).__init__(logger_name) | |||
| self.rank = rank | |||
| if rank % 8 == 0: | |||
| console = logging.StreamHandler(sys.stdout) | |||
| console.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| console.setFormatter(formatter) | |||
| self.addHandler(console) | |||
| def setup_logging_file(self, log_dir, rank=0): | |||
| """Setup logging file.""" | |||
| self.rank = rank | |||
| if not os.path.exists(log_dir): | |||
| os.makedirs(log_dir, exist_ok=True) | |||
| log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) | |||
| self.log_fn = os.path.join(log_dir, log_name) | |||
| fh = logging.FileHandler(self.log_fn) | |||
| fh.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |||
| fh.setFormatter(formatter) | |||
| self.addHandler(fh) | |||
| def info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO): | |||
| self._log(logging.INFO, msg, args, **kwargs) | |||
| def save_args(self, args): | |||
| self.info('Args:') | |||
| args_dict = vars(args) | |||
| for key in args_dict.keys(): | |||
| self.info('--> %s: %s', key, args_dict[key]) | |||
| self.info('') | |||
| def important_info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO) and self.rank == 0: | |||
| line_width = 2 | |||
| important_msg = '\n' | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += '*'*line_width + ' '*8 + msg + '\n' | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| self.info(important_msg, *args, **kwargs) | |||
| def get_logger(path, rank): | |||
| """Get Logger.""" | |||
| logger = LOGGER('crnn-seq2seq-ocr', rank) | |||
| logger.setup_logging_file(path, rank) | |||
| return logger | |||
| @@ -0,0 +1,196 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """lstm""" | |||
| import math | |||
| import numpy as np | |||
| from mindspore import nn, context, Tensor, Parameter, ParameterTuple | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import operations as P | |||
| @constexpr | |||
| def _create_sequence_length(shape): | |||
| num_step, batch_size, _ = shape | |||
| sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32) | |||
| return sequence_length | |||
| class LSTM(nn.Cell): | |||
| """ | |||
| Stacked LSTM (Long Short-Term Memory) layers. | |||
| Args: | |||
| input_size (int): Number of features of input. | |||
| hidden_size (int): Number of features of hidden layer. | |||
| num_layers (int): Number of layers of stacked LSTM . Default: 1. | |||
| has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True. | |||
| batch_first (bool): Specifies whether the first dimension of input is batch_size. Default: False. | |||
| dropout (float, int): If not 0, append `Dropout` layer on the outputs of each | |||
| LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0]. | |||
| bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False. | |||
| Inputs: | |||
| - **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or | |||
| (batch_size, seq_len, `input_size`). | |||
| - **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or | |||
| mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`). | |||
| Data type of `hx` must be the same as `input`. | |||
| Outputs: | |||
| Tuple, a tuple contains (`output`, (`h_n`, `c_n`)). | |||
| - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`). | |||
| - **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape | |||
| (num_directions * `num_layers`, batch_size, `hidden_size`). | |||
| """ | |||
| def __init__(self, | |||
| input_size, | |||
| hidden_size, | |||
| num_layers=1, | |||
| has_bias=True, | |||
| batch_first=False, | |||
| dropout=0, | |||
| bidirectional=False): | |||
| super(LSTM, self).__init__() | |||
| self.is_ascend = context.get_context("device_target") == "Ascend" | |||
| self.batch_first = batch_first | |||
| self.transpose = P.Transpose() | |||
| self.num_layers = num_layers | |||
| self.bidirectional = bidirectional | |||
| self.dropout = dropout | |||
| self.lstm = P.LSTM(input_size=input_size, | |||
| hidden_size=hidden_size, | |||
| num_layers=num_layers, | |||
| has_bias=has_bias, | |||
| bidirectional=bidirectional, | |||
| dropout=float(dropout)) | |||
| weight_size = 0 | |||
| gate_size = 4 * hidden_size | |||
| stdv = 1 / math.sqrt(hidden_size) | |||
| num_directions = 2 if bidirectional else 1 | |||
| if self.is_ascend: | |||
| self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0) | |||
| self.concat = P.Concat(axis=0) | |||
| self.concat_2dim = P.Concat(axis=2) | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| if dropout < 0 or dropout > 1: | |||
| raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout)) | |||
| if dropout == 1: | |||
| self.dropout_op = P.ZerosLike() | |||
| else: | |||
| self.dropout_op = nn.Dropout(float(1 - dropout)) | |||
| b0 = np.zeros(gate_size, dtype=np.float32) | |||
| self.w_list = [] | |||
| self.b_list = [] | |||
| self.rnns_fw = P.DynamicRNN(forget_bias=0.0) | |||
| self.rnns_bw = P.DynamicRNN(forget_bias=0.0) | |||
| for layer in range(num_layers): | |||
| w_shape = input_size if layer == 0 else (num_directions * hidden_size) | |||
| w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32) | |||
| self.w_list.append(Parameter( | |||
| initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer))) | |||
| if has_bias: | |||
| b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float32) | |||
| self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer))) | |||
| else: | |||
| self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer))) | |||
| if bidirectional: | |||
| w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32) | |||
| self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]), | |||
| name='weight_bw' + str(layer))) | |||
| b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float32) if has_bias else b0 | |||
| self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]), | |||
| name='bias_bw' + str(layer))) | |||
| self.w_list = ParameterTuple(self.w_list) | |||
| self.b_list = ParameterTuple(self.b_list) | |||
| else: | |||
| for layer in range(num_layers): | |||
| input_layer_size = input_size if layer == 0 else hidden_size * num_directions | |||
| increment_size = gate_size * input_layer_size | |||
| increment_size += gate_size * hidden_size | |||
| if has_bias: | |||
| increment_size += 2 * gate_size | |||
| weight_size += increment_size * num_directions | |||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||
| self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||
| def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias): | |||
| """stacked bidirectional dynamic_rnn""" | |||
| x_shape = self.shape(x) | |||
| sequence_length = _create_sequence_length(x_shape) | |||
| pre_layer = x | |||
| hn = () | |||
| cn = () | |||
| output = x | |||
| for i in range(self.num_layers): | |||
| offset = i * 2 | |||
| weight_fw, weight_bw = weight[offset], weight[offset + 1] | |||
| bias_fw, bias_bw = bias[offset], bias[offset + 1] | |||
| init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :] | |||
| init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :] | |||
| bw_x = self.reverse_seq(pre_layer, sequence_length) | |||
| y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw) | |||
| y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw) | |||
| y_bw = self.reverse_seq(y_bw, sequence_length) | |||
| output = self.concat_2dim((y, y_bw)) | |||
| pre_layer = self.dropout_op(output) if self.dropout else output | |||
| hn += (h[-1:, :, :],) | |||
| hn += (h_bw[-1:, :, :],) | |||
| cn += (c[-1:, :, :],) | |||
| cn += (c_bw[-1:, :, :],) | |||
| status_h = self.concat(hn) | |||
| status_c = self.concat(cn) | |||
| return output, status_h, status_c | |||
| def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias): | |||
| """stacked mutil_layer dynamic_rnn""" | |||
| pre_layer = x | |||
| hn = () | |||
| cn = () | |||
| y = 0 | |||
| for i in range(self.num_layers): | |||
| weight_fw, bias_bw = weight[i], bias[i] | |||
| init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :] | |||
| y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw) | |||
| pre_layer = self.dropout_op(y) if self.dropout else y | |||
| hn += (h[-1:, :, :],) | |||
| cn += (c[-1:, :, :],) | |||
| status_h = self.concat(hn) | |||
| status_c = self.concat(cn) | |||
| return y, status_h, status_c | |||
| def construct(self, x, hx): | |||
| if self.batch_first: | |||
| x = self.transpose(x, (1, 0, 2)) | |||
| h, c = hx | |||
| if self.is_ascend: | |||
| x = self.cast(x, mstype.float16) | |||
| h = self.cast(h, mstype.float16) | |||
| c = self.cast(c, mstype.float16) | |||
| if self.bidirectional: | |||
| x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list) | |||
| else: | |||
| x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list) | |||
| else: | |||
| x, h, c, _, _ = self.lstm(x, h, c, self.weight) | |||
| if self.batch_first: | |||
| x = self.transpose(x, (1, 0, 2)) | |||
| return x, (h, c) | |||
| @@ -0,0 +1,165 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """ | |||
| Seq2Seq_OCR model. | |||
| """ | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from src.cnn import CNN | |||
| from src.gru import GRU | |||
| from src.lstm import LSTM | |||
| from src.weight_init import lstm_default_state | |||
| class BidirectionalLSTM(nn.Cell): | |||
| """Bidirectional LSTM with a Dense layer | |||
| Args: | |||
| batch_size(int): batch size of input data | |||
| input_size(int): Size of time sequence | |||
| hidden_size(int): the hidden size of LSTM layers | |||
| output_size(int): the output size of the dense layer | |||
| """ | |||
| def __init__(self, batch_size, input_size, hidden_size, output_size): | |||
| super(BidirectionalLSTM, self).__init__() | |||
| self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True).to_float(mstype.float16) | |||
| self.h, self.c = lstm_default_state(batch_size, hidden_size, bidirectional=True) | |||
| self.embedding = nn.Dense(hidden_size * 2, output_size).to_float(mstype.float16) | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.cast = P.Cast() | |||
| def construct(self, inputs): | |||
| inputs = self.cast(inputs, mstype.float16) | |||
| recurrent, _ = self.rnn(inputs, (self.h, self.c)) | |||
| T, b, h = self.shape(recurrent) | |||
| t_rec = self.reshape(recurrent, (T * b, h)) | |||
| output = self.embedding(t_rec) | |||
| output = self.reshape(output, (T, b, -1)) | |||
| return output | |||
| class AttnDecoderRNN(nn.Cell): | |||
| """Attention Decoder Structure with a one-layer GRU | |||
| Args: | |||
| hidden_size(int): the hidden size | |||
| output_size(int): the output size | |||
| max_length(iht): max time step of the decoder | |||
| dropout_p(float): dropout probability, default is 0.1 | |||
| """ | |||
| def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1): | |||
| super(AttnDecoderRNN, self).__init__() | |||
| self.hidden_size = hidden_size | |||
| self.output_size = output_size | |||
| self.dropout_p = dropout_p | |||
| self.max_length = max_length | |||
| self.embedding = nn.Embedding(self.output_size, self.hidden_size) | |||
| self.attn = nn.Dense(in_channels=self.hidden_size * 2, out_channels=self.max_length).to_float(mstype.float16) | |||
| self.attn_combine = nn.Dense(in_channels=self.hidden_size * 2, | |||
| out_channels=self.hidden_size).to_float(mstype.float16) | |||
| self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p) | |||
| self.gru = GRU(hidden_size, hidden_size).to_float(mstype.float16) | |||
| self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16) | |||
| self.transpose = P.Transpose() | |||
| self.concat = P.Concat(axis=2) | |||
| self.concat1 = P.Concat(axis=1) | |||
| self.softmax = P.Softmax(axis=1) | |||
| self.relu = P.ReLU() | |||
| self.log_softmax = P.LogSoftmax(axis=1) | |||
| self.bmm = P.BatchMatMul() | |||
| self.unsqueeze = P.ExpandDims() | |||
| self.squeeze = P.Squeeze(1) | |||
| self.squeeze1 = P.Squeeze(0) | |||
| self.cast = P.Cast() | |||
| def construct(self, inputs, hidden, encoder_outputs): | |||
| embedded = self.embedding(inputs) | |||
| embedded = self.transpose(embedded, (1, 0, 2)) | |||
| embedded = self.dropout(embedded) | |||
| embedded = self.cast(embedded, mstype.float16) | |||
| embedded_concat = self.concat((embedded, hidden)) | |||
| embedded_concat = self.squeeze1(embedded_concat) | |||
| attn_weights = self.softmax(self.attn(embedded_concat)) | |||
| attn_weights = self.unsqueeze(attn_weights, 1) | |||
| perm_encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2)) | |||
| attn_applied = self.bmm(attn_weights, perm_encoder_outputs) | |||
| attn_applied = self.squeeze(attn_applied) | |||
| embedded_squeeze = self.squeeze1(embedded) | |||
| output = self.concat1((embedded_squeeze, attn_applied)) | |||
| output = self.attn_combine(output) | |||
| output = self.unsqueeze(output, 0) | |||
| output = self.relu(output) | |||
| gru_hidden = self.squeeze1(hidden) | |||
| output, hidden, _, _, _, _ = self.gru(output, gru_hidden) | |||
| output = self.squeeze1(output) | |||
| output = self.log_softmax(self.out(output)) | |||
| return output, hidden, attn_weights | |||
| class Encoder(nn.Cell): | |||
| """Encoder with a CNN and two BidirectionalLSTM layers | |||
| Args: | |||
| batch_size(int): batch size of input data | |||
| conv_out_dim(int): the output dimension of the cnn layer | |||
| hidden_size(int): the hidden size of LSTM layers | |||
| """ | |||
| def __init__(self, batch_size, conv_out_dim, hidden_size): | |||
| super(Encoder, self).__init__() | |||
| self.cnn = CNN(int(conv_out_dim/4)) | |||
| self.lstm1 = BidirectionalLSTM(batch_size, conv_out_dim, hidden_size, hidden_size).to_float(mstype.float16) | |||
| self.lstm2 = BidirectionalLSTM(batch_size, hidden_size, hidden_size, hidden_size).to_float(mstype.float16) | |||
| self.transpose = P.Transpose() | |||
| self.cast = P.Cast() | |||
| self.split = P.Split(axis=3, output_num=4) | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, inputs): | |||
| inputs = self.cast(inputs, mstype.float32) | |||
| (x1, x2, x3, x4) = self.split(inputs) | |||
| conv1 = self.cnn(x1) | |||
| conv2 = self.cnn(x2) | |||
| conv3 = self.cnn(x3) | |||
| conv4 = self.cnn(x4) | |||
| conv = self.concat((conv1, conv2, conv3, conv4)) | |||
| conv = self.transpose(conv, (2, 0, 1)) | |||
| output = self.lstm1(conv) | |||
| output = self.lstm2(output) | |||
| return output | |||
| class Decoder(nn.Cell): | |||
| """Decoder | |||
| Args: | |||
| hidden_size(int): the hidden size | |||
| output_size(int): the output size | |||
| max_length(iht): max time step of the decoder | |||
| dropout_p(float): dropout probability, default is 0.1 | |||
| """ | |||
| def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1): | |||
| super(Decoder, self).__init__() | |||
| self.decoder = AttnDecoderRNN(hidden_size, output_size, max_length, dropout_p) | |||
| def construct(self, inputs, hidden, encoder_outputs): | |||
| return self.decoder(inputs, hidden, encoder_outputs) | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Util class or function.""" | |||
| from __future__ import absolute_import, division, print_function, unicode_literals | |||
| import os | |||
| import codecs | |||
| import logging | |||
| def initialize_vocabulary(vocabulary_path): | |||
| """ | |||
| initialize vocabulary from file. | |||
| assume the vocabulary is stored one-item-per-line | |||
| """ | |||
| characters_class = 9999 | |||
| if os.path.exists(vocabulary_path): | |||
| rev_vocab = [] | |||
| with codecs.open(vocabulary_path, 'r', encoding='utf-8') as voc_file: | |||
| rev_vocab = [line.strip() for line in voc_file] | |||
| vocab = {x: y for (y, x) in enumerate(rev_vocab)} | |||
| reserved_char_size = characters_class - len(rev_vocab) | |||
| if reserved_char_size < 0: | |||
| raise ValueError("Number of characters in vocabulary is equal or larger than config.characters_class") | |||
| for _ in range(reserved_char_size): | |||
| rev_vocab.append('') | |||
| # put space at the last position | |||
| vocab[' '] = len(rev_vocab) | |||
| rev_vocab.append(' ') | |||
| logging.info("Initializing vocabulary ends: %s", vocabulary_path) | |||
| return vocab, rev_vocab | |||
| raise ValueError("Initializing vocabulary ends: %s" % vocabulary_path) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| #" ============================================================================ | |||
| """ | |||
| weights initialization | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| from mindspore import Tensor, Parameter | |||
| def lstm_default_state(batch_size, hidden_size, bidirectional, num_layers=1): | |||
| """init default input.""" | |||
| num_directions = 2 if bidirectional else 1 | |||
| h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||
| return h, c | |||
| def gru_default_state(input_size, hidden_size): | |||
| stdv = 1 / math.sqrt(hidden_size) | |||
| weight_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), | |||
| name='weight_i') | |||
| weight_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)), | |||
| name='weight_h') | |||
| bias_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), | |||
| name='bias_i') | |||
| bias_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)), | |||
| name='bias_h') | |||
| return weight_i, weight_h, bias_i, bias_h | |||
| @@ -0,0 +1,158 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| CRNN-Seq2Seq-OCR train. | |||
| """ | |||
| import os | |||
| import argparse | |||
| import datetime | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.train.model import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.common import set_seed | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.callback import ModelCheckpoint | |||
| from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config | |||
| from src.dataset import create_ocr_train_dataset | |||
| from src.logger import get_logger | |||
| from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper | |||
| set_seed(1) | |||
| def parse_args(): | |||
| """Parse train arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training') | |||
| # device related | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="device where the code will be implemented.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=0, | |||
| help='Distribute train or not, 1 for yes, 0 for no. Default: 0') | |||
| parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0') | |||
| parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1') | |||
| #dataset related | |||
| parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.') | |||
| # logging related | |||
| parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100') | |||
| parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/') | |||
| parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.') | |||
| parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None') | |||
| parser.add_argument('--is_save_on_master', type=int, default=0, | |||
| help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0') | |||
| args, _ = parser.parse_known_args() | |||
| # logger | |||
| args.outputs_dir = os.path.join(args.ckpt_path, | |||
| datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) | |||
| return args | |||
| def train(): | |||
| """Train function.""" | |||
| args = parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
| if args.is_distributed: | |||
| rank = args.rank_id | |||
| device_num = args.device_num | |||
| context.set_auto_parallel_context(device_num=device_num, | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| init() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| # Logger | |||
| args.logger = get_logger(args.outputs_dir, rank) | |||
| args.rank_save_ckpt_flag = 0 | |||
| if args.is_save_on_master: | |||
| if rank == 0: | |||
| args.rank_save_ckpt_flag = 1 | |||
| else: | |||
| args.rank_save_ckpt_flag = 1 | |||
| # DATASET | |||
| dataset = create_ocr_train_dataset(args.mindrecord_file, | |||
| config.batch_size, | |||
| rank_size=device_num, | |||
| rank_id=rank) | |||
| args.steps_per_epoch = dataset.get_dataset_size() | |||
| args.logger.info('Finish loading dataset') | |||
| if not args.ckpt_interval: | |||
| args.ckpt_interval = args.steps_per_epoch | |||
| args.logger.save_args(args) | |||
| network = AttentionOCR(config.batch_size, | |||
| int(config.img_width / 4), | |||
| config.encoder_hidden_size, | |||
| config.decoder_hidden_size, | |||
| config.decoder_output_size, | |||
| config.max_length, | |||
| config.dropout_p) | |||
| if args.pre_checkpoint_path: | |||
| param_dict = load_checkpoint(args.pre_checkpoint_path) | |||
| load_param_into_net(network, param_dict) | |||
| network = AttentionOCRWithLossCell(network, config.max_length) | |||
| lr = Tensor(config.lr, mstype.float32) | |||
| opt = nn.Adam(network.trainable_params(), lr, beta1=config.adam_beta1, beta2=config.adam_beta2, | |||
| loss_scale=config.loss_scale) | |||
| network = TrainingWrapper(network, opt, sens=config.loss_scale) | |||
| args.logger.info('Finished get network') | |||
| callback = [TimeMonitor(data_size=1), LossMonitor()] | |||
| if args.rank_save_ckpt_flag: | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/') | |||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | |||
| directory=save_ckpt_path, | |||
| prefix="crnn_seq2seq_ocr") | |||
| callback.append(ckpt_cb) | |||
| model = Model(network) | |||
| model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False) | |||
| args.logger.info('==========Training Done===============') | |||
| if __name__ == "__main__": | |||
| train() | |||