You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README_CN.md 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # 目录
  2. <!-- TOC -->
  3. - [目录](#目录)
  4. - [CNN+CTC描述](#cnnctc描述)
  5. - [模型架构](#模型架构)
  6. - [数据集](#数据集)
  7. - [特性](#特性)
  8. - [混合精度](#混合精度)
  9. - [环境要求](#环境要求)
  10. - [快速入门](#快速入门)
  11. - [脚本说明](#脚本说明)
  12. - [脚本及样例代码](#脚本及样例代码)
  13. - [脚本参数](#脚本参数)
  14. - [训练过程](#训练过程)
  15. - [训练](#训练)
  16. - [训练结果](#训练结果)
  17. - [评估过程](#评估过程)
  18. - [评估](#评估)
  19. - [模型描述](#模型描述)
  20. - [性能](#性能)
  21. - [训练性能](#训练性能)
  22. - [评估性能](#评估性能)
  23. - [用法](#用法)
  24. - [推理](#推理)
  25. - [在预训练模型上继续训练](#在预训练模型上继续训练)
  26. - [ModelZoo主页](#modelzoo主页)
  27. <!-- /TOC -->
  28. # CNN+CTC描述
  29. 本文描述了对场景文本识别(STR)的三个主要贡献。
  30. 首先检查训练和评估数据集不一致的内容,以及导致的性能差距。
  31. 再引入一个统一的四阶段STR框架,目前大多数STR模型都能够适应这个框架。
  32. 使用这个框架可以广泛评估以前提出的STR模块,并发现以前未开发的模块组合。
  33. 第三,分析在一致的训练和评估数据集下,模块对性能的贡献,包括准确率、速度和内存需求。
  34. 这些分析清除了当前比较的障碍,有助于了解现有模块的性能增益。
  35. [论文](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
  36. # 模型架构
  37. 示例:在MindSpore上使用MJSynth和SynthText数据集训练CNN+CTC模型进行文本识别。
  38. # 数据集
  39. [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/)和[SynthText](https://github.com/ankush-me/SynthText)数据集用于模型训练。[The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset)数据集用于评估。
  40. - 步骤1:
  41. 所有数据集均经过预处理,以.lmdb格式存储,点击[**此处**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)可下载。
  42. - 步骤2:
  43. 解压下载的文件,重命名MJSynth数据集为MJ,SynthText数据集为ST,IIIT数据集为IIIT。
  44. - 步骤3:
  45. 将上述三个数据集移至`cnctc_data`文件夹中,结构如下:
  46. ```python
  47. |--- CNNCTC/
  48. |--- cnnctc_data/
  49. |--- ST/
  50. data.mdb
  51. lock.mdb
  52. |--- MJ/
  53. data.mdb
  54. lock.mdb
  55. |--- IIIT/
  56. data.mdb
  57. lock.mdb
  58. ......
  59. ```
  60. - 步骤4:
  61. 预处理数据集:
  62. ```shell
  63. python src/preprocess_dataset.py
  64. ```
  65. 这大约需要75分钟。
  66. # 特性
  67. ## 混合精度
  68. 采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
  69. 以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
  70. # 环境要求
  71. - 硬件(Ascend)
  72. - 准备Ascend或GPU处理器搭建硬件环境。
  73. - 框架
  74. - [MindSpore](https://www.mindspore.cn/install)
  75. - 如需查看详情,请参见如下资源:
  76. - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
  77. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
  78. # 快速入门
  79. - 安装依赖:
  80. ```python
  81. pip install lmdb
  82. pip install Pillow
  83. pip install tqdm
  84. pip install six
  85. ```
  86. - 单机训练:
  87. ```shell
  88. bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
  89. ```
  90. - 分布式训练:
  91. ```shell
  92. bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
  93. ```
  94. - 评估:
  95. ```shell
  96. bash scripts/run_eval_ascend.sh $TRAINED_CKPT
  97. ```
  98. # 脚本说明
  99. ## 脚本及样例代码
  100. 完整代码结构如下:
  101. ```python
  102. |--- CNNCTC/
  103. |---README.md // CNN+CTC相关描述
  104. |---train.py // 训练脚本
  105. |---eval.py // 评估脚本
  106. |---scripts
  107. |---run_standalone_train_ascend.sh // Ascend单机shell脚本
  108. |---run_distribute_train_ascend.sh // Ascend分布式shell脚本
  109. |---run_eval_ascend.sh // Ascend评估shell脚本
  110. |---src
  111. |---__init__.py // init文件
  112. |---cnn_ctc.py // cnn_ctc网络
  113. |---config.py // 总配置
  114. |---callback.py // 损失回调文件
  115. |---dataset.py // 处理数据集
  116. |---util.py // 常规操作
  117. |---generate_hccn_file.py // 生成分布式json文件
  118. |---preprocess_dataset.py // 预处理数据集
  119. ```
  120. ## 脚本参数
  121. 在`config.py`中可以同时配置训练参数和评估参数。
  122. 参数:
  123. - `--CHARACTER`:字符标签。
  124. - `--NUM_CLASS`:类别数,包含所有字符标签和CTCLoss的<blank>标签。
  125. - `--HIDDEN_SIZE`:模型隐藏大小。
  126. - `--FINAL_FEATURE_WIDTH`:特性的数量。
  127. - `--IMG_H`:输入图像高度。
  128. - `--IMG_W`:输入图像宽度。
  129. - `--TRAIN_DATASET_PATH`:训练数据集的路径。
  130. - `--TRAIN_DATASET_INDEX_PATH`:决定顺序的训练数据集索引文件的路径。
  131. - `--TRAIN_BATCH_SIZE`:训练批次大小。在批次大小和索引文件中,必须确保输入数据是固定的形状。
  132. - `--TRAIN_DATASET_SIZE`:训练数据集大小。
  133. - `--TEST_DATASET_PATH`:测试数据集的路径。
  134. - `--TEST_BATCH_SIZE`:测试批次大小。
  135. - `--TRAIN_EPOCHS`:总训练轮次。
  136. - `--CKPT_PATH`:模型检查点文件路径,可用于恢复训练和评估。
  137. - `--SAVE_PATH`:模型检查点文件保存路径。
  138. - `--LR`:单机训练学习率。
  139. - `--LR_PARA`:分布式训练学习率。
  140. - `--Momentum`:动量。
  141. - `--LOSS_SCALE`:损失放大,避免梯度下溢。
  142. - `--SAVE_CKPT_PER_N_STEP`:每N步保存模型检查点文件。
  143. - `--KEEP_CKPT_MAX_NUM`:模型检查点文件保存数量上限。
  144. ## 训练过程
  145. ### 训练
  146. - 单机训练:
  147. ```shell
  148. bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
  149. ```
  150. 结果和检查点被写入`./train`文件夹。日志可以在`./train/log`中找到,损失值记录在`./train/loss.log`中。
  151. `$PRETRAINED_CKPT`为模型检查点的路径,**可选**。如果值为none,模型将从头开始训练。
  152. - 分布式训练:
  153. ```shell
  154. bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
  155. ```
  156. 结果和检查点分别写入设备`i`的`./train_parallel_{i}`文件夹。
  157. 日志可以在`./train_parallel_{i}/log_{i}.log`中找到,损失值记录在`./train_parallel_{i}/loss.log`中。
  158. 在Ascend上运行分布式任务时需要`$RANK_TABLE_FILE`。
  159. `$PATH_TO_CHECKPOINT`为模型检查点的路径,**可选**。如果值为none,模型将从头开始训练。
  160. ### 训练结果
  161. 训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
  162. ```python
  163. # 分布式训练结果(8P)
  164. epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712
  165. epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203
  166. epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573
  167. epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527
  168. epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406
  169. epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215
  170. ...
  171. epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549
  172. epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116
  173. epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555
  174. epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375
  175. epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031
  176. epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573
  177. epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345
  178. epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777
  179. epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694
  180. epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257
  181. ```
  182. ## 评估过程
  183. ### 评估
  184. - 评估:
  185. ```shell
  186. bash scripts/run_eval_ascend.sh $TRAINED_CKPT
  187. ```
  188. 在IIIT数据集上评估模型,并打印样本结果和总准确率。
  189. # 模型描述
  190. ## 性能
  191. ### 训练性能
  192. | 参数 | CNNCTC |
  193. | -------------------------- | ----------------------------------------------------------- |
  194. | 模型版本 | V1 |
  195. | 资源 | Ascend 910;CPU 2.60GHz,192核;内存:755G |
  196. | 上传日期 | 2020-09-28 |
  197. | MindSpore版本 | 1.0.0 |
  198. | 数据集 | MJSynth、SynthText |
  199. | 训练参数 | epoch=3, batch_size=192 |
  200. | 优化器 | RMSProp |
  201. | 损失函数 | CTCLoss |
  202. | 速度 | 1卡:300毫秒/步;8卡:310毫秒/步 |
  203. | 总时间 | 1卡:18小时;8卡:2.3小时 |
  204. | 参数(M) | 177 |
  205. | 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc> |
  206. ### 评估性能
  207. | 参数 | CNNCTC |
  208. | ------------------- | --------------------------- |
  209. | 模型版本 | V1 |
  210. | 资源 | Ascend 910 |
  211. | 上传日期 | 2020-09-28 |
  212. | MindSpore版本 | 1.0.0 |
  213. | 数据集 | IIIT5K |
  214. | batch_size | 192 |
  215. | 输出 |准确率 |
  216. | 准确率 | 85% |
  217. | 推理模型 | 675M(.ckpt文件) |
  218. ## 用法
  219. ### 推理
  220. 如果您需要在GPU、Ascend 910、Ascend 310等多个硬件平台上使用训练好的模型进行推理,请参考此[链接](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。以下为简单示例:
  221. - Ascend处理器环境运行
  222. ```python
  223. # 设置上下文
  224. context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
  225. context.set_context(device_id=cfg.device_id)
  226. # 加载未知数据集进行推理
  227. dataset = dataset.create_dataset(cfg.data_path, 1, False)
  228. # 定义模型
  229. net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
  230. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
  231. cfg.momentum, weight_decay=cfg.weight_decay)
  232. loss = P.CTCLoss(preprocess_collapse_repeated=False,
  233. ctc_merge_repeated=True,
  234. ignore_longer_outputs_than_inputs=False)
  235. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
  236. # 加载预训练模型
  237. param_dict = load_checkpoint(cfg.checkpoint_path)
  238. load_param_into_net(net, param_dict)
  239. net.set_train(False)
  240. # Make predictions on the unseen dataset
  241. acc = model.eval(dataset)
  242. print("accuracy: ", acc)
  243. ```
  244. ### 在预训练模型上继续训练
  245. - Ascend处理器环境运行
  246. ```python
  247. # 加载数据集
  248. dataset = create_dataset(cfg.data_path, 1)
  249. batch_num = dataset.get_dataset_size()
  250. # 定义模型
  251. net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
  252. # 如果pre_trained为True,则继续训练
  253. if cfg.pre_trained:
  254. param_dict = load_checkpoint(cfg.checkpoint_path)
  255. load_param_into_net(net, param_dict)
  256. lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
  257. steps_per_epoch=batch_num)
  258. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
  259. Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
  260. loss = P.CTCLoss(preprocess_collapse_repeated=False,
  261. ctc_merge_repeated=True,
  262. ignore_longer_outputs_than_inputs=False)
  263. model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
  264. amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
  265. # 设置回调
  266. config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
  267. keep_checkpoint_max=cfg.keep_checkpoint_max)
  268. time_cb = TimeMonitor(data_size=batch_num)
  269. ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
  270. config=config_ck)
  271. loss_cb = LossMonitor()
  272. # 开始训练
  273. model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
  274. print("train success")
  275. ```
  276. # ModelZoo主页
  277. 请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。