You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 12 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Contents
  2. - [PSENet Description](#PSENet-description)
  3. - [Dataset](#dataset)
  4. - [Features](#features)
  5. - [Mixed Precision](#mixed-precision)
  6. - [Environment Requirements](#environment-requirements)
  7. - [Quick Start](#quick-start)
  8. - [Script Description](#script-description)
  9. - [Script and Sample Code](#script-and-sample-code)
  10. - [Script Parameters](#script-parameters)
  11. - [Training Process](#training-process)
  12. - [Training](#training)
  13. - [Distributed Training](#distributed-training)
  14. - [Evaluation Process](#evaluation-process)
  15. - [Evaluation](#evaluation)
  16. - [Model Description](#model-description)
  17. - [Performance](#performance)
  18. - [Evaluation Performance](#evaluation-performance)
  19. - [Inference Performance](#evaluation-performance)
  20. - [How to use](#how-to-use)
  21. - [Inference](#inference)
  22. - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
  23. - [Transfer Learning](#transfer-learning)
  24. # [PSENet Description](#contents)
  25. With the development of convolutional neural network, scene text detection technology has been developed rapidly. However, there are still two problems in this algorithm, which hinders its application in industry. On the one hand, most of the existing algorithms require quadrilateral bounding boxes to accurately locate arbitrary shape text. On the other hand, two adjacent instances of text can cause error detection overwriting both instances. Traditionally, a segmentation-based approach can solve the first problem, but usually not the second. To solve these two problems, a new PSENet (PSENet) is proposed, which can accurately detect arbitrary shape text instances. More specifically, PSENet generates different scale kernels for each text instance and gradually expands the minimum scale kernel to a text instance with full shape. Because of the large geometric margins between the minimum scale kernels, our method can effectively segment closed text instances, making it easier to detect arbitrary shape text instances. The effectiveness of PSENet has been verified by numerous experiments on CTW1500, full text, ICDAR 2015, and ICDAR 2017 MLT.
  26. [Paper](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345
  27. # PSENet Example
  28. ## Description
  29. Progressive Scale Expansion Network (PSENet) is a text detector which is able to well detect the arbitrary-shape text in natural scene.
  30. # [Dataset](#contents)
  31. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
  32. Dataset used: [ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)
  33. A training set of 1000 images containing about 4500 readable words
  34. A testing set containing about 2000 readable words
  35. # [Environment Requirements](#contents)
  36. - Hardware(Ascend)
  37. - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
  38. - Framework
  39. - [MindSpore](http://www.mindspore.cn/install/en)
  40. - For more information, please check the resources below:
  41. - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
  42. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
  43. - install Mindspore
  44. - install [pyblind11](https://github.com/pybind/pybind11)
  45. - install [Opencv3.4](https://docs.opencv.org/3.4.9/)
  46. # [Quick Start](#contents)
  47. After installing MindSpore via the official website, you can start training and evaluation as follows:
  48. ```python
  49. # run distributed training example
  50. sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt
  51. #download opencv library
  52. download pyblind11, opencv3.4
  53. #install pyblind11 opencv3.4
  54. setup pyblind11(install the library by the pip command)
  55. setup opencv3.4(compile source code install the library)
  56. #enter the path ,run Makefile to product file
  57. cd ./src/ETSNET/pse/;make
  58. #run test.py
  59. python test.py --ckpt=pretrained_model.ckpt
  60. #download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
  61. #click "My Methods" button,then download Evaluation Scripts
  62. download script.py
  63. # run evaluation example
  64. sh scripts/run_eval_ascend.sh
  65. ```
  66. # [Script Description](#contents)
  67. ## [Script and Sample Code](#contents)
  68. ```path
  69. └── PSENet
  70. ├── README.md // descriptions about PSENet
  71. ├── scripts
  72. ├── run_distribute_train.sh // shell script for distributed
  73. └── run_eval_ascend.sh // shell script for evaluation
  74. ├──src
  75. ├── __init__.py
  76. ├── ETSNET
  77. ├── __init__.py
  78. ├── base.py // convolution and BN operator
  79. ├── dice_loss.py // calculate PSENet loss value
  80. ├── etsnet.py // Subnet in PSENet
  81. ├── fpn.py // Subnet in PSENet
  82. ├── resnet50.py // Subnet in PSENet
  83. ├── pse // Subnet in PSENet
  84. ├── __init__.py
  85. ├── adaptor.cpp
  86. ├── adaptor.h
  87. ├── Makefile
  88. ├──config.py // parameter configuration
  89. ├──dataset.py // creating dataset
  90. ├──network_define.py // learning ratio generation
  91. ├──export.py // export mindir file
  92. ├──mindspore_hub_conf.py // hub config file
  93. ├──test.py // test script
  94. ├──train.py // training script
  95. ```
  96. ## [Script Parameters](#contents)
  97. ```python
  98. Major parameters in train.py and config.py are:
  99. --pre_trained: Whether training from scratch or training based on the
  100. pre-trained model.Optional values are True, False.
  101. --device_id: Device ID used to train or evaluate the dataset. Ignore it
  102. when you use train.sh for distributed training.
  103. --device_num: devices used when you use train.sh for distributed training.
  104. ```
  105. ## [Training Process](#contents)
  106. ### Distributed Training
  107. ```shell
  108. sh scripts/run_distribute_train.sh rank_table_file pretrained_model.ckpt
  109. ```
  110. rank_table_file which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
  111. The above shell script will run distribute training in the background. You can view the results through the file
  112. `device[X]/test_*.log`. The loss value will be achieved as follows:
  113. ```log
  114. # grep "epoch: " device_*/loss.log
  115. device_0/log:epoch: 1, step: 20, loss is 0.80383
  116. device_0/log:epcoh: 2, step: 40, loss is 0.77951
  117. ...
  118. device_1/log:epoch: 1, step: 20, loss is 0.78026
  119. device_1/log:epcoh: 2, step: 40, loss is 0.76629
  120. ```
  121. ## [Evaluation Process](#contents)
  122. ### run test code
  123. python test.py --ckpt=./device*/ckpt*/ETSNet-*.ckpt
  124. ### Eval Script for ICDAR2015
  125. #### Usage
  126. step 1: download eval method from [here](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization).
  127. step 2: click "My Methods" button,then download Evaluation Scripts.
  128. step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files.
  129. ```shell
  130. sh ./script/run_eval_ascend.sh.sh
  131. ```
  132. #### Result
  133. Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0}
  134. # [Model Description](#contents)
  135. ## [Performance](#contents)
  136. ### Evaluation Performance
  137. | Parameters | PSENet |
  138. | -------------------------- | ----------------------------------------------------------- |
  139. | Model Version | V1 |
  140. | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
  141. | uploaded Date | 09/30/2020 (month/day/year) |
  142. | MindSpore Version | 1.0.0 |
  143. | Dataset | ICDAR2015 |
  144. | Training Parameters | start_lr=0.1; lr_scale=0.1 |
  145. | Optimizer | SGD |
  146. | Loss Function | LossCallBack |
  147. | outputs | probability |
  148. | Loss | 0.35 |
  149. | Speed | 1pc: 444 ms/step; 8pcs: 446 ms/step |
  150. | Total time | 1pc: 75.48 h; 8pcs: 10.01 h |
  151. | Parameters (M) | 27.36 |
  152. | Checkpoint for Fine tuning | 109.44M (.ckpt file) |
  153. | Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/psenet> |
  154. ### Inference Performance
  155. | Parameters | PSENet |
  156. | ------------------- | --------------------------- |
  157. | Model Version | V1 |
  158. | Resource | Ascend 910 |
  159. | Uploaded Date | 09/30/2020 (month/day/year) |
  160. | MindSpore Version | 1.0,0 |
  161. | Dataset | ICDAR2015 |
  162. | outputs | probability |
  163. | Accuracy | 1pc: 81%; 8pcs: 81% |
  164. ## [How to use](#contents)
  165. ### Inference
  166. If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html). Following the steps below, this is a simple example:
  167. ```python
  168. # Load unseen dataset for inference
  169. dataset = dataset.create_dataset(cfg.data_path, 1, False)
  170. # Define model
  171. config.INFERENCE = False
  172. net = ETSNet(config)
  173. net = net.set_train()
  174. param_dict = load_checkpoint(args.pre_trained)
  175. load_param_into_net(net, param_dict)
  176. print('Load Pretrained parameters done!')
  177. criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
  178. lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
  179. opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
  180. # warp model
  181. net = WithLossCell(net, criterion)
  182. net = TrainOneStepCell(net, opt)
  183. time_cb = TimeMonitor(data_size=step_size)
  184. loss_cb = LossCallBack(per_print_times=20)
  185. # set and apply parameters of check point
  186. ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
  187. ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
  188. model = Model(net)
  189. model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
  190. # Load pre-trained model
  191. param_dict = load_checkpoint(cfg.checkpoint_path)
  192. load_param_into_net(net, param_dict)
  193. net.set_train(False)
  194. # Make predictions on the unseen dataset
  195. acc = model.eval(dataset)
  196. print("accuracy: ", acc)
  197. ```