You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README_CN.md 9.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # 目录
  2. - [目录](#目录)
  3. - [PSENet概述](#psenet概述)
  4. - [PSENet示例](#psenet示例)
  5. - [概述](#概述)
  6. - [数据集](#数据集)
  7. - [环境要求](#环境要求)
  8. - [快速入门](#快速入门)
  9. - [脚本说明](#脚本说明)
  10. - [脚本和样例代码](#脚本和样例代码)
  11. - [脚本参数](#脚本参数)
  12. - [训练过程](#训练过程)
  13. - [分布式训练](#分布式训练)
  14. - [评估过程](#评估过程)
  15. - [运行测试代码](#运行测试代码)
  16. - [ICDAR2015评估脚本](#icdar2015评估脚本)
  17. - [用法](#用法)
  18. - [结果](#结果)
  19. - [模型描述](#模型描述)
  20. - [性能](#性能)
  21. - [评估性能](#评估性能)
  22. - [推理性能](#推理性能)
  23. - [使用方法](#使用方法)
  24. - [推理](#推理)
  25. <!-- /TOC -->
  26. # PSENet概述
  27. 随着卷积神经网络的发展,场景文本检测技术迅速发展,但其算法中存在的两大问题阻碍了这一技术的应用:第一,现有的大多数算法都需要四边形边框来精确定位任意形状的文本;第二,两个相邻文本可能会因错误检测而被覆盖。传统意义上,语义分割可以解决第一个问题,但无法解决第二个问题。而PSENet能够精确地检测出任意形状文本实例,同时解决了两个问题。具体地说,PSENet为每个文本实例生成不同的扩展内核,并逐渐将最小扩展内核扩展为具有完整形状的文本实例。由于最小内核之间的几何差别较大,PSNet可以有效分割封闭的文本实例,更容易地检测任意形状文本实例。通过在CTW1500、全文、ICDAR 2015和ICDAR 2017 MLT中进行多次实验,PSENet的有效性得以验证。
  28. [论文](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Shape_Robust_Text_Detection_With_Progressive_Scale_Expansion_Network_CVPR_2019_paper.html): Wenhai Wang, Enze Xie, Xiang Li, Wenbo Hou, Tong Lu, Gang Yu, Shuai Shao; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2019, pp. 9336-9345
  29. # PSENet示例
  30. ## 概述
  31. 渐进尺度扩展网络(PSENet)是一种能够很好地检测自然场景中任意形状文本的文本检测器。
  32. # 数据集
  33. 使用的数据集:[ICDAR2015](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)
  34. 训练集:包括约4500个可读单词的1000张图像。
  35. 测试集:约2000个可读单词。
  36. # 环境要求
  37. - 硬件:昇腾处理器(Ascend)
  38. - 使用Ascend处理器来搭建硬件环境。
  39. - 框架
  40. - [MindSpore](https://www.mindspore.cn/install)
  41. - 如需查看详情,请参见如下资源:
  42. - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
  43. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
  44. - 安装Mindspore
  45. - 安装[pyblind11](https://github.com/pybind/pybind11)
  46. - 安装[Opencv3.4](https://docs.opencv.org/3.4.9/)
  47. # 快速入门
  48. 通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
  49. ```python
  50. # 分布式训练运行示例
  51. sh scripts/run_distribute_train.sh pretrained_model.ckpt
  52. # 下载opencv库
  53. download pyblind11, opencv3.4
  54. # 安装pyblind11 opencv3.4
  55. setup pyblind11(install the library by the pip command)
  56. setup opencv3.4(compile source code install the library)
  57. # 输入路径,运行Makefile,找到产品文件
  58. cd ./src/ETSNET/pse/;make
  59. # 运行test.py
  60. python test.py --ckpt=pretrained_model.ckpt
  61. # 单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法
  62. # 点击"我的方法"按钮,下载评估脚本
  63. download script.py
  64. # 运行评估示例
  65. sh scripts/run_eval_ascend.sh
  66. ```
  67. ## 脚本说明
  68. ## 脚本和样例代码
  69. ```path
  70. └── PSENet
  71. ├── export.py // mindir转换脚本
  72. ├── mindspore_hub_conf.py // 网络模型
  73. ├── README.md // PSENet相关描述英文版
  74. ├── README_CN.md // PSENet相关描述中文版
  75. ├── scripts
  76. ├── run_distribute_train.sh // 用于分布式训练的shell脚本
  77. └── run_eval_ascend.sh // 用于评估的shell脚本
  78. ├── src
  79. ├── config.py // 参数配置
  80. ├── dataset.py // 创建数据集
  81. ├── ETSNET
  82. ├── base.py // 卷积和BN算子
  83. ├── dice_loss.py // 计算PSENet损耗值
  84. ├── etsnet.py // PSENet中的子网
  85. ├── fpn.py // PSENet中的子网
  86. ├── __init__.py
  87. ├── pse // PSENet中的子网
  88. ├── adaptor.cpp
  89. ├── adaptor.h
  90. ├── __init__.py
  91. ├── Makefile
  92. ├── resnet50 // PSENet中的子网
  93. ├── __init__.py
  94. ├── lr_schedule.py // 学习率
  95. ├── network_define.py // PSENet架构
  96. ├── test.py // 测试脚本
  97. ├── train.py // 训练脚本
  98. ```
  99. ## 脚本参数
  100. ```python
  101. train.py和config.py中主要参数如下:
  102. -- pre_trained:是从零开始训练还是基于预训练模型训练。可选值为True、False。
  103. -- device_id:用于训练或评估数据集的设备ID。当使用train.sh进行分布式训练时,忽略此参数。
  104. -- device_num:使用train.sh进行分布式训练时使用的设备。
  105. ```
  106. ## 训练过程
  107. ### 分布式训练
  108. ```shell
  109. sh scripts/run_distribute_train.sh pretrained_model.ckpt
  110. ```
  111. 上述shell脚本将在后台运行分布训练。可以通过`device[X]/test_*.log`文件查看结果。
  112. 采用以下方式达到损失值:
  113. ```log
  114. # grep "epoch:" device_*/loss.log
  115. device_0/log:epoch: 1, step: 20,loss is 0.80383
  116. device_0/log:epcoh: 2, step: 40,loss is 0.77951
  117. ...
  118. device_1/log:epoch: 1, step: 20,loss is 0.78026
  119. device_1/log:epcoh: 2, step: 40,loss is 0.76629
  120. ```
  121. ## 评估过程
  122. ### 运行测试代码
  123. python test.py --ckpt=./device*/ckpt*/ETSNet-*.ckpt
  124. ### ICDAR2015评估脚本
  125. #### 用法
  126. 第一步:单击[此处](https://rrc.cvc.uab.es/?ch=4&com=tasks#TextLocalization)下载评估方法。
  127. 第二步:单击"我的方法"按钮,下载评估脚本。
  128. 第三步:建议将评估方法根符号链接到$MINDSPORE/model_zoo/psenet/eval_ic15/。如果您的文件夹结构不同,您可能需要更改评估脚本文件中的相应路径。
  129. ```shell
  130. sh ./script/run_eval_ascend.sh.sh
  131. ```
  132. #### 结果
  133. Calculated!{"precision": 0.8147966668299853,"recall":0.8006740491092923,"hmean":0.8076736279747451,"AP":0}
  134. # 模型描述
  135. ## 性能
  136. ### 评估性能
  137. | 参数 | Ascend |
  138. | -------------------------- | ----------------------------------------------------------- |
  139. | 模型版本 | PSENet |
  140. | 资源 | Ascend 910; CPU 2.60GHz,192内核;内存 755G;系统 Euler2.8 |
  141. | 上传日期 | 2020-09-15 |
  142. | MindSpore版本 | 1.0.0 |
  143. | 数据集 | ICDAR2015 |
  144. | 训练参数 | start_lr=0.1; lr_scale=0.1 |
  145. | 优化器 | SGD |
  146. | 损失函数 | LossCallBack |
  147. | 输出 | 概率 |
  148. | 损失 | 0.35 |
  149. | 速度 | 1卡:444毫秒/步;8卡:446毫秒/步
  150. | 总时间 | 1卡:75.48小时;8卡:7.11小时|
  151. | 参数(M) | 27.36 |
  152. | 微调检查点 | 109.44M (.ckpt file) |
  153. | 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/psenet> |
  154. ### 推理性能
  155. | 参数 | Ascend |
  156. | ------------------- | --------------------------- |
  157. | 模型版本 | PSENet |
  158. | 资源 | Ascend 910;系统 Euler2.8 |
  159. | 上传日期 | 2020/09/15 |
  160. | MindSpore版本 | 1.0.0 |
  161. | 数据集| ICDAR2015 |
  162. | 输出 | 概率 |
  163. | 准确性 | 1卡:81%; 8卡:81% |
  164. ## 使用方法
  165. ### 推理
  166. 如果您需要使用已训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理,可参考[此处](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。操作示例如下:
  167. ```python
  168. # 加载未知数据集进行推理
  169. dataset = dataset.create_dataset(cfg.data_path, 1, False)
  170. # 定义模型
  171. config.INFERENCE = False
  172. net = ETSNet(config)
  173. net = net.set_train()
  174. param_dict = load_checkpoint(args.pre_trained)
  175. load_param_into_net(net, param_dict)
  176. print('Load Pretrained parameters done!')
  177. criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
  178. lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER)
  179. opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
  180. # 模型变形
  181. net = WithLossCell(net, criterion)
  182. net = TrainOneStepCell(net, opt)
  183. time_cb = TimeMonitor(data_size=step_size)
  184. loss_cb = LossCallBack(per_print_times=20)
  185. # 设置并应用检查点参数
  186. ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
  187. ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
  188. model = Model(net)
  189. model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
  190. # 加载预训练模型
  191. param_dict = load_checkpoint(cfg.checkpoint_path)
  192. load_param_into_net(net, param_dict)
  193. net.set_train(False)
  194. # 对未知数据集进行预测
  195. acc = model.eval(dataset)
  196. print("accuracy: ", acc)
  197. ```