You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README_CN.md 26 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. # 目录
  2. <!-- TOC -->
  3. - [目录](#目录)
  4. - [掩式序列到序列(MASS)预训练语言生成](#掩式序列到序列mass预训练语言生成)
  5. - [模型架构](#模型架构)
  6. - [数据集](#数据集)
  7. - [特性](#特性)
  8. - [脚本说明](#脚本说明)
  9. - [准备数据集](#准备数据集)
  10. - [标记化](#标记化)
  11. - [字节对编码](#字节对编码)
  12. - [构建词汇表](#构建词汇表)
  13. - [生成数据集](#生成数据集)
  14. - [News Crawl语料库](#news-crawl语料库)
  15. - [Gigaword语料库](#gigaword语料库)
  16. - [Cornell电影对白语料库](#cornell电影对白语料库)
  17. - [配置](#配置)
  18. - [训练&评估过程](#训练评估过程)
  19. - [权重平均值](#权重平均值)
  20. - [学习速率调度器](#学习速率调度器)
  21. - [环境要求](#环境要求)
  22. - [平台](#平台)
  23. - [其他要求](#其他要求)
  24. - [快速上手](#快速上手)
  25. - [预训练](#预训练)
  26. - [微调](#微调)
  27. - [推理](#推理)
  28. - [性能](#性能)
  29. - [结果](#结果)
  30. - [文本摘要微调](#文本摘要微调)
  31. - [会话应答微调](#会话应答微调)
  32. - [训练性能](#训练性能)
  33. - [推理性能](#推理性能)
  34. - [随机情况说明](#随机情况说明)
  35. - [其他](#其他)
  36. - [ModelZoo主页](#modelzoo主页)
  37. <!-- /TOC -->
  38. # 掩式序列到序列(MASS)预训练语言生成
  39. [掩式序列到序列(MASS)预训练语言生成](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf)由微软于2019年6月发布。
  40. BERT(Devlin等人,2018年)采用有屏蔽的语料丰富文本预训练Transformer的编码器部分(Vaswani等人,2017年),已在自然语言理解领域实现了性能最优(SOTA)。不仅如此,GPT(Raddford等人,2018年)也采用了有屏蔽的语料丰富文本对Transformer的解码器部分进行预训练(屏蔽了编码器输入)。两者都通过预训练有屏蔽的语料丰富文本来构建一个健壮的语言模型。
  41. 受BERT、GPT及其他语言模型的启发,微软致力于在此基础上研究[掩式序列到序列(MASS)预训练语言生成](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf)。MASS的参数k很重要,用来控制屏蔽后的分片长度。BERT和GPT属于特例,k等于1或者句长。
  42. [MASS介绍 — 序列对序列语言生成任务中性能优于BERT和GPT的预训练方法](https://www.microsoft.com/en-us/research/blog/introducing-mass-a-pre-training-method-that-outperforms-bert-and-gpt-in-sequence-to-sequence-language-generation-tasks/)
  43. [论文](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf): Song, Kaitao, Xu Tan, Tao Qin, Jianfeng Lu and Tie-Yan Liu.“MASS: Masked Sequence to Sequence Pre-training for Language Generation.”ICML (2019).
  44. # 模型架构
  45. MASS网络由Transformer实现,Transformer包括多个编码器层和多个解码器层。
  46. 预训练中,采用Adam优化器和损失放大来得到预训练后的模型。
  47. 微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。
  48. 测试过程中,通过微调后的模型预测结果,并采用波束搜索算法
  49. 获取可能性最高的预测结果。
  50. # 数据集
  51. 本文运用数据集包括:
  52. - News Crawl数据集(WMT,2019年)的英语单语数据,用于预训练
  53. - Gigaword语料库(Graff等人,2003年),用于文本摘要
  54. - Cornell电影对白语料库(DanescuNiculescu-Mizil & Lee,2011年)
  55. 数据集相关信息,参见[MASS:语言生成的隐式序列到序列预训练](https://www.microsoft.com/en-us/research/uploads/prod/2019/06/MASS-paper-updated-002.pdf)。
  56. ## 特性
  57. MASS设计联合预训练编码器和解码器,来完成语言生成任务。
  58. 首先,通过序列到序列的框架,MASS只预测阻塞的标记,迫使编码器理解未屏蔽标记的含义,并鼓励解码器从编码器中提取有用信息。
  59. 其次,通过预测解码器的连续标记,可以建立比仅预测离散标记更好的语言建模能力。
  60. 第三,通过进一步屏蔽编码器中未屏蔽的解码器的输入标记,鼓励解码器从编码器侧提取更有用的信息,而不是使用前一个标记中的丰富信息。
  61. ## 脚本说明
  62. MASS脚本及代码结构如下:
  63. ```text
  64. ├── mass
  65. ├── README.md // MASS模型介绍
  66. ├── config
  67. │ ├──config.py // 配置实例定义
  68. │ ├──config.json // 配置文件
  69. ├──src
  70. │ ├──dataset
  71. │ ├──bi_data_loader.py // 数据集加载器,用于微调或推理
  72. │ ├──mono_data_loader.py // 预训练数据集加载器
  73. │ ├──language_model
  74. │ ├──noise_channel_language_model.p // 数据集生成噪声通道语言模型
  75. │ ├──mass_language_model.py // 基于MASS论文的MASS语言模型
  76. │ ├──loose_masked_language_model.py // 基于MASS发布代码的MASS语言模型
  77. │ ├──masked_language_model.py // 基于MASS论文的MASS语言模型
  78. │ ├──transformer
  79. │ ├──create_attn_mask.py // 生成屏蔽矩阵,除去填充部分
  80. │ ├──transformer.py // Transformer模型架构
  81. │ ├──encoder.py // Transformer编码器组件
  82. │ ├──decoder.py // Transformer解码器组件
  83. │ ├──self_attention.py // 自注意块组件
  84. │ ├──multi_head_attention.py // 多头自注意组件
  85. │ ├──embedding.py // 嵌入组件
  86. │ ├──positional_embedding.py // 位置嵌入组件
  87. │ ├──feed_forward_network.py // 前馈网络
  88. │ ├──residual_conn.py // 残留块
  89. │ ├──beam_search.py // 推理所用的波束搜索解码器
  90. │ ├──transformer_for_infer.py // 使用Transformer进行推理
  91. │ ├──transformer_for_train.py // 使用Transformer进行训练
  92. │ ├──utils
  93. │ ├──byte_pair_encoding.py // 使用subword-nmt应用字节对编码(BPE)
  94. │ ├──dictionary.py // 字典
  95. │ ├──loss_moniter.py // 训练步骤中损失监控回调
  96. │ ├──lr_scheduler.py // 学习速率调度器
  97. │ ├──ppl_score.py // 基于N-gram的困惑度评分
  98. │ ├──rouge_score.py // 计算ROUGE得分
  99. │ ├──load_weights.py // 从检查点或者NPZ文件加载权重
  100. │ ├──initializer.py // 参数初始化器
  101. ├── vocab
  102. │ ├──all.bpe.codes // 字节对编码表(此文件需要用户自行生成)
  103. │ ├──all_en.dict.bin // 已学习到的词汇表(此文件需要用户自行生成)
  104. ├── scripts
  105. │ ├──run_ascend.sh // Ascend处理器上训练&评估模型脚本
  106. │ ├──run_gpu.sh // GPU处理器上训练&评估模型脚本
  107. │ ├──learn_subword.sh // 学习字节对编码
  108. │ ├──stop_training.sh // 停止训练
  109. ├── requirements.txt // 第三方包需求
  110. ├── train.py // 训练API入口
  111. ├── eval.py // 推理API入口
  112. ├── tokenize_corpus.py // 语料标记化
  113. ├── apply_bpe_encoding.py // 应用BPE进行编码
  114. ├── weights_average.py // 将各模型检查点平均转换到NPZ格式
  115. ├── news_crawl.py // 创建预训练所用的News Crawl数据集
  116. ├── gigaword.py // 创建Gigaword语料库
  117. ├── cornell_dialog.py // 创建Cornell电影对白数据集,用于对话应答
  118. ```
  119. ## 准备数据集
  120. 自然语言处理任务的数据准备过程包括数据清洗、标记、编码和生成词汇表几个步骤。
  121. 实验中,使用[字节对编码(BPE)](https://arxiv.org/abs/1508.07909)可以有效减少词汇量,减轻对OOV的影响。
  122. 使用`src/utils/dictionary.py`可以基于BPE学习到的文本词典创建词汇表。
  123. 有关BPE的更多详细信息,参见[Subword-nmt lib](https://www.cnpython.com/pypi/subword-nmt)或[论文](https://arxiv.org/abs/1508.07909)。
  124. 实验中,根据News Crawl数据集的1.9万个句子,学习到的词汇量为45755个单词。
  125. 这里我们简单介绍一下准备数据所需的脚本。
  126. ### 标记化
  127. 使用`tokenize_corpus.py`可以标记`.txt`格式的文本语料。
  128. `tokenize_corpus.py`的主要参数如下:
  129. ```bash
  130. --corpus_folder: Corpus folder path, if multi-folders are provided, use ',' split folders.
  131. --output_folder: Output folder path.
  132. --tokenizer: Tokenizer to be used, nltk or jieba, if nltk is not installed fully, use jieba instead.
  133. --pool_size: Processes pool size.
  134. ```
  135. 示例代码如下:
  136. ```bash
  137. python tokenize_corpus.py --corpus_folder /{path}/corpus --output_folder /{path}/tokenized_corpus --tokenizer {nltk|jieba} --pool_size 16
  138. ```
  139. ### 字节对编码
  140. 标记化后,使用提供的`all.bpe.codes`对标记后的语料进行字节对编码处理。
  141. 应用BPE所需的脚本为`apply_bpe_encoding.py`。
  142. `apply_bpe_encoding.py`的主要参数如下:
  143. ```bash
  144. --codes: BPE codes file.
  145. --src_folder: Corpus folders.
  146. --output_folder: Output files folder.
  147. --prefix: Prefix of text file in `src_folder`.
  148. --vocab_path: Generated vocabulary output path.
  149. --threshold: Filter out words that frequency is lower than threshold.
  150. --processes: Size of process pool (to accelerate).Default: 2.
  151. ```
  152. 示例代码如下:
  153. ```bash
  154. python tokenize_corpus.py --codes /{path}/all.bpe.codes \
  155. --src_folder /{path}/tokenized_corpus \
  156. --output_folder /{path}/tokenized_corpus/bpe \
  157. --prefix tokenized \
  158. --vocab_path /{path}/vocab_en.dict.bin
  159. --processes 32
  160. ```
  161. ### 构建词汇表
  162. 如需创建新词汇表,可任选下列方法之一:
  163. 1. 重新学习字节对编码,从`subword-nmt`的多个词汇表文件创建词汇表。
  164. 2. 基于现有词汇文件创建词汇表,该词汇文件行格式为`word frequency`。
  165. 3. *(可选)* 基于`vocab/all_en.dict.bin`,应用`src/utils/dictionary.py`中的`shink`方法创建一个小词汇表。
  166. 4. 应用`persistence()`方法将词汇表持久化到`vocab`文件夹。
  167. `src/utils/dictionary.py`的主要接口如下:
  168. 1. `shrink(self, threshold=50)`:通过过滤词频低于阈值的单词来缩小词汇量,并返回一个新的词汇表。
  169. 2. `load_from_text(cls, filepaths: List[str])`:加载现有文本词汇表,行格式为`word frequency`。
  170. 3. `load_from_persisted_dict(cls, filepath)`:加载通过调用`persistence()`方法保存的持久化二进制词汇表。
  171. 4. `persistence(self, path)`:将词汇表对象保存为二进制文件。
  172. 示例代码如下:
  173. ```python
  174. from src.utils import Dictionary
  175. vocabulary = Dictionary.load_from_persisted_dict("vocab/all_en.dict.bin")
  176. tokens = [1, 2, 3, 4, 5]
  177. # Convert ids to symbols.
  178. print([vocabulary[t] for t in tokens])
  179. sentence = ["Hello", "world"]
  180. # Convert symbols to ids.
  181. print([vocabulary.index[s] for s in sentence])
  182. ```
  183. 相关信息,参见源文件。
  184. ### 生成数据集
  185. 如前所述,MASS模式下使用了三个语料数据集,相关数据集生成脚本已提供。
  186. #### News Crawl语料库
  187. 数据集生成脚本为`news_crawl.py`。
  188. `news_crawl.py`的主要参数如下:
  189. ```bash
  190. Note that please provide `--existed_vocab` or `--dict_folder` at least one.
  191. A new vocabulary would be created in `output_folder` when pass `--dict_folder`.
  192. --src_folder: Corpus folders.
  193. --existed_vocab: Optional, persisted vocabulary file.
  194. --mask_ratio: Ratio of mask.
  195. --output_folder: Output dataset files folder path.
  196. --max_len: Maximum sentence length.If a sentence longer than `max_len`, then drop it.
  197. --suffix: Optional, suffix of generated dataset files.
  198. --processes: Optional, size of process pool (to accelerate).Default: 2.
  199. ```
  200. 示例代码如下:
  201. ```bash
  202. python news_crawl.py --src_folder /{path}/news_crawl \
  203. --existed_vocab /{path}/mass/vocab/all_en.dict.bin \
  204. --mask_ratio 0.5 \
  205. --output_folder /{path}/news_crawl_dataset \
  206. --max_len 32 \
  207. --processes 32
  208. ```
  209. #### Gigaword语料库
  210. 数据集生成脚本为`gigaword.py`。
  211. `gigaword.py`主要参数如下:
  212. ```bash
  213. --train_src: Train source file path.
  214. --train_ref: Train reference file path.
  215. --test_src: Test source file path.
  216. --test_ref: Test reference file path.
  217. --existed_vocab: Persisted vocabulary file.
  218. --output_folder: Output dataset files folder path.
  219. --noise_prob: Optional, add noise prob.Default: 0.
  220. --max_len: Optional, maximum sentence length.If a sentence longer than `max_len`, then drop it.Default: 64.
  221. --format: Optional, dataset format, "mindrecord" or "tfrecord".Default: "tfrecord".
  222. ```
  223. 示例代码如下:
  224. ```bash
  225. python gigaword.py --train_src /{path}/gigaword/train_src.txt \
  226. --train_ref /{path}/gigaword/train_ref.txt \
  227. --test_src /{path}/gigaword/test_src.txt \
  228. --test_ref /{path}/gigaword/test_ref.txt \
  229. --existed_vocab /{path}/mass/vocab/all_en.dict.bin \
  230. --noise_prob 0.1 \
  231. --output_folder /{path}/gigaword_dataset \
  232. --max_len 64
  233. ```
  234. #### Cornell电影对白语料库
  235. 数据集生成脚本为`cornell_dialog.py`。
  236. `cornell_dialog.py`主要参数如下:
  237. ```bash
  238. --src_folder: Corpus folders.
  239. --existed_vocab: Persisted vocabulary file.
  240. --train_prefix: Train source and target file prefix.Default: train.
  241. --test_prefix: Test source and target file prefix.Default: test.
  242. --output_folder: Output dataset files folder path.
  243. --max_len: Maximum sentence length.If a sentence longer than `max_len`, then drop it.
  244. --valid_prefix: Optional, Valid source and target file prefix.Default: valid.
  245. ```
  246. 示例代码如下:
  247. ```bash
  248. python cornell_dialog.py --src_folder /{path}/cornell_dialog \
  249. --existed_vocab /{path}/mass/vocab/all_en.dict.bin \
  250. --train_prefix train \
  251. --test_prefix test \
  252. --noise_prob 0.1 \
  253. --output_folder /{path}/cornell_dialog_dataset \
  254. --max_len 64
  255. ```
  256. ## 配置
  257. `config/`目录下的JSON文件为模板配置文件,
  258. 便于为大多数选项及参数赋值,包括训练平台、数据集和模型的配置、优化器参数等。还可以通过设置相应选项,获得诸如损失放大和检查点等可选特性。
  259. 有关属性的详细信息,参见`config/config.py`文件。
  260. ## 训练&评估过程
  261. 训练模型时,只需使用shell脚本`run_ascend.sh`或`run_gpu.sh`即可。脚本中设置了环境变量,执行`mass`下的`train.py`训练脚本。
  262. 您可以通过选项赋值来启动单卡或多卡训练任务,在bash中运行如下命令:
  263. Ascend处理器:
  264. ```ascend
  265. sh run_ascend.sh [--options]
  266. ```
  267. GPU处理器:
  268. ```gpu
  269. sh run_gpu.sh [--options]
  270. ```
  271. `run_ascend.sh`的用法如下:
  272. ```text
  273. Usage: run_ascend.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>]
  274. [-i, --device_id <N>] [-j, --hccl_json <FILE>]
  275. [-c, --config <FILE>] [-o, --output <FILE>]
  276. [-v, --vocab <FILE>]
  277. options:
  278. -h, --help show usage
  279. -t, --task select task: CHAR, 't' for train and 'i' for inference".
  280. -n, --device_num device number used for training: N, default is 1.
  281. -i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0.
  282. -j, --hccl_json rank table file used for training with multiple devices: FILE.
  283. -c, --config configuration file as shown in the path 'mass/config': FILE.
  284. -o, --output assign output file of inference: FILE.
  285. -v, --vocab set the vocabulary.
  286. -m, --metric set the metric.
  287. ```
  288. 说明:运行分布式训练时,确保已配置`hccl_json`文件。
  289. `run_gpu.sh`的用法如下:
  290. ```text
  291. Usage: run_gpu.sh [-h, --help] [-t, --task <CHAR>] [-n, --device_num <N>]
  292. [-i, --device_id <N>] [-c, --config <FILE>]
  293. [-o, --output <FILE>] [-v, --vocab <FILE>]
  294. options:
  295. -h, --help show usage
  296. -t, --task select task: CHAR, 't' for train and 'i' for inference".
  297. -n, --device_num device number used for training: N, default is 1.
  298. -i, --device_id device id used for training with single device: N, 0<=N<=7, default is 0.
  299. -c, --config configuration file as shown in the path 'mass/config': FILE.
  300. -o, --output assign output file of inference: FILE.
  301. -v, --vocab set the vocabulary.
  302. -m, --metric set the metric.
  303. ```
  304. 运行如下命令进行2卡训练。
  305. Ascend处理器:
  306. ```ascend
  307. sh run_ascend.sh --task t --device_num 2 --hccl_json /{path}/rank_table.json --config /{path}/config.json
  308. ```
  309. 注:`run_ascend.sh`暂不支持不连续设备ID,`rank_table.json`中的设备ID必须从0开始。
  310. GPU处理器:
  311. ```gpu
  312. sh run_gpu.sh --task t --device_num 2 --config /{path}/config.json
  313. ```
  314. 运行如下命令进行单卡训练:
  315. Ascend处理器:
  316. ```ascend
  317. sh run_ascend.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json
  318. ```
  319. GPU处理器:
  320. ```gpu
  321. sh run_gpu.sh --task t --device_num 1 --device_id 0 --config /{path}/config.json
  322. ```
  323. ## 权重平均值
  324. ```python
  325. python weights_average.py --input_files your_checkpoint_list --output_file model.npz
  326. ```
  327. `input_files`为检查点文件清单。如需使用`model.npz`作为权重文件,请在“existed_ckpt”的`config.json`文件中添加`model.npz`的路径。
  328. ```json
  329. {
  330. ...
  331. "checkpoint_options": {
  332. "existed_ckpt": "/xxx/xxx/model.npz",
  333. "save_ckpt_steps": 1000,
  334. ...
  335. },
  336. ...
  337. }
  338. ```
  339. ## 学习速率调度器
  340. 模型中提供了两个学习速率调度器:
  341. 1. [多项式衰减调度器](https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1)。
  342. 2. [逆平方根调度器](https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/Inverse_square_root/)。
  343. 可以在`config/config.json`文件中配置学习率调度器。
  344. 多项式衰减调度器配置文件示例如下:
  345. ```json
  346. {
  347. ...
  348. "learn_rate_config": {
  349. "optimizer": "adam",
  350. "lr": 1e-4,
  351. "lr_scheduler": "poly",
  352. "poly_lr_scheduler_power": 0.5,
  353. "decay_steps": 10000,
  354. "warmup_steps": 2000,
  355. "min_lr": 1e-6
  356. },
  357. ...
  358. }
  359. ```
  360. 逆平方根调度器配置文件示例如下:
  361. ```json
  362. {
  363. ...
  364. "learn_rate_config": {
  365. "optimizer": "adam",
  366. "lr": 1e-4,
  367. "lr_scheduler": "isr",
  368. "decay_start_step": 12000,
  369. "warmup_steps": 2000,
  370. "min_lr": 1e-6
  371. },
  372. ...
  373. }
  374. ```
  375. 有关学习率调度器的更多详细信息,参见`src/utils/lr_scheduler.py`。
  376. # 环境要求
  377. ## 平台
  378. - 硬件(Ascend或GPU)
  379. - 使用Ascend或GPU处理器准备硬件环境。
  380. - 框架
  381. - [MindSpore](https://www.mindspore.cn/install)
  382. - 更多关于Mindspore的信息,请查看以下资源:
  383. - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
  384. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
  385. ## 其他要求
  386. ```txt
  387. nltk
  388. numpy
  389. subword-nmt
  390. rouge
  391. ```
  392. <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html>
  393. # 快速上手
  394. MASS通过预测输入序列中被屏蔽的片段来预训练序列到序列模型。之后,选择下游的文本摘要或会话应答任务进行模型微调和推理。
  395. 这里提供了一个练习示例来演示应用MASS,对模型进行预训练、微调的基本用法,以及推理过程。操作步骤如下:
  396. 1. 下载并处理数据集。
  397. 2. 修改`config.json`文件,配置网络。
  398. 3. 运行预训练和微调任务。
  399. 4. 进行推理验证。
  400. ## 预训练
  401. 预训练模型时,首先配置`config.json`中的选项:
  402. - 将`dataset_config`节点下的`pre_train_dataset`配置为数据集路径。
  403. - 选择优化器(可采用'momentum/adam/lamb’)。
  404. - 在`checkpoint_path`下,指定'ckpt_prefix'和'ckpt_path'来保存模型文件。
  405. - 配置其他参数,包括数据集配置和网络配置。
  406. - 如果已经有训练好的模型,请将`existed_ckpt`配置为该检查点文件。
  407. 如使用Ascend芯片,执行`run_ascend.sh`这个shell脚本:
  408. ```ascend
  409. sh run_ascend.sh -t t -n 1 -i 1 -c /mass/config/config.json
  410. ```
  411. 如使用GPU处理器,执行`run_gpu.sh`这个shell脚本:
  412. ```gpu
  413. sh run_gpu.sh -t t -n 1 -i 1 -c /mass/config/config.json
  414. ```
  415. 日志和输出文件可以在`./train_mass_*/`路径下获取,模型文件可以在`config/config.json`配置文件中指定的路径下获取。
  416. ## 微调
  417. 预训练模型时,首先配置`config.json`中的选项:
  418. - 将`dataset_config`节点下的`fine_tune_dataset`配置为数据集路径。
  419. - 将`checkpoint_path`节点下的`existed_ckpt`赋值给预训练生成的已有模型文件。
  420. - 选择优化器(可采用'momentum/adam/lamb’)。
  421. - 在`checkpoint_path`下,指定'ckpt_prefix'和'ckpt_path'来保存模型文件。
  422. - 配置其他参数,包括数据集配置和网络配置。
  423. 如使用Ascend芯片,执行`run_ascend.sh`这个shell脚本:
  424. ```ascend
  425. sh run_ascend.sh -t t -n 1 -i 1 -c config/config.json
  426. ```
  427. 如使用GPU处理器,执行`run_gpu.sh`这个shell脚本:
  428. ```gpu
  429. sh run_gpu.sh -t t -n 1 -i 1 -c config/config.json
  430. ```
  431. 日志和输出文件可以在`./train_mass_*/`路径下获取,模型文件可以在`config/config.json`配置文件中指定的路径下获取。
  432. ## 推理
  433. 如果您需要使用此训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理,可参考此[链接](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。
  434. 推理时,请先配置`config.json`中的选项:
  435. - 将`dataset_config`节点下的`test_dataset`配置为数据集路径。
  436. - 将`dataset_config`节点下的`test_dataset`配置为数据集路径。
  437. - 选择优化器(可采用'momentum/adam/lamb’)。
  438. - 在`checkpoint_path`下,指定'ckpt_prefix'和'ckpt_path'来保存模型文件。
  439. - 配置其他参数,包括数据集配置和网络配置。
  440. 如使用Ascend芯片,执行`run_ascend.sh`这个shell脚本:
  441. ```bash
  442. sh run_ascend.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile}
  443. ```
  444. 如使用GPU处理器,执行`run_gpu.sh`这个shell脚本:
  445. ```gpu
  446. sh run_gpu.sh -t i -n 1 -i 1 -c config/config.json -o {outputfile}
  447. ```
  448. # 性能
  449. ## 结果
  450. ### 文本摘要微调
  451. 下表展示了,相较于其他两种预训练方法,MASS在文本摘要任务中的ROUGE得分情况。
  452. 训练数据大小为3.8M。
  453. | 方法| RG-1(F) | RG-2(F) | RG-L(F) |
  454. |:---------------|:--------------|:-------------|:-------------|
  455. | MASS | 进行中 | 进行中 | 进行中 |
  456. ### 会话应答微调
  457. 下表展示了,相较于其他两种基线方法,MASS在Cornell电影对白语料库中困惑度(PPL)的得分情况。
  458. | 方法 | 数据 = 10K | 数据 = 110K |
  459. |--------------------|------------------|-----------------|
  460. | MASS | 进行中 | 进行中 |
  461. ### 训练性能
  462. | 参数 | 掩式序列到序列预训练语言生成 |
  463. |:---------------------------|:--------------------------------------------------------------------------|
  464. | 模型版本 | v1 |
  465. | 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755GB;系统 Euler2.8 |
  466. | 上传日期 | 2020-05-24 |
  467. | MindSpore版本 | 0.2.0 |
  468. | 数据集 | News Crawl 2007-2017英语单语语料库、Gigaword语料库、Cornell电影对白语料库 |
  469. | 训练参数 | Epoch=50, steps=XXX, batch_size=192, lr=1e-4 |
  470. | 优化器 | Adam |
  471. | 损失函数 | 标签平滑交叉熵准则 |
  472. | 输出 | 句子及概率 |
  473. | 损失 | 小于2 |
  474. | 准确性 | 会话应答PPL=23.52,文本摘要RG-1=29.79|
  475. | 速度 | 611.45句子/秒 |
  476. | 总时长 | |
  477. | 参数(M) | 44.6M |
  478. ### 推理性能
  479. | 参数 | 掩式序列到序列预训练语言生成 |
  480. |:---------------------------|:-----------------------------------------------------------|
  481. |模型版本| V1 |
  482. | 资源 | Ascend 910;系统 Euler2.8 |
  483. | 上传日期 | 2020-05-24 |
  484. | MindSpore版本 | 0.2.0 |
  485. | 数据集 | Gigaword语料库、Cornell电影对白语料库 |
  486. | batch_size | --- |
  487. | 输出 | 句子及概率 |
  488. | 准确度 | 会话应答PPL=23.52,文本摘要RG-1=29.79|
  489. | 速度 | ----句子/秒 |
  490. | 总时长 | --/-- |
  491. # 随机情况说明
  492. MASS模型涉及随机失活(dropout)操作,如需禁用此功能,请在`config/config.json`中将dropout_rate设置为0。
  493. # 其他
  494. 该模型已在Ascend环境下与GPU环境下得到验证,尚未在CPU环境下验证。
  495. # ModelZoo主页
  496. [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)