| @@ -0,0 +1,293 @@ | |||
|  | |||
| <!-- TOC --> | |||
| # CTPN for Ascend | |||
| - [CTPN Description](#CTPN-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Features](#features) | |||
| - [Mixed Precision](#mixed-precision) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Training Process](#training-process) | |||
| - [Evaluation Process](#evaluation-process) | |||
| - [Evaluation](#evaluation) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Training Performance](#evaluation-performance) | |||
| - [Inference Performance](#evaluation-performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [CTPN Description](#contents) | |||
| CTPN is a text detection model based on object detection method. It improves Faster R-CNN and combines with bidirectional LSTM, so ctpn is very effective for horizontal text detection. Another highlight of ctpn is to transform the text detection task into a series of small-scale text box detection.This idea was proposed in the paper "Detecting Text in Natural Image with Connectionist Text Proposal Network". | |||
| [Paper](https://arxiv.org/pdf/1609.03605.pdf) Zhi Tian, Weilin Huang, Tong He, Pan He, Yu Qiao, "Detecting Text in Natural Image with Connectionist Text Proposal Network", ArXiv, vol. abs/1609.03605, 2016. | |||
| # [Model architecture](#contents) | |||
| The overall network architecture contains a VGG16 as backbone, and use bidirection lstm to extract context feature of the small-scale text box, then it used the RPN(RegionProposal Network) to predict the boundding box and probability. | |||
| [Link](https://arxiv.org/pdf/1605.07314v1.pdf) | |||
| # [Dataset](#contents) | |||
| Here we used 6 datasets for training, and 1 datasets for Evaluation. | |||
| - Dataset1: ICDAR 2013: Focused Scene Text | |||
| - Train: 142MB, 229 images | |||
| - Test: 110MB, 233 images | |||
| - Dataset2: ICDAR 2011: Born-Digital Images | |||
| - Train: 27.7MB, 410 images | |||
| - Dataset3: ICDAR 2015: | |||
| - Train:89MB, 1000 images | |||
| - Dataset4: SCUT-FORU: Flickr OCR Universal Database | |||
| - Train: 388MB, 1715 images | |||
| - Dataset5: CocoText v2(Subset of MSCOCO2017): | |||
| - Train: 13GB, 63686 images | |||
| - Dataset6: SVT(The Street View Dataset) | |||
| - Train: 115MB, 349 images | |||
| # [Features](#contents) | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Script description](#contents) | |||
| ## [Script and sample code](#contents) | |||
| ```shell | |||
| . | |||
| └─ctpn | |||
| ├── README.md # network readme | |||
| ├── eval.py # eval net | |||
| ├── scripts | |||
| │ ├── eval_res.sh # calculate precision and recall | |||
| │ ├── run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p) | |||
| │ ├── run_eval_ascend.sh # launch evaluating with ascend platform | |||
| │ └── run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p) | |||
| ├── src | |||
| │ ├── CTPN | |||
| │ │ ├── BoundingBoxDecode.py # bounding box decode | |||
| │ │ ├── BoundingBoxEncode.py # bounding box encode | |||
| │ │ ├── __init__.py # package init file | |||
| │ │ ├── anchor_generator.py # anchor generator | |||
| │ │ ├── bbox_assign_sample.py # proposal layer | |||
| │ │ ├── proposal_generator.py # proposla generator | |||
| │ │ ├── rpn.py # region-proposal network | |||
| │ │ └── vgg16.py # backbone | |||
| │ ├── config.py # training configuration | |||
| │ ├── convert_icdar2015.py # convert icdar2015 dataset label | |||
| │ ├── convert_svt.py # convert svt label | |||
| │ ├── create_dataset.py # create mindrecord dataset | |||
| │ ├── ctpn.py # ctpn network definition | |||
| │ ├── dataset.py # data proprocessing | |||
| │ ├── lr_schedule.py # learning rate scheduler | |||
| │ ├── network_define.py # network definition | |||
| │ └── text_connector | |||
| │ ├── __init__.py # package init file | |||
| │ ├── connect_text_lines.py # connect text lines | |||
| │ ├── detector.py # detect box | |||
| │ ├── get_successions.py # get succession proposal | |||
| │ └── utils.py # some functions which is commonly used | |||
| └── train.py # train net | |||
| ``` | |||
| ## [Training process](#contents) | |||
| ### Dataset | |||
| To create dataset, download the dataset first and deal with it.We provided src/convert_svt.py and src/convert_icdar2015.py to deal with svt and icdar2015 dataset label.For svt dataset, you can deal with it as below: | |||
| ```shell | |||
| python convert_svt.py --dataset_path=/path/img --xml_file=/path/train.xml --location_dir=/path/location | |||
| ``` | |||
| For ICDAR2015 dataset, you can deal with it | |||
| ```shell | |||
| python convert_icdar2015.py --src_label_path=/path/train_label --target_label_path=/path/label | |||
| ``` | |||
| Then modify the src/config.py and add the dataset path.For each path, add IMAGE_PATH and LABEL_PATH into a list in config.An example is show as blow: | |||
| ```python | |||
| # create dataset | |||
| "coco_root": "/path/coco", | |||
| "coco_train_data_type": "train2017", | |||
| "cocotext_json": "/path/cocotext.v2.json", | |||
| "icdar11_train_path": ["/path/image/", "/path/label"], | |||
| "icdar13_train_path": ["/path/image/", "/path/label"], | |||
| "icdar15_train_path": ["/path/image/", "/path/label"], | |||
| "icdar13_test_path": ["/path/image/", "/path/label"], | |||
| "flick_train_path": ["/path/image/", "/path/label"], | |||
| "svt_train_path": ["/path/image/", "/path/label"], | |||
| "pretrain_dataset_path": "", | |||
| "finetune_dataset_path": "", | |||
| "test_dataset_path": "", | |||
| ``` | |||
| Then you can create dataset with src/create_dataset.py with the command as below: | |||
| ```shell | |||
| python src/create_dataset.py | |||
| ``` | |||
| ### Usage | |||
| - Ascend: | |||
| ```bash | |||
| # distribute training example(8p) | |||
| sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH] | |||
| # standalone training | |||
| sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] | |||
| # evaluation: | |||
| sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ``` | |||
| The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).To get the vgg16 backbone, you can use the network structure defined in src/CTPN/vgg16.py.To train the backbone, copy the src/CTPN/vgg16.py under modelzoo/official/cv/vgg16/src/, and modify the vgg16/train.py to suit the new construction.You can fix it as below: | |||
| ```python | |||
| ... | |||
| from src.vgg16 import VGG16 | |||
| ... | |||
| network = VGG16() | |||
| ... | |||
| ``` | |||
| Then you can train it with ImageNet2012. | |||
| > Notes: | |||
| > RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. | |||
| > | |||
| > This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` | |||
| > | |||
| > TASK_TYPE contains Pretraining and Finetune. For Pretraining, we use ICDAR2013, ICDAR2015, SVT, SCUT-FORU, CocoText v2. For Finetune, we use ICDAR2011, | |||
| ICDAR2013, SCUT-FORU to improve precision and recall, and when doing Finetune, we use the checkpoint training in Pretrain as our PRETRAINED_PATH. | |||
| > COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text). | |||
| > | |||
| ### Launch | |||
| ```bash | |||
| # training example | |||
| shell: | |||
| Ascend: | |||
| # distribute training example(8p) | |||
| sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH] | |||
| # standalone training | |||
| sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] | |||
| ``` | |||
| ### Result | |||
| Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings. | |||
| ```python | |||
| 377 epoch: 1 step: 229 ,rpn_loss: 0.00355, rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00103, | |||
| 399 epoch: 2 step: 229 ,rpn_loss: 0.00327,rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00093, | |||
| 424 epoch: 3 step: 229 ,rpn_loss: 0.00910, rpn_cls_loss: 0.00385, rpn_reg_loss: 0.00175, | |||
| ``` | |||
| ## [Eval process](#contents) | |||
| ### Usage | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| - Ascend: | |||
| ```bash | |||
| sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ``` | |||
| After eval, you can get serval archive file named submit_ctpn-xx_xxxx.zip, which contains the name of your checkpoint file.To evalulate it, you can use the scripts provided by the ICDAR2013 network, you can download the Deteval scripts from the [link](https://rrc.cvc.uab.es/?com=downloads&action=download&ch=2&f=aHR0cHM6Ly9ycmMuY3ZjLnVhYi5lcy9zdGFuZGFsb25lcy9zY3JpcHRfdGVzdF9jaDJfdDFfZTItMTU3Nzk4MzA2Ny56aXA=) | |||
| After download the scripts, unzip it and put it under ctpn/scripts and use eval_res.sh to get the result.You will get files as below: | |||
| ```text | |||
| gt.zip | |||
| readme.txt | |||
| rrc_evalulation_funcs_1_1.py | |||
| script.py | |||
| ``` | |||
| Then you can run the scripts/eval_res.sh to calculate the evalulation result. | |||
| ```base | |||
| bash eval_res.sh | |||
| ``` | |||
| ### Result | |||
| Evaluation result will be stored in the example path, you can find result like the followings in `log`. | |||
| ```text | |||
| {"precision": 0.90791, "recall": 0.86118, "hmean": 0.88393} | |||
| ``` | |||
| # [Model description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ------------------------------------------------------------ | | |||
| | Model Version | CTPN | | |||
| | Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | | |||
| | uploaded Date | 02/06/2021 | | |||
| | MindSpore Version | 1.1.1 | | |||
| | Dataset | 16930 images | | |||
| | Batch_size | 2 | | |||
| | Training Parameters | src/config.py | | |||
| | Optimizer | Momentum | | |||
| | Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression| | |||
| | Loss | ~0.04 | | |||
| | Total time (8p) | 6h | | |||
| | Scripts | [ctpn script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ctpn) | | |||
| #### Inference Performance | |||
| | Parameters | Ascend | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | CTPN | | |||
| | Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | | |||
| | Uploaded Date | 02/06/2020 | | |||
| | MindSpore Version | 1.1.1 | | |||
| | Dataset | 229 images | | |||
| | Batch_size | 1 | | |||
| | Accuracy | precision=0.9079, recall=0.8611 F-measure:0.8839 | | |||
| | Total time | 1 min | | |||
| | Model for inference | 135M (.ckpt file) | | |||
| #### Training performance results | |||
| | **Ascend** | train performance | | |||
| | :--------: | :---------------: | | |||
| | 1p | 10 img/s | | |||
| | **Ascend** | train performance | | |||
| | :--------: | :---------------: | | |||
| | 8p | 84 img/s | | |||
| # [Description of Random Situation](#contents) | |||
| We set seed to 1 in train.py. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # less required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation for CTPN""" | |||
| import os | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.common import set_seed | |||
| from src.ctpn import CTPN | |||
| from src.config import config | |||
| from src.dataset import create_ctpn_dataset | |||
| from src.text_connector.detector import detect | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description="CTPN evaluation") | |||
| parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.") | |||
| parser.add_argument("--image_path", type=str, default="", help="Image path.") | |||
| parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''): | |||
| """ctpn infer.""" | |||
| print("ckpt path is {}".format(ckpt_path)) | |||
| ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False) | |||
| config.batch_size = config.test_batch_size | |||
| total = ds.get_dataset_size() | |||
| print("*************total dataset size is {}".format(total)) | |||
| net = CTPN(config, is_training=False) | |||
| param_dict = load_checkpoint(ckpt_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| eval_iter = 0 | |||
| print("\n========================================\n") | |||
| print("Processing, please wait a moment.") | |||
| img_basenames = [] | |||
| output_dir = os.path.join(os.getcwd(), "submit") | |||
| if not os.path.exists(output_dir): | |||
| os.mkdir(output_dir) | |||
| for file in os.listdir(img_dir): | |||
| img_basenames.append(os.path.basename(file)) | |||
| for data in ds.create_dict_iterator(): | |||
| img_data = data['image'] | |||
| img_metas = data['image_shape'] | |||
| gt_bboxes = data['box'] | |||
| gt_labels = data['label'] | |||
| gt_num = data['valid_num'] | |||
| start = time.time() | |||
| # run net | |||
| output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num) | |||
| gt_bboxes = gt_bboxes.asnumpy() | |||
| gt_labels = gt_labels.asnumpy() | |||
| gt_num = gt_num.asnumpy().astype(bool) | |||
| end = time.time() | |||
| proposal = output[0] | |||
| proposal_mask = output[1] | |||
| print("start to draw pic") | |||
| for j in range(config.test_batch_size): | |||
| img = img_basenames[config.test_batch_size * eval_iter + j] | |||
| all_box_tmp = proposal[j].asnumpy() | |||
| all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) | |||
| using_boxes_mask = all_box_tmp * all_mask_tmp | |||
| textsegs = using_boxes_mask[:, 0:4].astype(np.float32) | |||
| scores = using_boxes_mask[:, 4].astype(np.float32) | |||
| shape = img_metas.asnumpy()[0][:2].astype(np.int32) | |||
| bboxes = detect(textsegs, scores[:, np.newaxis], shape) | |||
| from PIL import Image, ImageDraw | |||
| im = Image.open(img_dir + '/' + img) | |||
| draw = ImageDraw.Draw(im) | |||
| image_h = img_metas.asnumpy()[j][2] | |||
| image_w = img_metas.asnumpy()[j][3] | |||
| gt_boxs = gt_bboxes[j][gt_num[j], :] | |||
| for gt_box in gt_boxs: | |||
| gt_x1 = gt_box[0] / image_w | |||
| gt_y1 = gt_box[1] / image_h | |||
| gt_x2 = gt_box[2] / image_w | |||
| gt_y2 = gt_box[3] / image_h | |||
| draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ | |||
| fill='green', width=2) | |||
| file_name = "res_" + img.replace("jpg", "txt") | |||
| output_file = os.path.join(output_dir, file_name) | |||
| f = open(output_file, 'w') | |||
| for bbox in bboxes: | |||
| x1 = bbox[0] / image_w | |||
| y1 = bbox[1] / image_h | |||
| x2 = bbox[2] / image_w | |||
| y2 = bbox[3] / image_h | |||
| draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) | |||
| str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) | |||
| f.write(str_tmp) | |||
| f.write("\n") | |||
| f.close() | |||
| im.save(img) | |||
| percent = round(eval_iter / total * 100, 2) | |||
| eval_iter = eval_iter + 1 | |||
| print("Iter {} cost time {}".format(eval_iter, end - start)) | |||
| print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r') | |||
| if __name__ == '__main__': | |||
| ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path) | |||
| @@ -0,0 +1,21 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| for submit_file in "submit"*.zip | |||
| do | |||
| echo "eval result for ${submit_file}" | |||
| python script.py –g=gt.zip –s=${submit_file} –o=./ | |||
| echo -e ".\n" | |||
| done | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 3 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $1) | |||
| echo $PATH1 | |||
| if [ ! -f $PATH1 ] | |||
| then | |||
| echo "error: RANK_TABLE_FILE=$PATH1 is not a file" | |||
| exit 1 | |||
| fi | |||
| TASK_TYPE=$2 | |||
| PATH2=$(get_real_path $3) | |||
| echo $PATH2 | |||
| if [ ! -f $PATH2 ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export RANK_TABLE_FILE=$PATH1 | |||
| for((i=0; i<${DEVICE_NUM}; i++)) | |||
| do | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$i | |||
| rm -rf ./train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp ../*.py ./train_parallel$i | |||
| cp *.sh ./train_parallel$i | |||
| cp -r ../src ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log & | |||
| cd .. | |||
| done | |||
| @@ -0,0 +1,80 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] | |||
| then | |||
| echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| IMAGE_PATH=$(get_real_path $1) | |||
| DATASET_PATH=$(get_real_path $2) | |||
| CHECKPOINT_PATH=$(get_real_path $3) | |||
| echo $IMAGE_PATH | |||
| echo $DATASET_PATH | |||
| echo $CHECKPOINT_PATH | |||
| if [ ! -d $IMAGE_PATH ] | |||
| then | |||
| echo "error: IMAGE_PATH=$PATH1 is not a path" | |||
| exit 1 | |||
| fi | |||
| if [ ! -f $DATASET_PATH ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path" | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $CHECKPOINT_PATH ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export RANK_SIZE=$DEVICE_NUM | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| for file in "${CHECKPOINT_PATH}"/*.ckpt | |||
| do | |||
| if [ -d "eval" ]; | |||
| then | |||
| rm -rf ./eval | |||
| fi | |||
| mkdir ./eval | |||
| cp ../*.py ./eval | |||
| cp *.sh ./eval | |||
| cp -r ../src ./eval | |||
| cd ./eval | |||
| env > env.log | |||
| CHECKPOINT_FILE_PATH=$file | |||
| echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" | |||
| python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log | |||
| echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" | |||
| cd ./submit | |||
| file_base_name=$(basename $file) | |||
| zip -r ../../submit_${file_base_name%.*}.zip *.txt | |||
| cd ../../ | |||
| done | |||
| @@ -0,0 +1,54 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 2 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| TASK_TYPE=$1 | |||
| PRETRAINED_PATH=$(get_real_path $2) | |||
| echo $PRETRAINED_PATH | |||
| if [ ! -f $PRETRAINED_PATH ] | |||
| then | |||
| echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| rm -rf ./train | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log & | |||
| cd .. | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| class BoundingBoxDecode(nn.Cell): | |||
| """ | |||
| BoundintBox Decoder. | |||
| Returns: | |||
| pred_box(Tensor): decoder bounding boxes. | |||
| """ | |||
| def __init__(self): | |||
| super(BoundingBoxDecode, self).__init__() | |||
| self.split = P.Split(axis=1, output_num=4) | |||
| self.ones = 1.0 | |||
| self.half = 0.5 | |||
| self.log = P.Log() | |||
| self.exp = P.Exp() | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, bboxes, deltas): | |||
| """ | |||
| boxes(Tensor): boundingbox. | |||
| deltas(Tensor): delta between boundingboxs and anchors. | |||
| """ | |||
| x1, y1, x2, y2 = self.split(bboxes) | |||
| width = x2 - x1 + self.ones | |||
| height = y2 - y1 + self.ones | |||
| ctr_x = x1 + self.half * width | |||
| ctr_y = y1 + self.half * height | |||
| _, dy, _, dh = self.split(deltas) | |||
| pred_ctr_x = ctr_x | |||
| pred_ctr_y = dy * height + ctr_y | |||
| pred_w = width | |||
| pred_h = self.exp(dh) * height | |||
| x1 = pred_ctr_x - self.half * pred_w | |||
| y1 = pred_ctr_y - self.half * pred_h | |||
| x2 = pred_ctr_x + self.half * pred_w | |||
| y2 = pred_ctr_y + self.half * pred_h | |||
| pred_box = self.concat((x1, y1, x2, y2)) | |||
| return pred_box | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| class BoundingBoxEncode(nn.Cell): | |||
| """ | |||
| BoundintBox Decoder. | |||
| Returns: | |||
| pred_box(Tensor): decoder bounding boxes. | |||
| """ | |||
| def __init__(self): | |||
| super(BoundingBoxEncode, self).__init__() | |||
| self.split = P.Split(axis=1, output_num=4) | |||
| self.ones = 1.0 | |||
| self.half = 0.5 | |||
| self.log = P.Log() | |||
| self.concat = P.Concat(axis=1) | |||
| def construct(self, anchor_box, gt_box): | |||
| """ | |||
| boxes(Tensor): boundingbox. | |||
| deltas(Tensor): delta between boundingboxs and anchors. | |||
| """ | |||
| x1, y1, x2, y2 = self.split(anchor_box) | |||
| width = x2 - x1 + self.ones | |||
| height = y2 - y1 + self.ones | |||
| ctr_x = x1 + self.half * width | |||
| ctr_y = y1 + self.half * height | |||
| gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box) | |||
| gt_width = gt_x2 - gt_x1 + self.ones | |||
| gt_height = gt_y2 - gt_y1 + self.ones | |||
| ctr_gt_x = gt_x1 + self.half * gt_width | |||
| ctr_gt_y = gt_y1 + self.half * gt_height | |||
| target_dx = (ctr_gt_x - ctr_x) / width | |||
| target_dy = (ctr_gt_y - ctr_y) / height | |||
| dw = gt_width / width | |||
| dh = gt_height / height | |||
| target_dw = self.log(dw) | |||
| target_dh = self.log(dh) | |||
| deltas = self.concat((target_dx, target_dy, target_dw, target_dh)) | |||
| return deltas | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FasterRcnn anchor generator.""" | |||
| import numpy as np | |||
| class AnchorGenerator(): | |||
| """Anchor generator for FasterRcnn.""" | |||
| def __init__(self, config): | |||
| """Anchor generator init method.""" | |||
| self.base_size = config.anchor_base | |||
| self.num_anchor = config.num_anchors | |||
| self.anchor_height = config.anchor_height | |||
| self.anchor_width = config.anchor_width | |||
| self.size = self.gen_anchor_size() | |||
| self.base_anchors = self.gen_base_anchors() | |||
| def gen_base_anchors(self): | |||
| """Generate a single anchor.""" | |||
| base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32) | |||
| anchors = np.zeros((len(self.size), 4), np.int32) | |||
| index = 0 | |||
| for h, w in self.size: | |||
| anchors[index] = self.scale_anchor(base_anchor, h, w) | |||
| index += 1 | |||
| return anchors | |||
| def gen_anchor_size(self): | |||
| """Generate a list of anchor size""" | |||
| size = [] | |||
| for width in self.anchor_width: | |||
| for height in self.anchor_height: | |||
| size.append((height, width)) | |||
| return size | |||
| def scale_anchor(self, anchor, h, w): | |||
| x_ctr = (anchor[0] + anchor[2]) * 0.5 | |||
| y_ctr = (anchor[1] + anchor[3]) * 0.5 | |||
| scaled_anchor = anchor.copy() | |||
| scaled_anchor[0] = x_ctr - w / 2 # xmin | |||
| scaled_anchor[2] = x_ctr + w / 2 # xmax | |||
| scaled_anchor[1] = y_ctr - h / 2 # ymin | |||
| scaled_anchor[3] = y_ctr + h / 2 # ymax | |||
| return scaled_anchor | |||
| def _meshgrid(self, x, y): | |||
| """Generate grid.""" | |||
| xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1) | |||
| yy = np.repeat(y, len(x)) | |||
| return xx, yy | |||
| def grid_anchors(self, featmap_size, stride=16): | |||
| """Generate anchor list.""" | |||
| base_anchors = self.base_anchors | |||
| feat_h, feat_w = featmap_size | |||
| shift_x = np.arange(0, feat_w) * stride | |||
| shift_y = np.arange(0, feat_h) * stride | |||
| shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) | |||
| shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) | |||
| shifts = shifts.astype(base_anchors.dtype) | |||
| all_anchors = base_anchors[None, :, :] + shifts[:, None, :] | |||
| all_anchors = all_anchors.reshape(-1, 4) | |||
| return all_anchors | |||
| @@ -0,0 +1,152 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FasterRcnn positive and negative sample screening for RPN.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from src.CTPN.BoundingBoxEncode import BoundingBoxEncode | |||
| class BboxAssignSample(nn.Cell): | |||
| """ | |||
| Bbox assigner and sampler definition. | |||
| Args: | |||
| config (dict): Config. | |||
| batch_size (int): Batchsize. | |||
| num_bboxes (int): The anchor nums. | |||
| add_gt_as_proposals (bool): add gt bboxes as proposals flag. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| bbox_targets: bbox location, (batch_size, num_bboxes, 4) | |||
| bbox_weights: bbox weights, (batch_size, num_bboxes, 1) | |||
| labels: label for every bboxes, (batch_size, num_bboxes, 1) | |||
| label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) | |||
| Examples: | |||
| BboxAssignSample(config, 2, 1024, True) | |||
| """ | |||
| def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): | |||
| super(BboxAssignSample, self).__init__() | |||
| cfg = config | |||
| self.batch_size = batch_size | |||
| self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) | |||
| self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) | |||
| self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) | |||
| self.zero_thr = Tensor(0.0, mstype.float16) | |||
| self.num_bboxes = num_bboxes | |||
| self.num_gts = cfg.num_gts | |||
| self.num_expected_pos = cfg.num_expected_pos | |||
| self.num_expected_neg = cfg.num_expected_neg | |||
| self.add_gt_as_proposals = add_gt_as_proposals | |||
| if self.add_gt_as_proposals: | |||
| self.label_inds = Tensor(np.arange(1, self.num_gts + 1)) | |||
| self.concat = P.Concat(axis=0) | |||
| self.max_gt = P.ArgMaxWithValue(axis=0) | |||
| self.max_anchor = P.ArgMaxWithValue(axis=1) | |||
| self.sum_inds = P.ReduceSum() | |||
| self.iou = P.IOU() | |||
| self.greaterequal = P.GreaterEqual() | |||
| self.greater = P.Greater() | |||
| self.select = P.Select() | |||
| self.gatherND = P.GatherNd() | |||
| self.squeeze = P.Squeeze() | |||
| self.cast = P.Cast() | |||
| self.logicaland = P.LogicalAnd() | |||
| self.less = P.Less() | |||
| self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) | |||
| self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | |||
| self.reshape = P.Reshape() | |||
| self.equal = P.Equal() | |||
| self.bounding_box_encode = BoundingBoxEncode() | |||
| self.scatterNdUpdate = P.ScatterNdUpdate() | |||
| self.scatterNd = P.ScatterNd() | |||
| self.logicalnot = P.LogicalNot() | |||
| self.tile = P.Tile() | |||
| self.zeros_like = P.ZerosLike() | |||
| self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
| self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) | |||
| self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) | |||
| self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) | |||
| self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) | |||
| self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) | |||
| self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) | |||
| self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) | |||
| self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) | |||
| self.print = P.Print() | |||
| def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): | |||
| gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ | |||
| (self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one) | |||
| bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ | |||
| (self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two) | |||
| overlaps = self.iou(bboxes, gt_bboxes_i) | |||
| max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) | |||
| _, max_overlaps_w_ac = self.max_anchor(overlaps) | |||
| neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \ | |||
| self.less(max_overlaps_w_gt, self.neg_iou_thr)) | |||
| assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) | |||
| pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr) | |||
| assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ | |||
| max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) | |||
| assigned_gt_inds4 = assigned_gt_inds3 | |||
| for j in range(self.num_gts): | |||
| max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] | |||
| overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::]) | |||
| pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \ | |||
| self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j)) | |||
| assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4) | |||
| assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores) | |||
| pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) | |||
| pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) | |||
| pos_check_valid = self.sum_inds(pos_check_valid, -1) | |||
| valid_pos_index = self.less(self.range_pos_size, pos_check_valid) | |||
| pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) | |||
| pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones | |||
| pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32) | |||
| pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1)) | |||
| neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) | |||
| num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) | |||
| num_pos = self.sum_inds(num_pos, -1) | |||
| unvalid_pos_index = self.less(self.range_pos_size, num_pos) | |||
| valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) | |||
| pos_bboxes_ = self.gatherND(bboxes, pos_index) | |||
| pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) | |||
| pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) | |||
| pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) | |||
| valid_pos_index = self.cast(valid_pos_index, mstype.int32) | |||
| valid_neg_index = self.cast(valid_neg_index, mstype.int32) | |||
| bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4)) | |||
| bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,)) | |||
| labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,)) | |||
| total_index = self.concat((pos_index, neg_index)) | |||
| total_valid_index = self.concat((valid_pos_index, valid_neg_index)) | |||
| label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,)) | |||
| return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \ | |||
| labels_total, self.cast(label_weights_total, mstype.bool_) | |||
| @@ -0,0 +1,190 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FasterRcnn proposal generator.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from src.CTPN.BoundingBoxDecode import BoundingBoxDecode | |||
| class Proposal(nn.Cell): | |||
| """ | |||
| Proposal subnet. | |||
| Args: | |||
| config (dict): Config. | |||
| batch_size (int): Batchsize. | |||
| num_classes (int) - Class number. | |||
| use_sigmoid_cls (bool) - Select sigmoid or softmax function. | |||
| target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0). | |||
| target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0). | |||
| Returns: | |||
| Tuple, tuple of output tensor,(proposal, mask). | |||
| Examples: | |||
| Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \ | |||
| target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)) | |||
| """ | |||
| def __init__(self, | |||
| config, | |||
| batch_size, | |||
| num_classes, | |||
| use_sigmoid_cls, | |||
| target_means=(.0, .0, .0, .0), | |||
| target_stds=(1.0, 1.0, 1.0, 1.0) | |||
| ): | |||
| super(Proposal, self).__init__() | |||
| cfg = config | |||
| self.batch_size = batch_size | |||
| self.num_classes = num_classes | |||
| self.target_means = target_means | |||
| self.target_stds = target_stds | |||
| self.use_sigmoid_cls = config.use_sigmoid_cls | |||
| if self.use_sigmoid_cls: | |||
| self.cls_out_channels = 1 | |||
| self.activation = P.Sigmoid() | |||
| self.reshape_shape = (-1, 1) | |||
| else: | |||
| self.cls_out_channels = num_classes | |||
| self.activation = P.Softmax(axis=1) | |||
| self.reshape_shape = (-1, 2) | |||
| if self.cls_out_channels <= 0: | |||
| raise ValueError('num_classes={} is too small'.format(num_classes)) | |||
| self.num_pre = cfg.rpn_proposal_nms_pre | |||
| self.min_box_size = cfg.rpn_proposal_min_bbox_size | |||
| self.nms_thr = cfg.rpn_proposal_nms_thr | |||
| self.nms_post = cfg.rpn_proposal_nms_post | |||
| self.nms_across_levels = cfg.rpn_proposal_nms_across_levels | |||
| self.max_num = cfg.rpn_proposal_max_num | |||
| # Op Define | |||
| self.squeeze = P.Squeeze() | |||
| self.reshape = P.Reshape() | |||
| self.cast = P.Cast() | |||
| self.feature_shapes = cfg.feature_shapes | |||
| self.transpose_shape = (1, 2, 0) | |||
| self.decode = BoundingBoxDecode() | |||
| self.nms = P.NMSWithMask(self.nms_thr) | |||
| self.concat_axis0 = P.Concat(axis=0) | |||
| self.concat_axis1 = P.Concat(axis=1) | |||
| self.split = P.Split(axis=1, output_num=5) | |||
| self.min = P.Minimum() | |||
| self.gatherND = P.GatherNd() | |||
| self.slice = P.Slice() | |||
| self.select = P.Select() | |||
| self.greater = P.Greater() | |||
| self.transpose = P.Transpose() | |||
| self.tile = P.Tile() | |||
| self.set_train_local(config, training=True) | |||
| self.multi_10 = Tensor(10.0, mstype.float16) | |||
| def set_train_local(self, config, training=False): | |||
| """Set training flag.""" | |||
| self.training_local = training | |||
| cfg = config | |||
| self.topK_stage1 = () | |||
| self.topK_shape = () | |||
| total_max_topk_input = 0 | |||
| if not self.training_local: | |||
| self.num_pre = cfg.rpn_nms_pre | |||
| self.min_box_size = cfg.rpn_min_bbox_min_size | |||
| self.nms_thr = cfg.rpn_nms_thr | |||
| self.nms_post = cfg.rpn_nms_post | |||
| self.max_num = cfg.rpn_max_num | |||
| k_num = self.num_pre | |||
| total_max_topk_input = k_num | |||
| self.topK_stage1 = k_num | |||
| self.topK_shape = (k_num, 1) | |||
| self.topKv2 = P.TopK(sorted=True) | |||
| self.topK_shape_stage2 = (self.max_num, 1) | |||
| self.min_float_num = -65536.0 | |||
| self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) | |||
| self.shape = P.Shape() | |||
| self.print = P.Print() | |||
| def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): | |||
| proposals_tuple = () | |||
| masks_tuple = () | |||
| for img_id in range(self.batch_size): | |||
| rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::]) | |||
| rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::]) | |||
| proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list) | |||
| proposals_tuple += (proposals,) | |||
| masks_tuple += (masks,) | |||
| return proposals_tuple, masks_tuple | |||
| def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors): | |||
| """Get proposal boundingbox.""" | |||
| mlvl_proposals = () | |||
| mlvl_mask = () | |||
| rpn_cls_score = self.transpose(cls_scores, self.transpose_shape) | |||
| rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape) | |||
| anchors = mlvl_anchors | |||
| # (H, W, A*2) | |||
| rpn_cls_score_shape = self.shape(rpn_cls_score) | |||
| rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \ | |||
| rpn_cls_score_shape[1], -1, self.cls_out_channels)) | |||
| rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) | |||
| rpn_cls_score = self.activation(rpn_cls_score) | |||
| if self.use_sigmoid_cls: | |||
| rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16) | |||
| else: | |||
| rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16) | |||
| rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) | |||
| scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre) | |||
| topk_inds = self.reshape(topk_inds, self.topK_shape) | |||
| bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) | |||
| anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) | |||
| proposals_decode = self.decode(anchors_sorted, bboxes_sorted) | |||
| proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape))) | |||
| proposals, _, mask_valid = self.nms(proposals_decode) | |||
| mlvl_proposals = mlvl_proposals + (proposals,) | |||
| mlvl_mask = mlvl_mask + (mask_valid,) | |||
| proposals = self.concat_axis0(mlvl_proposals) | |||
| masks = self.concat_axis0(mlvl_mask) | |||
| _, _, _, _, scores = self.split(proposals) | |||
| scores = self.squeeze(scores) | |||
| topk_mask = self.cast(self.topK_mask, mstype.float16) | |||
| scores_using = self.select(masks, scores, topk_mask) | |||
| _, topk_inds = self.topKv2(scores_using, self.max_num) | |||
| topk_inds = self.reshape(topk_inds, self.topK_shape_stage2) | |||
| proposals = self.gatherND(proposals, topk_inds) | |||
| masks = self.gatherND(masks, topk_inds) | |||
| return proposals, masks | |||
| @@ -0,0 +1,228 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RPN for fasterRCNN""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.ops import functional as F | |||
| from src.CTPN.bbox_assign_sample import BboxAssignSample | |||
| class RpnRegClsBlock(nn.Cell): | |||
| """ | |||
| Rpn reg cls block for rpn layer | |||
| Args: | |||
| config(EasyDict) - Network construction config. | |||
| in_channels (int) - Input channels of shared convolution. | |||
| feat_channels (int) - Output channels of shared convolution. | |||
| num_anchors (int) - The anchor number. | |||
| cls_out_channels (int) - Output channels of classification convolution. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, | |||
| config, | |||
| in_channels, | |||
| feat_channels, | |||
| num_anchors, | |||
| cls_out_channels): | |||
| super(RpnRegClsBlock, self).__init__() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.shape = (-1, 2*config.hidden_size) | |||
| self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16) | |||
| self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16) | |||
| self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16) | |||
| self.shape1 = (config.num_step, config.rnn_batch_size, -1) | |||
| self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step) | |||
| self.transpose = P.Transpose() | |||
| self.print = P.Print() | |||
| self.dropout = nn.Dropout(0.8) | |||
| def construct(self, x): | |||
| x = self.reshape(x, self.shape) | |||
| x = self.lstm_fc(x) | |||
| x1 = self.rpn_cls(x) | |||
| x1 = self.reshape(x1, self.shape1) | |||
| x1 = self.transpose(x1, (2, 1, 0)) | |||
| x1 = self.reshape(x1, self.shape2) | |||
| x1 = self.transpose(x1, (1, 0, 2, 3)) | |||
| x2 = self.rpn_reg(x) | |||
| x2 = self.reshape(x2, self.shape1) | |||
| x2 = self.transpose(x2, (2, 1, 0)) | |||
| x2 = self.reshape(x2, self.shape2) | |||
| x2 = self.transpose(x2, (1, 0, 2, 3)) | |||
| return x1, x2 | |||
| class RPN(nn.Cell): | |||
| """ | |||
| ROI proposal network.. | |||
| Args: | |||
| config (dict) - Config. | |||
| batch_size (int) - Batchsize. | |||
| in_channels (int) - Input channels of shared convolution. | |||
| feat_channels (int) - Output channels of shared convolution. | |||
| num_anchors (int) - The anchor number. | |||
| cls_out_channels (int) - Output channels of classification convolution. | |||
| Returns: | |||
| Tuple, tuple of output tensor. | |||
| Examples: | |||
| RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, | |||
| num_anchors=3, cls_out_channels=512) | |||
| """ | |||
| def __init__(self, | |||
| config, | |||
| batch_size, | |||
| in_channels, | |||
| feat_channels, | |||
| num_anchors, | |||
| cls_out_channels): | |||
| super(RPN, self).__init__() | |||
| cfg_rpn = config | |||
| self.cfg = config | |||
| self.num_bboxes = cfg_rpn.num_bboxes | |||
| self.feature_anchor_shape = cfg_rpn.feature_shapes | |||
| self.feature_anchor_shape = self.feature_anchor_shape[0] * \ | |||
| self.feature_anchor_shape[1] * num_anchors * batch_size | |||
| self.num_anchors = num_anchors | |||
| self.batch_size = batch_size | |||
| self.test_batch_size = cfg_rpn.test_batch_size | |||
| self.num_layers = 1 | |||
| self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) | |||
| self.use_sigmoid_cls = config.use_sigmoid_cls | |||
| if config.use_sigmoid_cls: | |||
| self.reshape_shape_cls = (-1,) | |||
| self.loss_cls = P.SigmoidCrossEntropyWithLogits() | |||
| cls_out_channels = 1 | |||
| else: | |||
| self.reshape_shape_cls = (-1, cls_out_channels) | |||
| self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none") | |||
| self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\ | |||
| num_anchors, cls_out_channels) | |||
| self.transpose = P.Transpose() | |||
| self.reshape = P.Reshape() | |||
| self.concat = P.Concat(axis=0) | |||
| self.fill = P.Fill() | |||
| self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) | |||
| self.trans_shape = (0, 2, 3, 1) | |||
| self.reshape_shape_reg = (-1, 4) | |||
| self.softmax = nn.Softmax() | |||
| self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) | |||
| self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) | |||
| self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) | |||
| self.num_bboxes = cfg_rpn.num_bboxes | |||
| self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) | |||
| self.CheckValid = P.CheckValid() | |||
| self.sum_loss = P.ReduceSum() | |||
| self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0) | |||
| self.squeeze = P.Squeeze() | |||
| self.cast = P.Cast() | |||
| self.tile = P.Tile() | |||
| self.zeros_like = P.ZerosLike() | |||
| self.loss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
| self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
| self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) | |||
| self.print = P.Print() | |||
| def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): | |||
| """ | |||
| make rpn layer for rpn proposal network | |||
| Args: | |||
| num_layers (int) - layer num. | |||
| in_channels (int) - Input channels of shared convolution. | |||
| feat_channels (int) - Output channels of shared convolution. | |||
| num_anchors (int) - The anchor number. | |||
| cls_out_channels (int) - Output channels of classification convolution. | |||
| Returns: | |||
| List, list of RpnRegClsBlock cells. | |||
| """ | |||
| rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels) | |||
| return rpn_layer | |||
| def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): | |||
| ''' | |||
| inputs(Tensor): Inputs tensor from lstm. | |||
| img_metas(Tensor): Image shape. | |||
| anchor_list(Tensor): Total anchor list. | |||
| gt_labels(Tensor): Ground truth labels. | |||
| gt_valids(Tensor): Whether ground truth is valid. | |||
| ''' | |||
| rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs) | |||
| rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape) | |||
| rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls) | |||
| rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape) | |||
| rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg) | |||
| output = () | |||
| bbox_targets = () | |||
| bbox_weights = () | |||
| labels = () | |||
| label_weights = () | |||
| if self.training: | |||
| for i in range(self.batch_size): | |||
| valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\ | |||
| mstype.int32) | |||
| gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) | |||
| gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) | |||
| gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) | |||
| bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, | |||
| gt_labels_i, | |||
| self.cast(valid_flag_list, | |||
| mstype.bool_), | |||
| anchor_list, gt_valids_i) | |||
| bbox_weight = self.cast(bbox_weight, mstype.float16) | |||
| label_weight = self.cast(label_weight, mstype.float16) | |||
| bbox_targets += (bbox_target,) | |||
| bbox_weights += (bbox_weight,) | |||
| labels += (label,) | |||
| label_weights += (label_weight,) | |||
| bbox_target_with_batchsize = self.concat(bbox_targets) | |||
| bbox_weight_with_batchsize = self.concat(bbox_weights) | |||
| label_with_batchsize = self.concat(labels) | |||
| label_weight_with_batchsize = self.concat(label_weights) | |||
| bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) | |||
| bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) | |||
| label_ = F.stop_gradient(label_with_batchsize) | |||
| label_weight_ = F.stop_gradient(label_weight_with_batchsize) | |||
| rpn_cls_score = self.cast(rpn_cls_score, mstype.float32) | |||
| if self.use_sigmoid_cls: | |||
| label_ = self.cast(label_, mstype.float32) | |||
| loss_cls = self.loss_cls(rpn_cls_score, label_) | |||
| loss_cls = loss_cls * label_weight_ | |||
| loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total | |||
| rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32) | |||
| bbox_target_ = self.cast(bbox_target_, mstype.float32) | |||
| loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_) | |||
| bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4)) | |||
| loss_reg = loss_reg * bbox_weight_ | |||
| loss_reg = self.sum_loss(loss_reg, (1,)) | |||
| loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total | |||
| loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg | |||
| output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg) | |||
| else: | |||
| output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1) | |||
| return output | |||
| @@ -0,0 +1,177 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.common.dtype as mstype | |||
| def _weight_variable(shape, factor=0.01): | |||
| ''''weight initialize''' | |||
| init_value = np.random.randn(*shape).astype(np.float32) * factor | |||
| return Tensor(init_value) | |||
| def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False): | |||
| """Batchnorm2D wrapper.""" | |||
| gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) | |||
| beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) | |||
| moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) | |||
| moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) | |||
| return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, | |||
| beta_init=beta_init, moving_mean_init=moving_mean_init, | |||
| moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) | |||
| def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True): | |||
| """Conv2D wrapper.""" | |||
| weights = 'ones' | |||
| layers = [] | |||
| conv = nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| pad_mode=pad_mode, weight_init=weights, has_bias=False) | |||
| if not weights_update: | |||
| conv.weight.requires_grad = False | |||
| layers += [conv] | |||
| layers += [_BatchNorm2dInit(out_channels)] | |||
| return nn.SequentialCell(layers) | |||
| def _fc(in_channels, out_channels): | |||
| '''full connection layer''' | |||
| weight = _weight_variable((out_channels, in_channels)) | |||
| bias = _weight_variable((out_channels,)) | |||
| return nn.Dense(in_channels, out_channels, weight, bias) | |||
| class VGG16FeatureExtraction(nn.Cell): | |||
| def __init__(self, weights_update=False): | |||
| """ | |||
| VGG16 feature extraction | |||
| Args: | |||
| weights_updata(bool): whether update weights for top two layers, default is False. | |||
| """ | |||
| super(VGG16FeatureExtraction, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") | |||
| self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) | |||
| self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\ | |||
| padding=1, weights_update=weights_update) | |||
| self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\ | |||
| padding=1, weights_update=weights_update) | |||
| self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\ | |||
| padding=1, weights_update=weights_update) | |||
| self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\ | |||
| padding=1, weights_update=weights_update) | |||
| self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1) | |||
| self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1) | |||
| self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1) | |||
| self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1) | |||
| self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) | |||
| self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) | |||
| self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) | |||
| self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) | |||
| self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) | |||
| self.cast = P.Cast() | |||
| def construct(self, x): | |||
| """ | |||
| :param x: shape=(B, 3, 224, 224) | |||
| :return: | |||
| """ | |||
| x = self.cast(x, mstype.float32) | |||
| x = self.conv1_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv1_2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool(x) | |||
| x = self.conv2_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv2_2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool(x) | |||
| x = self.conv3_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv3_2(x) | |||
| x = self.relu(x) | |||
| x = self.conv3_3(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool(x) | |||
| x = self.conv4_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv4_2(x) | |||
| x = self.relu(x) | |||
| x = self.conv4_3(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool(x) | |||
| x = self.conv5_1(x) | |||
| x = self.relu(x) | |||
| x = self.conv5_2(x) | |||
| x = self.relu(x) | |||
| x = self.conv5_3(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class VGG16Classfier(nn.Cell): | |||
| def __init__(self): | |||
| """VGG16 classfier structure""" | |||
| super(VGG16Classfier, self).__init__() | |||
| self.flatten = P.Flatten() | |||
| self.relu = nn.ReLU() | |||
| self.fc1 = _fc(in_channels=7*7*512, out_channels=4096) | |||
| self.fc2 = _fc(in_channels=4096, out_channels=4096) | |||
| self.batch_size = 32 | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x): | |||
| """ | |||
| :param x: shape=(B, 512, 7, 7) | |||
| :return: | |||
| """ | |||
| x = self.reshape(x, (self.batch_size, 7*7*512)) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class VGG16(nn.Cell): | |||
| def __init__(self): | |||
| """VGG16 construct for training backbone""" | |||
| super(VGG16, self).__init__() | |||
| self.feature_extraction = VGG16FeatureExtraction(weights_update=True) | |||
| self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.classifier = VGG16Classfier() | |||
| self.fc3 = _fc(in_channels=4096, out_channels=1000) | |||
| def construct(self, x): | |||
| """ | |||
| :param x: shape=(B, 3, 224, 224) | |||
| :return: logits, shape=(B, 1000) | |||
| """ | |||
| feature_maps = self.feature_extraction(x) | |||
| x = self.max_pool(feature_maps) | |||
| x = self.classifier(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Network parameters.""" | |||
| from easydict import EasyDict | |||
| pretrain_config = EasyDict({ | |||
| # LR | |||
| "base_lr": 0.0009, | |||
| "warmup_step": 30000, | |||
| "warmup_ratio": 1/3.0, | |||
| "total_epoch": 100, | |||
| }) | |||
| finetune_config = EasyDict({ | |||
| # LR | |||
| "base_lr": 0.0005, | |||
| "warmup_step": 300, | |||
| "warmup_ratio": 1/3.0, | |||
| "total_epoch": 50, | |||
| }) | |||
| # use for low case number | |||
| config = EasyDict({ | |||
| "img_width": 960, | |||
| "img_height": 576, | |||
| "keep_ratio": False, | |||
| "flip_ratio": 0.0, | |||
| "photo_ratio": 0.0, | |||
| "expand_ratio": 1.0, | |||
| # anchor | |||
| "feature_shapes": (36, 60), | |||
| "num_anchors": 14, | |||
| "anchor_base": 16, | |||
| "anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406], | |||
| "anchor_width": [16], | |||
| # rpn | |||
| "rpn_in_channels": 256, | |||
| "rpn_feat_channels": 512, | |||
| "rpn_loss_cls_weight": 1.0, | |||
| "rpn_loss_reg_weight": 3.0, | |||
| "rpn_cls_out_channels": 2, | |||
| # bbox_assign_sampler | |||
| "neg_iou_thr": 0.5, | |||
| "pos_iou_thr": 0.7, | |||
| "min_pos_iou": 0.001, | |||
| "num_bboxes": 30240, | |||
| "num_gts": 256, | |||
| "num_expected_neg": 512, | |||
| "num_expected_pos": 256, | |||
| #proposal | |||
| "activate_num_classes": 2, | |||
| "use_sigmoid_cls": False, | |||
| # train proposal | |||
| "rpn_proposal_nms_across_levels": False, | |||
| "rpn_proposal_nms_pre": 2000, | |||
| "rpn_proposal_nms_post": 1000, | |||
| "rpn_proposal_max_num": 1000, | |||
| "rpn_proposal_nms_thr": 0.7, | |||
| "rpn_proposal_min_bbox_size": 8, | |||
| # rnn structure | |||
| "input_size": 512, | |||
| "num_step": 60, | |||
| "rnn_batch_size": 36, | |||
| "hidden_size": 128, | |||
| # training | |||
| "warmup_mode": "linear", | |||
| "batch_size": 1, | |||
| "momentum": 0.9, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 10, | |||
| "keep_checkpoint_max": 5, | |||
| "save_checkpoint_path": "./", | |||
| "use_dropout": False, | |||
| "loss_scale": 1, | |||
| "weight_decay": 1e-4, | |||
| # test proposal | |||
| "rpn_nms_pre": 2000, | |||
| "rpn_nms_post": 1000, | |||
| "rpn_max_num": 1000, | |||
| "rpn_nms_thr": 0.7, | |||
| "rpn_min_bbox_min_size": 8, | |||
| "test_iou_thr": 0.7, | |||
| "test_max_per_img": 100, | |||
| "test_batch_size": 1, | |||
| "use_python_proposal": False, | |||
| # text proposal connection | |||
| "max_horizontal_gap": 60, | |||
| "text_proposals_min_scores": 0.7, | |||
| "text_proposals_nms_thresh": 0.2, | |||
| "min_v_overlaps": 0.7, | |||
| "min_size_sim": 0.7, | |||
| "min_ratio": 0.5, | |||
| "line_min_score": 0.9, | |||
| "text_proposals_width": 16, | |||
| "min_num_proposals": 2, | |||
| # create dataset | |||
| "coco_root": "", | |||
| "coco_train_data_type": "", | |||
| "cocotext_json": "", | |||
| "icdar11_train_path": [], | |||
| "icdar13_train_path": [], | |||
| "icdar15_train_path": [], | |||
| "icdar13_test_path": [], | |||
| "flick_train_path": [], | |||
| "svt_train_path": [], | |||
| "pretrain_dataset_path": "", | |||
| "finetune_dataset_path": "", | |||
| "test_dataset_path": "", | |||
| # training dataset | |||
| "pretraining_dataset_file": "", | |||
| "finetune_dataset_file": "" | |||
| }) | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """convert icdar2015 dataset label""" | |||
| import os | |||
| import argparse | |||
| def init_args(): | |||
| parser = argparse.ArgumentParser('') | |||
| parser.add_argument('-s', '--src_label_path', type=str, default='./', | |||
| help='Directory containing icdar2015 train label') | |||
| parser.add_argument('-t', '--target_label_path', type=str, default='test.xml', | |||
| help='Directory where save the icdar2015 label after convert') | |||
| return parser.parse_args() | |||
| def convert(): | |||
| args = init_args() | |||
| anno_file = os.listdir(args.src_label_path) | |||
| annos = {} | |||
| # read | |||
| for file in anno_file: | |||
| gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines() | |||
| label_list = [] | |||
| label_name = os.path.basename(file) | |||
| for each_label in gt: | |||
| print(file) | |||
| spt = each_label.split(',') | |||
| print(spt) | |||
| if "###" in spt[8]: | |||
| continue | |||
| else: | |||
| x1 = min(int(spt[0]), int(spt[6])) | |||
| y1 = min(int(spt[1]), int(spt[3])) | |||
| x2 = max(int(spt[2]), int(spt[4])) | |||
| y2 = max(int(spt[5]), int(spt[7])) | |||
| label_list.append([x1, y1, x2, y2]) | |||
| annos[label_name] = label_list | |||
| # write | |||
| if not os.path.exists(args.target_label_path): | |||
| os.makedirs(args.target_label_path) | |||
| for label_file, pos in annos.items(): | |||
| tgt_anno_file = os.path.join(args.target_label_path, label_file) | |||
| f = open(tgt_anno_file, 'w', encoding='UTF-8-sig') | |||
| for tgt_label in pos: | |||
| str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3]) | |||
| f.write(str_pos) | |||
| f.write("\n") | |||
| f.close() | |||
| if __name__ == "__main__": | |||
| convert() | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """convert svt dataset label""" | |||
| import os | |||
| import argparse | |||
| from xml.etree import ElementTree as ET | |||
| import numpy as np | |||
| def init_args(): | |||
| parser = argparse.ArgumentParser('') | |||
| parser.add_argument('-d', '--dataset_dir', type=str, default='./', | |||
| help='Directory containing images') | |||
| parser.add_argument('-x', '--xml_file', type=str, default='test.xml', | |||
| help='Directory where character dictionaries for the dataset were stored') | |||
| parser.add_argument('-o', '--location_dir', type=str, default='./location', | |||
| help='Directory where ord map dictionaries for the dataset were stored') | |||
| return parser.parse_args() | |||
| def xml_to_dict(xml_file, save_file=False): | |||
| tree = ET.parse(xml_file) | |||
| root = tree.getroot() | |||
| imgs_labels = [] | |||
| for ch in root: | |||
| im_label = {} | |||
| for ch01 in ch: | |||
| if ch01.tag in "address": | |||
| continue | |||
| elif ch01.tag in 'taggedRectangles': | |||
| # multiple children | |||
| rect_list = [] | |||
| for ch02 in ch01: | |||
| rect = {} | |||
| rect['location'] = ch02.attrib | |||
| rect['label'] = ch02[0].text | |||
| rect_list.append(rect) | |||
| im_label['rect'] = rect_list | |||
| else: | |||
| im_label[ch01.tag] = ch01.text | |||
| imgs_labels.append(im_label) | |||
| if save_file: | |||
| np.save("annotation_train.npy", imgs_labels) | |||
| return imgs_labels | |||
| def convert(): | |||
| args = init_args() | |||
| if not os.path.exists(args.dataset_dir): | |||
| raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir)) | |||
| if not os.path.exists(args.xml_file): | |||
| raise ValueError("xml_file :{} does not exist".format(args.xml_file)) | |||
| if not os.path.exists(args.location_dir): | |||
| os.makedirs(args.location_dir) | |||
| ims_labels_dict = xml_to_dict(args.xml_file, True) | |||
| num_images = len(ims_labels_dict) | |||
| print("Converting annotation, {} images in total ".format(num_images)) | |||
| for i in range(num_images): | |||
| img_label = ims_labels_dict[i] | |||
| image_name = img_label['imageName'] | |||
| rects = img_label['rect'] | |||
| print("processing image: {}".format(image_name)) | |||
| location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt")) | |||
| f = open(location_file_name, 'w') | |||
| for j, rect in enumerate(rects): | |||
| rect = rects[j] | |||
| location = rect['location'] | |||
| h = int(location['height']) | |||
| w = int(location['width']) | |||
| x = int(location['x']) | |||
| y = int(location['y']) | |||
| pos = [x, y, x+w, y+h] | |||
| str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3]) | |||
| f.write(str_pos) | |||
| f.write("\n") | |||
| f.close() | |||
| if __name__ == "__main__": | |||
| convert() | |||
| @@ -0,0 +1,177 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from __future__ import division | |||
| import os | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.config import config | |||
| def create_coco_label(): | |||
| """Create image label.""" | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| coco_root = config.coco_root | |||
| data_type = config.coco_train_data_type | |||
| from src.coco_text import COCO_Text | |||
| anno_json = config.cocotext_json | |||
| ct = COCO_Text(anno_json) | |||
| image_ids = ct.getImgIds(imgIds=ct.train, | |||
| catIds=[('legibility', 'legible')]) | |||
| for img_id in image_ids: | |||
| image_info = ct.loadImgs(img_id)[0] | |||
| file_name = image_info['file_name'][15:] | |||
| anno_ids = ct.getAnnIds(imgIds=img_id) | |||
| anno = ct.loadAnns(anno_ids) | |||
| image_path = os.path.join(coco_root, data_type, file_name) | |||
| annos = [] | |||
| im = Image.open(image_path) | |||
| width, _ = im.size | |||
| for label in anno: | |||
| bbox = label["bbox"] | |||
| bbox_width = int(bbox[2]) | |||
| if 60 * bbox_width < width: | |||
| continue | |||
| x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2]) | |||
| y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3]) | |||
| annos.append([x1, y1, x2, y2] + [1]) | |||
| if annos: | |||
| image_anno_dict[image_path] = np.array(annos) | |||
| image_files.append(image_path) | |||
| return image_files, image_anno_dict | |||
| def create_anno_dataset_label(train_img_dirs, train_txt_dirs): | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| # read | |||
| img_basenames = [] | |||
| for file in os.listdir(train_img_dirs): | |||
| # Filter git file. | |||
| if 'gif' not in file: | |||
| img_basenames.append(os.path.basename(file)) | |||
| img_names = [] | |||
| for item in img_basenames: | |||
| temp1, _ = os.path.splitext(item) | |||
| img_names.append((temp1, item)) | |||
| for img, img_basename in img_names: | |||
| image_path = train_img_dirs + '/' + img_basename | |||
| annos = [] | |||
| if len(img) == 6 and '_' not in img_basename: | |||
| gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines() | |||
| if img.isdigit() and int(img) > 1200: | |||
| continue | |||
| for img_each_label in gt: | |||
| spt = img_each_label.replace(',', '').split(' ') | |||
| if ' ' not in img_each_label: | |||
| spt = img_each_label.split(',') | |||
| annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1]) | |||
| if annos: | |||
| image_anno_dict[image_path] = np.array(annos) | |||
| image_files.append(image_path) | |||
| return image_files, image_anno_dict | |||
| def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix): | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| img_basenames = [] | |||
| for file_name in os.listdir(train_img_dir): | |||
| if 'gif' not in file_name: | |||
| img_basenames.append(os.path.basename(file_name)) | |||
| img_names = [] | |||
| for item in img_basenames: | |||
| temp1, _ = os.path.splitext(item) | |||
| img_names.append((temp1, item)) | |||
| for img, img_basename in img_names: | |||
| image_path = train_img_dir + '/' + img_basename | |||
| annos = [] | |||
| file_name = prefix + img + ".txt" | |||
| file_path = os.path.join(train_txt_dir, file_name) | |||
| gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines() | |||
| if not gt: | |||
| continue | |||
| for img_each_label in gt: | |||
| spt = img_each_label.replace(',', '').split(' ') | |||
| if ' ' not in img_each_label: | |||
| spt = img_each_label.split(',') | |||
| annos.append([spt[0], spt[1], spt[2], spt[3]] + [1]) | |||
| if annos: | |||
| image_anno_dict[image_path] = np.array(annos) | |||
| image_files.append(image_path) | |||
| return image_files, image_anno_dict | |||
| def create_train_dataset(dataset_type): | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| if dataset_type == "pretraining": | |||
| # pretrianing: coco, flick, icdar2013 train, icdar2015, svt | |||
| coco_image_files, coco_anno_dict = create_coco_label() | |||
| flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0], | |||
| config.flick_train_path[1]) | |||
| icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0], | |||
| config.icdar13_train_path[1], "gt_img_") | |||
| icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0], | |||
| config.icdar15_train_path[1], "gt_") | |||
| svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "") | |||
| image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files | |||
| image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \ | |||
| **icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict} | |||
| data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \ | |||
| prefix="ctpn_pretrain.mindrecord", file_num=8) | |||
| elif dataset_type == "finetune": | |||
| # finetune: icdar2011, icdar2013 train, flick | |||
| flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0], | |||
| config.flick_train_path[1]) | |||
| icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0], | |||
| config.icdar11_train_path[1], "gt_") | |||
| icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0], | |||
| config.icdar13_train_path[1], "gt_img_") | |||
| image_files = flick_image_files + icdar11_image_files + icdar13_image_files | |||
| image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict} | |||
| data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \ | |||
| prefix="ctpn_finetune.mindrecord", file_num=8) | |||
| elif dataset_type == "test": | |||
| # test: icdar2013 test | |||
| icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ | |||
| config.icdar13_test_path[1], "") | |||
| image_files = icdar_test_image_files | |||
| image_anno_dict = icdar_test_anno_dict | |||
| data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ | |||
| prefix="ctpn_test.mindrecord", file_num=1) | |||
| else: | |||
| print("dataset_type should be pretraining, finetune, test") | |||
| def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1): | |||
| """Create MindRecord file.""" | |||
| mindrecord_path = os.path.join(dst_dir, prefix) | |||
| writer = FileWriter(mindrecord_path, file_num) | |||
| ctpn_json = { | |||
| "image": {"type": "bytes"}, | |||
| "annotation": {"type": "int32", "shape": [-1, 5]}, | |||
| } | |||
| writer.add_schema(ctpn_json, "ctpn_json") | |||
| for image_name in image_files: | |||
| with open(image_name, 'rb') as f: | |||
| img = f.read() | |||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||
| print("img name is {}, anno is {}".format(image_name, annos)) | |||
| row = {"image": img, "annotation": annos} | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| if __name__ == "__main__": | |||
| create_train_dataset("pretraining") | |||
| create_train_dataset("finetune") | |||
| create_train_dataset("test") | |||
| @@ -0,0 +1,148 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CPTN network definition.""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from src.CTPN.rpn import RPN | |||
| from src.CTPN.anchor_generator import AnchorGenerator | |||
| from src.CTPN.proposal_generator import Proposal | |||
| from src.CTPN.vgg16 import VGG16FeatureExtraction | |||
| class BiLSTM(nn.Cell): | |||
| """ | |||
| Define a BiLSTM network which contains two LSTM layers | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| """ | |||
| def __init__(self, config, is_training=True): | |||
| super(BiLSTM, self).__init__() | |||
| self.is_training = is_training | |||
| self.batch_size = config.batch_size * config.rnn_batch_size | |||
| print("batch size is {} ".format(self.batch_size)) | |||
| self.input_size = config.input_size | |||
| self.hidden_size = config.hidden_size | |||
| self.num_step = config.num_step | |||
| self.reshape = P.Reshape() | |||
| self.cast = P.Cast() | |||
| k = (1 / self.hidden_size) ** 0.5 | |||
| self.rnn1 = P.DynamicRNN(forget_bias=0.0) | |||
| self.rnn_bw = P.DynamicRNN(forget_bias=0.0) | |||
| self.w1 = Parameter(np.random.uniform(-k, k, \ | |||
| (self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1") | |||
| self.w1_bw = Parameter(np.random.uniform(-k, k, \ | |||
| (self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw") | |||
| self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1") | |||
| self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw") | |||
| self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) | |||
| self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) | |||
| self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) | |||
| self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) | |||
| self.reverse_seq = P.ReverseV2(axis=[0]) | |||
| self.concat = P.Concat() | |||
| self.transpose = P.Transpose() | |||
| self.concat1 = P.Concat(axis=2) | |||
| self.dropout = nn.Dropout(0.7) | |||
| self.use_dropout = config.use_dropout | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| def construct(self, x): | |||
| if self.use_dropout: | |||
| x = self.dropout(x) | |||
| x = self.cast(x, mstype.float16) | |||
| bw_x = self.reverse_seq(x) | |||
| y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1) | |||
| y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw) | |||
| y1_bw = self.reverse_seq(y1_bw) | |||
| output = self.concat1((y1, y1_bw)) | |||
| return output | |||
| class CTPN(nn.Cell): | |||
| """ | |||
| Define CTPN network | |||
| Args: | |||
| input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for | |||
| captcha images. | |||
| batch_size(int): batch size of input data, default is 64 | |||
| hidden_size(int): the hidden size in LSTM layers, default is 512 | |||
| """ | |||
| def __init__(self, config, is_training=True): | |||
| super(CTPN, self).__init__() | |||
| self.config = config | |||
| self.is_training = is_training | |||
| self.num_step = config.num_step | |||
| self.input_size = config.input_size | |||
| self.batch_size = config.batch_size | |||
| self.hidden_size = config.hidden_size | |||
| self.vgg16_feature_extractor = VGG16FeatureExtraction() | |||
| self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same') | |||
| self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16) | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.cast = P.Cast() | |||
| # rpn block | |||
| self.rpn_with_loss = RPN(config, | |||
| self.batch_size, | |||
| config.rpn_in_channels, | |||
| config.rpn_feat_channels, | |||
| config.num_anchors, | |||
| config.rpn_cls_out_channels) | |||
| self.anchor_generator = AnchorGenerator(config) | |||
| self.featmap_size = config.feature_shapes | |||
| self.anchor_list = self.get_anchors(self.featmap_size) | |||
| self.proposal_generator_test = Proposal(config, | |||
| config.test_batch_size, | |||
| config.activate_num_classes, | |||
| config.use_sigmoid_cls) | |||
| self.proposal_generator_test.set_train_local(config, False) | |||
| def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids): | |||
| # (1,3,600,900) | |||
| x = self.vgg16_feature_extractor(img_data) | |||
| x = self.conv(x) | |||
| x = self.cast(x, mstype.float16) | |||
| # (1, 512, 38, 57) | |||
| x = self.transpose(x, (0, 2, 1, 3)) | |||
| x = self.reshape(x, (-1, self.input_size, self.num_step)) | |||
| x = self.transpose(x, (2, 0, 1)) | |||
| # (57, 38, 512) | |||
| x = self.rnn(x) | |||
| # (57, 38, 256) | |||
| #x = self.cast(x, mstype.float32) | |||
| rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x, | |||
| img_metas, | |||
| self.anchor_list, | |||
| gt_bboxes, | |||
| gt_labels, | |||
| gt_valids) | |||
| if self.training: | |||
| return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss | |||
| proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list) | |||
| return proposal, proposal_mask | |||
| def get_anchors(self, featmap_size): | |||
| anchors = self.anchor_generator.grid_anchors(featmap_size) | |||
| return Tensor(anchors, mstype.float16) | |||
| @@ -0,0 +1,342 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FasterRcnn dataset""" | |||
| from __future__ import division | |||
| import os | |||
| import numpy as np | |||
| from numpy import random | |||
| import mmcv | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| import mindspore.dataset.transforms.c_transforms as CC | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.config import config | |||
| class PhotoMetricDistortion: | |||
| """Photo Metric Distortion""" | |||
| def __init__(self, | |||
| brightness_delta=32, | |||
| contrast_range=(0.5, 1.5), | |||
| saturation_range=(0.5, 1.5), | |||
| hue_delta=18): | |||
| self.brightness_delta = brightness_delta | |||
| self.contrast_lower, self.contrast_upper = contrast_range | |||
| self.saturation_lower, self.saturation_upper = saturation_range | |||
| self.hue_delta = hue_delta | |||
| def __call__(self, img, boxes, labels): | |||
| img = img.astype('float32') | |||
| if random.randint(2): | |||
| delta = random.uniform(-self.brightness_delta, self.brightness_delta) | |||
| img += delta | |||
| mode = random.randint(2) | |||
| if mode == 1: | |||
| if random.randint(2): | |||
| alpha = random.uniform(self.contrast_lower, | |||
| self.contrast_upper) | |||
| img *= alpha | |||
| # convert color from BGR to HSV | |||
| img = mmcv.bgr2hsv(img) | |||
| # random saturation | |||
| if random.randint(2): | |||
| img[..., 1] *= random.uniform(self.saturation_lower, | |||
| self.saturation_upper) | |||
| # random hue | |||
| if random.randint(2): | |||
| img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) | |||
| img[..., 0][img[..., 0] > 360] -= 360 | |||
| img[..., 0][img[..., 0] < 0] += 360 | |||
| # convert color from HSV to BGR | |||
| img = mmcv.hsv2bgr(img) | |||
| # random contrast | |||
| if mode == 0: | |||
| if random.randint(2): | |||
| alpha = random.uniform(self.contrast_lower, | |||
| self.contrast_upper) | |||
| img *= alpha | |||
| # randomly swap channels | |||
| if random.randint(2): | |||
| img = img[..., random.permutation(3)] | |||
| return img, boxes, labels | |||
| class Expand: | |||
| """expand image""" | |||
| def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): | |||
| if to_rgb: | |||
| self.mean = mean[::-1] | |||
| else: | |||
| self.mean = mean | |||
| self.min_ratio, self.max_ratio = ratio_range | |||
| def __call__(self, img, boxes, labels): | |||
| if random.randint(2): | |||
| return img, boxes, labels | |||
| h, w, c = img.shape | |||
| ratio = random.uniform(self.min_ratio, self.max_ratio) | |||
| expand_img = np.full((int(h * ratio), int(w * ratio), c), | |||
| self.mean).astype(img.dtype) | |||
| left = int(random.uniform(0, w * ratio - w)) | |||
| top = int(random.uniform(0, h * ratio - h)) | |||
| expand_img[top:top + h, left:left + w] = img | |||
| img = expand_img | |||
| boxes += np.tile((left, top), 2) | |||
| return img, boxes, labels | |||
| def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """rescale operation for image""" | |||
| img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) | |||
| if img_data.shape[0] > config.img_height: | |||
| img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True) | |||
| scale_factor = scale_factor * scale_factor2 | |||
| img_shape = np.append(img_shape, scale_factor) | |||
| img_shape = np.asarray(img_shape, dtype=np.float32) | |||
| gt_bboxes = gt_bboxes * scale_factor | |||
| gt_bboxes = split_gtbox_label(gt_bboxes) | |||
| if gt_bboxes.shape[0] != 0: | |||
| gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
| gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """resize operation for image""" | |||
| img_data = img | |||
| img_data, w_scale, h_scale = mmcv.imresize( | |||
| img_data, (config.img_width, config.img_height), return_scale=True) | |||
| scale_factor = np.array( | |||
| [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||
| img_shape = (config.img_height, config.img_width, 1.0) | |||
| img_shape = np.asarray(img_shape, dtype=np.float32) | |||
| gt_bboxes = gt_bboxes * scale_factor | |||
| gt_bboxes = split_gtbox_label(gt_bboxes) | |||
| gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
| gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """resize operation for image of eval""" | |||
| img_data = img | |||
| img_data, w_scale, h_scale = mmcv.imresize( | |||
| img_data, (config.img_width, config.img_height), return_scale=True) | |||
| scale_factor = np.array( | |||
| [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) | |||
| img_shape = (config.img_height, config.img_width) | |||
| img_shape = np.append(img_shape, (h_scale, w_scale)) | |||
| img_shape = np.asarray(img_shape, dtype=np.float32) | |||
| gt_bboxes = gt_bboxes * scale_factor | |||
| shape = gt_bboxes.shape | |||
| label_column = np.ones((shape[0], 1), dtype=int) | |||
| gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1) | |||
| gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) | |||
| gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) | |||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """flipped generation""" | |||
| img_data = img | |||
| flipped = gt_bboxes.copy() | |||
| _, w, _ = img_data.shape | |||
| flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 | |||
| flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 | |||
| return (img_data, img_shape, flipped, gt_label, gt_num) | |||
| def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| img_data = img[:, :, ::-1] | |||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """photo crop operation for image""" | |||
| random_photo = PhotoMetricDistortion() | |||
| img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) | |||
| return (img_data, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num): | |||
| """expand operation for image""" | |||
| expand = Expand() | |||
| img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label) | |||
| return (img, img_shape, gt_bboxes, gt_label, gt_num) | |||
| def split_gtbox_label(gt_bbox_total): | |||
| """split ground truth box label""" | |||
| gtbox_list = [] | |||
| box_num, _ = gt_bbox_total.shape | |||
| for i in range(box_num): | |||
| gt_bbox = gt_bbox_total[i] | |||
| if gt_bbox[0] % 16 != 0: | |||
| gt_bbox[0] = (gt_bbox[0] // 16) * 16 | |||
| if gt_bbox[2] % 16 != 0: | |||
| gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16 | |||
| x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16) | |||
| for x0 in x0_array: | |||
| gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1]) | |||
| return np.array(gtbox_list) | |||
| def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid): | |||
| """pad ground truth label""" | |||
| pad_max_number = 256 | |||
| gt_label = gt_bboxes[:, 4] | |||
| gt_valid = gt_bboxes[:, 4] | |||
| if gt_bboxes.shape[0] < 256: | |||
| gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \ | |||
| mode="constant", constant_values=0) | |||
| gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1) | |||
| gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0) | |||
| else: | |||
| print("WARNING label num is high than 256") | |||
| gt_box = gt_bboxes[0:pad_max_number] | |||
| gt_label = gt_label[0:pad_max_number] | |||
| gt_valid = gt_valid[0:pad_max_number] | |||
| return (img, img_shape, gt_box[:, :4], gt_label, gt_valid) | |||
| def preprocess_fn(image, box, is_training): | |||
| """Preprocess function for dataset.""" | |||
| def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid): | |||
| image_shape = image_shape[:2] | |||
| input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid | |||
| if config.keep_ratio: | |||
| input_data = rescale_column(*input_data) | |||
| else: | |||
| input_data = resize_column_test(*input_data) | |||
| input_data = pad_label(*input_data) | |||
| input_data = image_bgr_rgb(*input_data) | |||
| output_data = input_data | |||
| return output_data | |||
| def _data_aug(image, box, is_training): | |||
| """Data augmentation function.""" | |||
| image_bgr = image.copy() | |||
| image_bgr[:, :, 0] = image[:, :, 2] | |||
| image_bgr[:, :, 1] = image[:, :, 1] | |||
| image_bgr[:, :, 2] = image[:, :, 0] | |||
| image_shape = image_bgr.shape[:2] | |||
| gt_box = box[:, :4] | |||
| gt_label = box[:, 4] | |||
| gt_valid = box[:, 4] | |||
| input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid | |||
| if not is_training: | |||
| return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid) | |||
| expand = (np.random.rand() < config.expand_ratio) | |||
| if expand: | |||
| input_data = expand_column(*input_data) | |||
| input_data = photo_crop_column(*input_data) | |||
| if config.keep_ratio: | |||
| input_data = rescale_column(*input_data) | |||
| else: | |||
| input_data = resize_column(*input_data) | |||
| input_data = pad_label(*input_data) | |||
| input_data = image_bgr_rgb(*input_data) | |||
| output_data = input_data | |||
| return output_data | |||
| return _data_aug(image, box, is_training) | |||
| def anno_parser(annos_str): | |||
| """Parse annotation from string to list.""" | |||
| annos = [] | |||
| for anno_str in annos_str: | |||
| anno = list(map(int, anno_str.strip().split(','))) | |||
| annos.append(anno) | |||
| return annos | |||
| def filter_valid_data(image_dir, anno_path): | |||
| """Filter valid image file, which both in image_dir and anno_path.""" | |||
| image_files = [] | |||
| image_anno_dict = {} | |||
| if not os.path.isdir(image_dir): | |||
| raise RuntimeError("Path given is not valid.") | |||
| if not os.path.isfile(anno_path): | |||
| raise RuntimeError("Annotation file is not valid.") | |||
| with open(anno_path, "rb") as f: | |||
| lines = f.readlines() | |||
| for line in lines: | |||
| line_str = line.decode("utf-8").strip() | |||
| line_split = str(line_str).split(' ') | |||
| file_name = line_split[0] | |||
| image_path = os.path.join(image_dir, file_name) | |||
| if os.path.isfile(image_path): | |||
| image_anno_dict[image_path] = anno_parser(line_split[1:]) | |||
| image_files.append(image_path) | |||
| return image_files, image_anno_dict | |||
| def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8): | |||
| """Create MindRecord file.""" | |||
| mindrecord_dir = config.mindrecord_dir | |||
| mindrecord_path = os.path.join(mindrecord_dir, prefix) | |||
| writer = FileWriter(mindrecord_path, file_num) | |||
| image_files, image_anno_dict = create_icdar_test_label() | |||
| ctpn_json = { | |||
| "image": {"type": "bytes"}, | |||
| "annotation": {"type": "int32", "shape": [-1, 6]}, | |||
| } | |||
| writer.add_schema(ctpn_json, "ctpn_json") | |||
| for image_name in image_files: | |||
| with open(image_name, 'rb') as f: | |||
| img = f.read() | |||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||
| row = {"image": img, "annotation": annos} | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0, | |||
| is_training=True, num_parallel_workers=4): | |||
| """Creatr deeptext dataset with MindDataset.""" | |||
| ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\ | |||
| num_parallel_workers=8, shuffle=is_training) | |||
| decode = C.Decode() | |||
| ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1) | |||
| compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) | |||
| hwc_to_chw = C.HWC2CHW() | |||
| normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) | |||
| type_cast0 = CC.TypeCast(mstype.float32) | |||
| type_cast1 = CC.TypeCast(mstype.float16) | |||
| type_cast2 = CC.TypeCast(mstype.int32) | |||
| type_cast3 = CC.TypeCast(mstype.bool_) | |||
| if is_training: | |||
| ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], | |||
| output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
| column_order=["image", "image_shape", "box", "label", "valid_num"], | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], | |||
| num_parallel_workers=12) | |||
| ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"], | |||
| num_parallel_workers=12) | |||
| else: | |||
| ds = ds.map(operations=compose_map_func, | |||
| input_columns=["image", "annotation"], | |||
| output_columns=["image", "image_shape", "box", "label", "valid_num"], | |||
| column_order=["image", "image_shape", "box", "label", "valid_num"], | |||
| num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], | |||
| num_parallel_workers=24) | |||
| # transpose_column from python to c | |||
| ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) | |||
| ds = ds.map(operations=[type_cast1], input_columns=["box"]) | |||
| ds = ds.map(operations=[type_cast2], input_columns=["label"]) | |||
| ds = ds.map(operations=[type_cast3], input_columns=["valid_num"]) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_num) | |||
| return ds | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """lr generator for deeptext""" | |||
| import math | |||
| def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
| learning_rate = float(init_lr) + lr_inc * current_step | |||
| return learning_rate | |||
| def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||
| base = float(current_step - warmup_steps) / float(decay_steps) | |||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||
| return learning_rate | |||
| def dynamic_lr(config, base_step): | |||
| """dynamic learning rate generator""" | |||
| base_lr = config.base_lr | |||
| total_steps = int(base_step * config.total_epoch) | |||
| warmup_steps = config.warmup_step | |||
| lr = [] | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | |||
| else: | |||
| lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||
| return lr | |||
| @@ -0,0 +1,153 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """FasterRcnn training network wrapper.""" | |||
| import time | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore import ParameterTuple | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| time_stamp_init = False | |||
| time_stamp_first = 0 | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF terminating training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, per_print_times=1, rank_id=0): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.count = 0 | |||
| self.rpn_loss_sum = 0 | |||
| self.rpn_cls_loss_sum = 0 | |||
| self.rpn_reg_loss_sum = 0 | |||
| self.rank_id = rank_id | |||
| global time_stamp_init, time_stamp_first | |||
| if not time_stamp_init: | |||
| time_stamp_first = time.time() | |||
| time_stamp_init = True | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| rpn_loss = cb_params.net_outputs[0].asnumpy() | |||
| rpn_cls_loss = cb_params.net_outputs[1].asnumpy() | |||
| rpn_reg_loss = cb_params.net_outputs[2].asnumpy() | |||
| self.count += 1 | |||
| self.rpn_loss_sum += float(rpn_loss) | |||
| self.rpn_cls_loss_sum += float(rpn_cls_loss) | |||
| self.rpn_reg_loss_sum += float(rpn_reg_loss) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| if self.count >= 1: | |||
| global time_stamp_first | |||
| time_stamp_current = time.time() | |||
| rpn_loss = self.rpn_loss_sum / self.count | |||
| rpn_cls_loss = self.rpn_cls_loss_sum / self.count | |||
| rpn_reg_loss = self.rpn_reg_loss_sum / self.count | |||
| loss_file = open("./loss_{}.log".format(self.rank_id), "a+") | |||
| loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"% | |||
| (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| rpn_loss, rpn_cls_loss, rpn_reg_loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| class LossNet(nn.Cell): | |||
| """FasterRcnn loss method""" | |||
| def construct(self, x1, x2, x3): | |||
| return x1 | |||
| class WithLossCell(nn.Cell): | |||
| """ | |||
| Wrap the network with loss function to compute loss. | |||
| Args: | |||
| backbone (Cell): The target network to wrap. | |||
| loss_fn (Cell): The loss function used to compute loss. | |||
| """ | |||
| def __init__(self, backbone, loss_fn): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._loss_fn = loss_fn | |||
| def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||
| rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||
| return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss) | |||
| @property | |||
| def backbone_network(self): | |||
| """ | |||
| Get the backbone network. | |||
| Returns: | |||
| Cell, return backbone network. | |||
| """ | |||
| return self._backbone | |||
| class TrainOneStepCell(nn.Cell): | |||
| """ | |||
| Network training package class. | |||
| Append an optimizer to the training network after that the construct function | |||
| can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. | |||
| network_backbone (Cell): The forward network. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default value is 1.0. | |||
| reduce_flag (bool): The reduce flag. Default value is False. | |||
| mean (bool): Allreduce method. Default value is False. | |||
| degree (int): Device number. Default value is None. | |||
| """ | |||
| def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.backbone = network_backbone | |||
| self.weights = ParameterTuple(network.trainable_params()) | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, | |||
| sens_param=True) | |||
| self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32)) | |||
| self.reduce_flag = reduce_flag | |||
| if reduce_flag: | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): | |||
| weights = self.weights | |||
| rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num) | |||
| grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens) | |||
| if self.reduce_flag: | |||
| grads = self.grad_reducer(grads) | |||
| return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================import numpy as np | |||
| import numpy as np | |||
| from src.text_connector.utils import clip_boxes, fit_y | |||
| from src.text_connector.get_successions import get_successions | |||
| def connect_text_lines(text_proposals, scores, size): | |||
| """ | |||
| Connect text lines | |||
| Args: | |||
| text_proposals(numpy.array): Predict text proposals. | |||
| scores(numpy.array): Bbox predicts scores. | |||
| size(numpy.array): Image size. | |||
| Returns: | |||
| text_recs(numpy.array): Text boxes after connect. | |||
| """ | |||
| graph = get_successions(text_proposals, scores, size) | |||
| text_lines = np.zeros((len(graph), 5), np.float32) | |||
| for index, indices in enumerate(graph): | |||
| text_line_boxes = text_proposals[list(indices)] | |||
| x0 = np.min(text_line_boxes[:, 0]) | |||
| x1 = np.max(text_line_boxes[:, 2]) | |||
| offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 | |||
| lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) | |||
| lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) | |||
| # the score of a text line is the average score of the scores | |||
| # of all text proposals contained in the text line | |||
| score = scores[list(indices)].sum() / float(len(indices)) | |||
| text_lines[index, 0] = x0 | |||
| text_lines[index, 1] = min(lt_y, rt_y) | |||
| text_lines[index, 2] = x1 | |||
| text_lines[index, 3] = max(lb_y, rb_y) | |||
| text_lines[index, 4] = score | |||
| text_lines = clip_boxes(text_lines, size) | |||
| text_recs = np.zeros((len(text_lines), 9), np.float) | |||
| index = 0 | |||
| for line in text_lines: | |||
| xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3] | |||
| text_recs[index, 0] = xmin | |||
| text_recs[index, 1] = ymin | |||
| text_recs[index, 2] = xmax | |||
| text_recs[index, 3] = ymax | |||
| text_recs[index, 4] = line[4] | |||
| index = index + 1 | |||
| return text_recs | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from src.config import config | |||
| from src.text_connector.utils import nms | |||
| from src.text_connector.connect_text_lines import connect_text_lines | |||
| def filter_proposal(proposals, scores): | |||
| """ | |||
| Filter text proposals | |||
| Args: | |||
| proposals(numpy.array): Text proposals. | |||
| Returns: | |||
| proposals(numpy.array): Text proposals after filter. | |||
| """ | |||
| inds = np.where(scores > config.text_proposals_min_scores)[0] | |||
| keep_proposals = proposals[inds] | |||
| keep_scores = scores[inds] | |||
| sorted_inds = np.argsort(keep_scores.ravel())[::-1] | |||
| keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds] | |||
| nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh) | |||
| keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds] | |||
| return keep_proposals, keep_scores | |||
| def filter_boxes(boxes): | |||
| """ | |||
| Filter text boxes | |||
| Args: | |||
| boxes(numpy.array): Text boxes. | |||
| Returns: | |||
| boxes(numpy.array): Text boxes after filter. | |||
| """ | |||
| heights = np.zeros((len(boxes), 1), np.float) | |||
| widths = np.zeros((len(boxes), 1), np.float) | |||
| scores = np.zeros((len(boxes), 1), np.float) | |||
| index = 0 | |||
| for box in boxes: | |||
| widths[index] = abs(box[2] - box[0]) | |||
| heights[index] = abs(box[3] - box[1]) | |||
| scores[index] = abs(box[4]) | |||
| index += 1 | |||
| return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\ | |||
| (widths > (config.text_proposals_width * config.min_num_proposals)))[0] | |||
| def detect(text_proposals, scores, size): | |||
| """ | |||
| Detect text boxes | |||
| Args: | |||
| text_proposals(numpy.array): Predict text proposals. | |||
| scores(numpy.array): Bbox predicts scores. | |||
| size(numpy.array): Image size. | |||
| Returns: | |||
| boxes(numpy.array): Text boxes after connect. | |||
| """ | |||
| keep_proposals, keep_scores = filter_proposal(text_proposals, scores) | |||
| connect_boxes = connect_text_lines(keep_proposals, keep_scores, size) | |||
| boxes = connect_boxes[filter_boxes(connect_boxes)] | |||
| return boxes | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| from src.config import config | |||
| from src.text_connector.utils import overlaps_v, size_similarity | |||
| def get_successions(text_proposals, scores, im_size): | |||
| """ | |||
| Get successions text boxes. | |||
| Args: | |||
| text_proposals(numpy.array): Predict text proposals. | |||
| scores(numpy.array): Bbox predicts scores. | |||
| size(numpy.array): Image size. | |||
| Returns: | |||
| sub_graph(list): Proposals graph. | |||
| """ | |||
| bboxes_table = [[] for _ in range(int(im_size[1]))] | |||
| for index, box in enumerate(text_proposals): | |||
| bboxes_table[int(box[0])].append(index) | |||
| graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) | |||
| for index, box in enumerate(text_proposals): | |||
| successions_left = [] | |||
| for left in range(int(box[0]) + 1, min(int(box[0]) + config.max_horizontal_gap + 1, im_size[1])): | |||
| adj_box_indices = bboxes_table[left] | |||
| for adj_box_index in adj_box_indices: | |||
| if meet_v_iou(text_proposals, adj_box_index, index): | |||
| successions_left.append(adj_box_index) | |||
| if successions_left: | |||
| break | |||
| if not successions_left: | |||
| continue | |||
| succession_index = successions_left[np.argmax(scores[successions_left])] | |||
| box_right = text_proposals[succession_index] | |||
| succession_right = [] | |||
| for right in range(int(box_right[0]) - 1, max(int(box_right[0] - config.max_horizontal_gap), 0) - 1, -1): | |||
| adj_box_indices = bboxes_table[right] | |||
| for adj_box_index in adj_box_indices: | |||
| if meet_v_iou(text_proposals, adj_box_index, index): | |||
| succession_right.append(adj_box_index) | |||
| if succession_right: | |||
| break | |||
| if scores[index] >= np.max(scores[succession_right]): | |||
| graph[index, succession_index] = True | |||
| sub_graph = get_sub_graph(graph) | |||
| return sub_graph | |||
| def get_sub_graph(graph): | |||
| """ | |||
| Get successions text boxes. | |||
| Args: | |||
| graph(numpy.array): proposal graph | |||
| Returns: | |||
| sub_graph(list): Proposals graph after connect. | |||
| """ | |||
| sub_graphs = [] | |||
| for index in range(graph.shape[0]): | |||
| if not graph[:, index].any() and graph[index, :].any(): | |||
| v = index | |||
| sub_graphs.append([v]) | |||
| while graph[v, :].any(): | |||
| v = np.where(graph[v, :])[0][0] | |||
| sub_graphs[-1].append(v) | |||
| return sub_graphs | |||
| def meet_v_iou(text_proposals, index1, index2): | |||
| """ | |||
| Calculate vertical iou. | |||
| Args: | |||
| text_proposals(numpy.array): tex proposals | |||
| index1(int): text_proposal index | |||
| tindex2(int): text proposal index | |||
| Returns: | |||
| sub_graph(list): Proposals graph after connect. | |||
| """ | |||
| heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 | |||
| return overlaps_v(text_proposals, index1, index2) >= config.min_v_overlaps and \ | |||
| size_similarity(heights, index1, index2) >= config.min_size_sim | |||
| @@ -0,0 +1,118 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| def threshold(coords, min_, max_): | |||
| return np.maximum(np.minimum(coords, max_), min_) | |||
| def clip_boxes(boxes, im_shape): | |||
| """ | |||
| Clip boxes to image boundaries. | |||
| Args: | |||
| boxes(numpy.array):bounding box. | |||
| im_shape(numpy.array): image shape. | |||
| Return: | |||
| boxes(numpy.array):boundding box after clip. | |||
| """ | |||
| boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1) | |||
| boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1) | |||
| return boxes | |||
| def overlaps_v(text_proposals, index1, index2): | |||
| """ | |||
| Calculate vertical overlap ratio. | |||
| Args: | |||
| text_proposals(numpy.array): Text proposlas. | |||
| index1(int): First text proposal. | |||
| index2(int): Second text proposal. | |||
| Return: | |||
| overlap(float32): vertical overlap. | |||
| """ | |||
| h1 = text_proposals[index1][3] - text_proposals[index1][1] + 1 | |||
| h2 = text_proposals[index2][3] - text_proposals[index2][1] + 1 | |||
| y0 = max(text_proposals[index2][1], text_proposals[index1][1]) | |||
| y1 = min(text_proposals[index2][3], text_proposals[index1][3]) | |||
| return max(0, y1 - y0 + 1) / min(h1, h2) | |||
| def size_similarity(heights, index1, index2): | |||
| """ | |||
| Calculate vertical size similarity ratio. | |||
| Args: | |||
| heights(numpy.array): Text proposlas heights. | |||
| index1(int): First text proposal. | |||
| index2(int): Second text proposal. | |||
| Return: | |||
| overlap(float32): vertical overlap. | |||
| """ | |||
| h1 = heights[index1] | |||
| h2 = heights[index2] | |||
| return min(h1, h2) / max(h1, h2) | |||
| def fit_y(X, Y, x1, x2): | |||
| if np.sum(X == X[0]) == len(X): | |||
| return Y[0], Y[0] | |||
| p = np.poly1d(np.polyfit(X, Y, 1)) | |||
| return p(x1), p(x2) | |||
| def nms(bboxs, thresh): | |||
| """ | |||
| Args: | |||
| text_proposals(numpy.array): tex proposals | |||
| index1(int): text_proposal index | |||
| tindex2(int): text proposal index | |||
| """ | |||
| x1, y1, x2, y2, scores = np.split(bboxs, 5, axis=1) | |||
| x1 = bboxs[:, 0] | |||
| y1 = bboxs[:, 1] | |||
| x2 = bboxs[:, 2] | |||
| y2 = bboxs[:, 3] | |||
| scores = bboxs[:, 4] | |||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
| order = scores.argsort()[::-1] | |||
| num_dets = bboxs.shape[0] | |||
| suppressed = np.zeros(num_dets, dtype=np.int32) | |||
| keep = [] | |||
| for _i in range(num_dets): | |||
| i = order[_i] | |||
| if suppressed[i] == 1: | |||
| continue | |||
| keep.append(i) | |||
| x1_i = x1[i] | |||
| y1_i = y1[i] | |||
| x2_i = x2[i] | |||
| y2_i = y2[i] | |||
| area_i = areas[i] | |||
| for _j in range(_i + 1, num_dets): | |||
| j = order[_j] | |||
| if suppressed[j] == 1: | |||
| continue | |||
| x1_j = max(x1_i, x1[j]) | |||
| y1_j = max(y1_i, y1[j]) | |||
| x2_j = min(x2_i, x2[j]) | |||
| y2_j = min(y2_i, y2[j]) | |||
| w = max(0.0, x2_j - x1_j + 1) | |||
| h = max(0.0, y2_j - y1_j + 1) | |||
| inter = w*h | |||
| overlap = inter / (area_i+areas[j]-inter) | |||
| if overlap >= thresh: | |||
| suppressed[j] = 1 | |||
| return keep | |||
| @@ -0,0 +1,119 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # less required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train CTPN and get checkpoint files.""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import ast | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context, Tensor | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn import Momentum | |||
| from mindspore.common import set_seed | |||
| from src.ctpn import CTPN | |||
| from src.config import config, pretrain_config, finetune_config | |||
| from src.dataset import create_ctpn_dataset | |||
| from src.lr_schedule import dynamic_lr | |||
| from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell | |||
| set_seed(1) | |||
| parser = argparse.ArgumentParser(description="CTPN training") | |||
| parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") | |||
| parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") | |||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") | |||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") | |||
| parser.add_argument("--task_type", type=str, default="Pretraining",\ | |||
| choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining") | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True) | |||
| if __name__ == '__main__': | |||
| if args_opt.run_distribute: | |||
| rank = args_opt.rank_id | |||
| device_num = args_opt.device_num | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| init() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| if args_opt.task_type == "Pretraining": | |||
| print("Start to do pretraining") | |||
| mindrecord_file = config.pretraining_dataset_file | |||
| training_cfg = pretrain_config | |||
| else: | |||
| print("Start to do finetune") | |||
| mindrecord_file = config.finetune_dataset_file | |||
| training_cfg = finetune_config | |||
| print("CHECKING MINDRECORD FILES ...") | |||
| while not os.path.exists(mindrecord_file + ".db"): | |||
| time.sleep(5) | |||
| print("CHECKING MINDRECORD FILES DONE!") | |||
| loss_scale = float(config.loss_scale) | |||
| # When create MindDataset, using the fitst mindrecord file, such as ctpn_pretrain.mindrecord0. | |||
| dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\ | |||
| batch_size=config.batch_size, device_num=device_num, rank_id=rank) | |||
| dataset_size = dataset.get_dataset_size() | |||
| net = CTPN(config=config, is_training=True) | |||
| net = net.set_train() | |||
| load_path = args_opt.pre_trained | |||
| if args_opt.task_type == "Pretraining": | |||
| print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained)) | |||
| param_dict = load_checkpoint(load_path) | |||
| for item in list(param_dict.keys()): | |||
| if not item.startswith('vgg16_feature_extractor'): | |||
| param_dict.pop(item) | |||
| load_param_into_net(net, param_dict) | |||
| else: | |||
| if load_path != "": | |||
| print("load pretrain ckpt {}".format(args_opt.pre_trained)) | |||
| param_dict = load_checkpoint(load_path) | |||
| load_param_into_net(net, param_dict) | |||
| loss = LossNet() | |||
| lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32) | |||
| opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\ | |||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| if args_opt.run_distribute: | |||
| net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, | |||
| mean=True, degree=device_num) | |||
| else: | |||
| net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) | |||
| time_cb = TimeMonitor(data_size=dataset_size) | |||
| loss_cb = LossCallBack(rank_id=rank) | |||
| cb = [time_cb, loss_cb] | |||
| if config.save_checkpoint: | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") | |||
| ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) | |||
| cb += [ckpoint_cb] | |||
| model = Model(net) | |||
| model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||