You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README_CN.md 48 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. # 目录
  2. <!-- TOC -->
  3. - [目录](#目录)
  4. - [GPT-2模型](#GPT-2模型)
  5. - [模型架构](#模型架构)
  6. - [下游任务](#下游任务)
  7. - [脚本说明](#脚本说明)
  8. - [模型转换](#模型转换)
  9. - [准备数据集](#准备数据集)
  10. - [Language Modeling 语言建模任务](#Language Modeling语言建模任务)
  11. - [Children's Book Test 任务](#Children's Book Test任务)
  12. - [LAMBADA 任务](#LAMBADA任务)
  13. - [Reading Comprehension 任务](#Reading Comprehension任务)
  14. - [Summarization 任务](#Summarization任务)
  15. - [Translation 任务](#Translation任务)
  16. - [配置](#配置)
  17. - [微调&评估过程](#微调&训练评估过程)
  18. - [Language Modeling 任务](#Language Modeling任务)
  19. - 微调
  20. - 评估
  21. - [Children's Book Test 任务](#Children's Book Test任务)
  22. - 评估
  23. - [LAMBADA 任务](#LAMBADA任务)
  24. - 评估
  25. - [Reading Comprehension 任务](#Reading Comprehension任务)
  26. - 评估
  27. - [Summarization 任务](#Summarization任务)
  28. - 评估
  29. - [Translation 任务](#Translation任务)
  30. - 评估
  31. - [环境要求](#环境要求)
  32. - [平台](#平台)
  33. - [其他要求](#其他要求)
  34. - [性能](#性能)
  35. - [推理性能](#推理性能)
  36. - [Language Modeling 任务](#Language Modeling任务)
  37. - [Children's Book Test 任务](#Children's Book Test任务)
  38. - [LAMBADA 任务](#LAMBADA任务)
  39. - [Reading Comprehension 任务](#Reading Comprehension任务)
  40. - [Summarization 任务](#Summarization任务)
  41. - [Translation 任务](#Translation任务)
  42. - [训练性能](#训练性能)
  43. - [推理性能](#推理性能)
  44. - [其他](#其他)
  45. - [ModelZoo主页](#modelzoo主页)
  46. <!-- /TOC -->
  47. # GPT-2模型
  48. [GPT-2介绍](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 由Open于2019年发布。GPT-2模型是继承于GPT模型,GPT-2是一个非常庞大的语言模型,它主要是用于预测下一个单词。按照参数量的大小,GPT-2模型可分为small(117M)、medium(345M)、large(762M)、xlarge(1542M)。
  49. [GPT-2介绍](https://openai.com/blog/better-language-models/)
  50. [GPT-2论文](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf): Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9.
  51. # 模型架构
  52. GPT-2模型由Transformer的解码器实现,Transformer包括多个编码器层和多个解码器层,但在GPT-2模型中仅使用了Transformer的解码器部分。
  53. 微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。
  54. 测试过程中,通过微调后的模型预测结果,对于某些任务可以直接进行zero-shot评估即可。
  55. # 下游任务
  56. 本文主要涉及6个下游任务,包括:
  57. - Language Modeling 任务
  58. - Children‘s Book Test 任务
  59. - LAMBADA任务
  60. - Reading Comprehension任务
  61. - Summarization任务
  62. - Translation任务
  63. 数据集相关信息,参见[https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)。
  64. ## 脚本说明
  65. GPT-2脚本及代码结构如下:
  66. ```text
  67. ├── GPT-2
  68. ├── README.md // MASS模型介绍
  69. ├── scripts
  70. │ ├──run_cbt.sh // CBT任务的微调&评估脚本
  71. │ ├──run_lambada.sh // LAMBADA任务的微调&评估脚本
  72. │ ├──run_language_model.sh // 语言建模任务的微调&评估脚本
  73. │ ├──run_read_comprehension.sh // 阅读理解任务的微调&评估脚本
  74. │ ├──run_summarization.sh // 摘要生成任务的微调&评估脚本
  75. │ ├──run_translation.sh // 翻译任务的微调&评估脚本
  76. ├──src
  77. │ ├──clip_grad_utils.py // 用于梯度裁剪
  78. | ├──dataset.py // 数据集加载用于微调或推理
  79. │ ├──finetune_eval_config.py // 微调和推理配置文件
  80. │ ├──gpt2_for_finetune.py // 用于梯度裁剪
  81. | ├──GPT2_generation.py // 生成模块
  82. │ ├──GPT2_model.py // GPT2模型脚本
  83. │ ├──GPT2ForCBT.py // CBT任务的模型脚本
  84. │ ├──GPT2ForLanguageModel.py // 语言建模任务的模型脚本
  85. │ ├──GPT2ForReadComprehension.py // 阅读理解任务的模型脚本
  86. │ ├──GPT2ForSummarization.py // 摘要生成任务的模型脚本
  87. │ ├──GPT2ForTranslation.py // 翻译任务的模型脚本
  88. │ ├──weight_init.py // 初始化权重
  89. │ ├──utils
  90. │ ├──bleu_score.py // 用于计算BLEU分数
  91. │ ├──rouge_score.py // 用于计算ROUGE分数
  92. │ ├──CrossEntropy.py // 交叉熵损失
  93. │ ├──data_preprocess.py // 数据集预处理脚本
  94. │ ├──generation_utils.py // 用于帮助生成模型,包含采样等方法
  95. │ ├──get_config_setting.py // 获取配置信息
  96. │ ├──task_utils.py // 辅助下游任务的功能脚本
  97. │ ├──lr_schedule.py // 学习率策略脚本
  98. │ ├──metric_method.py // 下游任务的评价指标
  99. │ ├──tensor_manipulations.py // 涉及张量操作
  100. │ ├──tokenization.py // 标记化,包含BPE编码和解码
  101. │ ├──pretrain-data
  102. │ ├──stopwords.txt // 用于LAMBADA任务的stopword filter
  103. ├──create_cbt_data.py // 用于CBT任务创建mindrecord
  104. ├──create_lambada_data.py // 用于lambada任务创建mindrecord
  105. ├──create_lambada_data.py // 用于其他任务创建mindrecord
  106. ├──create_summary_data.py // 用于summarization任务创建mindrecord
  107. ├──download_cnn_dailymail.py // 下载CNN & Dailymail数据集
  108. ├──cnn_dataset_sampler.py // CNN & Dailymail训练集采样器
  109. ├──eval_rc_addition_answer.py // 使用addition_answer评估阅读理解任务
  110. ├──run_CBT_task.py // CBT任务微调&推理API入口
  111. ├──run_lambada.py // LAMBADA任务微调&推理API入口
  112. ├──run_language_mdoel.py // 语言建模任务微调&推理API入口
  113. ├──run_ReadComprehension.py // 阅读理解任务微调&推理API入口
  114. ├──run_summarization.py // 摘要生成任务微调&推理API入口
  115. ├──run_translation.py // 翻译任务微调&推理API入口
  116. ├──task_dataset_preprocess.py // 各个任务的数据集处理入口
  117. ├──convert_tf_ckpt
  118. │ ├──read_weight_tf.py // 读取tensorflow下的预训练模型
  119. │ ├──trans_dict.py // 模型参数名称字典
  120. │ ├──save_weight_ms.py // 生成mindspore ckpt
  121. ├──third_party
  122. │ ├──gpt2-merges.txt
  123. │ ├──gpt2-vocab.json // GPT-2预训练词表
  124. │ ├──bleu.py // 辅助bleu值计算的第三方代码
  125. ```
  126. ## 模型转换
  127. - 下载GPT-2的预训练模型 [GPT-2预训练模型下载](https://github.com/openai/gpt-2/blob/master/download_model.py)
  128. - 在tensorflow的环境下,运行`read_weight_tf.py`,示例代码如下:
  129. `python read_weight_tf.py --ckpt_file_path=/{path}/model.ckpt`
  130. - 在mindspore的环境下,运行`save_weight_ms.py`,示例代码如下:
  131. `python save_weight_ms.py --output_file_name="mindspore_gpt2_small.ckpt"`
  132. ## 准备数据集
  133. ### Language Modeling语言建模任务
  134. #### WikiText2 、WikiText103、PTB、1BW 数据集
  135. - [WikiText2数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) 解压后使用`wikitext-2 /wiki.test.tokens`作为测试集
  136. - [WikiText103数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 解压后使用`wikitext-103 /wiki.test.tokens`作为测试集
  137. - [PTB数据集下载](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz) 解压后使用 `/simple-examples/data/ptb.test.txt` 测试集,使用 `/simple-examples/data/ptb.test.txt` 作为训练集
  138. - [1BW数据集下载](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz) 解压后使用`1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050`作为测试集,使用`1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00001-of-00100`作为原始训练集,进行随机采样后得到30000条训练集样本
  139. 使用`task_dataset_preprocess.py`可以对以上数据集进行清洗。
  140. `task_dataset_preprocess.py`的主要参数如下:
  141. ```bash
  142. --task: The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada, Summarization, ReadingComprehension].
  143. --input_file: The raw dataset path.
  144. --dataset: The name of dataset which should be processed, only for LanguageModeling task.
  145. --output_file: The output dataset path after preprocessing.
  146. --condition: Process train or test dataset, including [train, test], only for 1BW and CNN & DailyMail dataset.
  147. ```
  148. 示例代码如下:
  149. 清洗PTB训练集和测试集
  150. ```bash
  151. python task_dataset_preprocess.py --task "LanguageModeling" --input_file /{path}/ptb.test.txt --dataset "ptb" --output_file /{path}/ptb_clean_test.txt --condition "test"
  152. ```
  153. 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
  154. `create_lm_data.py`的主要参数如下:
  155. ```bash
  156. --input_file: Input raw text file.
  157. --output_file: Output MindRecord file.
  158. --num_splits: The MindRecord file will be split into the number of partition.
  159. --max_seq_length: Maximum sequence length.
  160. --vocab_file: url of gpt2-vocab.json.
  161. --merge_file: url of gpt2-merges.txt
  162. ```
  163. 示例代码如下:
  164. ```bash
  165. python create_lm_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
  166. ```
  167. ### Children's Book Test任务
  168. #### CBT-CN / CBT-NE 数据集
  169. - [CBT数据集下载](http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz) 使用在`/data`目录下使用`cbtest_CN_valid_2000ex.txt、cbtest_NE_valid_2000ex.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
  170. ```bash
  171. python task_dataset_preprocess.py --task "CBT" --input_file /{path}/cbtest_CN_valid_2000ex.txt --dataset "cbt" --output_file /{path}/cbt_cn_valid.txt
  172. ```
  173. 使用`create_cbt_data.py`可以将以上数据集格式转换为mindrecord
  174. `create_cbt_data.py`的主要参数如下:
  175. ```bash
  176. --input_file: Input raw text file.
  177. --output_file: Output MindRecord file.
  178. --num_splits: The MindRecord file will be split into the number of partition.
  179. --max_seq_length: Maximum sequence length.
  180. --num_choice: Number of choices.
  181. --vocab_file: url of gpt2-vocab.json.
  182. --merge_file: url of gpt2-merges.txt
  183. ```
  184. 示例代码如下:
  185. ```bash
  186. python create_cbt_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --num_choice 10 --vocab_file={path} --merge_file={path}
  187. ```
  188. ### LAMBADA任务
  189. #### LAMBADA 数据集
  190. - [LAMBADA数据集下载](https://zenodo.org/record/2630551#.X-yCSTTithH) 使用`lambada_test_plain_text.txt`作为该任务的评估集,清洗该数据集,示例代码如下:
  191. ```bash
  192. python task_dataset_preprocess.py --task "LAMBADA" --input_file /{path}/lambada_test_plain_text.txt --dataset "LAMBADA" --output_file /{path}/lambada_test_clean.txt
  193. ```
  194. 使用`create_lambada_data.py`可以将以上数据集格式转换为mindrecord
  195. `create_lambada_data.py`的主要参数如下:
  196. ```bash
  197. --input_file: Input raw text file.
  198. --output_file: Output MindRecord file.
  199. --num_splits: The MindRecord file will be split into the number of partition.
  200. --max_seq_length: Maximum sequence length.
  201. --vocab_file: url of gpt2-vocab.json.
  202. --merge_file: url of gpt2-merges.txt
  203. ```
  204. 示例代码如下:
  205. ```bash
  206. python create_lambada_data.py --input_file /{path}/lambada_test_clean.txt --output_file /{path}/lambada-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
  207. ```
  208. ### Reading Comprehension 任务
  209. #### CoQA数据集
  210. - [CoQA数据集下载](http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json) 使用`coqa-dev-v1.0.json`作为该任务的评估集,清洗该数据集,示例代码如下:
  211. ```bash
  212. python task_dataset_preprocess.py --task "ReadingComprehension" --input_file /{path}/coqa-dev-v1.0.json --dataset "coqa" --output_file /{path}/coqa_dev.txt
  213. ```
  214. 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
  215. 示例代码如下:
  216. ```bash
  217. python create_lm_data.py --input_file /{path}/coqa_dev.txt --output_file /{path}/coqa-dev-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
  218. ```
  219. ### Summarization 任务
  220. #### CNN & Dailymail数据集
  221. - 下载该数据集,使用`download_cnn_dailymail.py`脚本进行下载,示例代码如下:
  222. ```bash
  223. 下载测试集
  224. python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split test
  225. 下载训练集
  226. python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split train
  227. ```
  228. 从训练集中随机采用10000条样本作为最终的微调的训练集,使用`cnn_dataset_sampler.py`脚本进行训练的采样操作,生成新的训练集,示例代码如下:
  229. ```bash
  230. GPT-2 small和GPT-2 medium模型的训练集中seq_length=1024, 因此该脚本中设置max_length=1022
  231. python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
  232. --output_path="/{path}/cnn_train_hint_small.txt"
  233. --replace_hint="true"
  234. --sample="true"
  235. --max_length=1022
  236. --prob=0.25
  237. --max_items=10000
  238. --hint="TL;DR:"
  239. GPT-2 large模型的训练集中seq_length=768,因此该脚本中设置max_length=766
  240. python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt"
  241. --output_path="/{path}/cnn_train_hint_large.txt"
  242. --replace_hint="true"
  243. --sample="true"
  244. --max_length=766
  245. --prob=0.25
  246. --max_items=10000
  247. --hint="TL;DR:"
  248. ```
  249. 使用`create_summary_data.py`可以将以上数据集格式转换为mindrecord
  250. 示例代码如下:
  251. ```bash
  252. python create_summary_data.py --input_file /{path}/cnn_dailymail_test.txt --output_file /{path}/cnn_dailymail-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} --mode 'cnn_dailymail'
  253. ```
  254. ### Translation 任务
  255. #### WMT14 En-Fr数据集
  256. - [WMT14 En-Fr数据集下载](http://statmt.org/wmt14/test-full.tgz) 使用`newstest2014-fren-ref.en.sgm`和`newstest2014-fren-ref.fr.sgm`作为该任务的评估集,合并且清洗该数据集,示例代码如下:
  257. ```bash
  258. python task_dataset_preprocess.py --task "Translation" --input_file /{path}/test-full --dataset "wmt14" --output_file /{path}/wmt14
  259. ```
  260. 在`output_file`路径下会生成两个文件`wmt14.en_fr.txt`和`wmt14.fr_en.txt`,分别用于评估`En-Fr`和`Fr-En`。
  261. 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord
  262. 示例代码如下:
  263. ```bash
  264. python create_lm_data.py --input_file /{path}/wmt14.en_fr.txt --output_file /{path}/en-fr-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
  265. python create_lm_data.py --input_file /{path}/wmt14.fr_en.txt --output_file /{path}/fr-en-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path}
  266. ```
  267. ## 配置
  268. `src/finetune_eval_config.py`为GPT-2模型训练和推理的配置文件,便于为大多数选项及参数赋值,包括GPT-2 模型规模、模型的配置、优化器参数等。
  269. 有关属性的详细信息,参见`src/finetune_eval_config.py`文件。
  270. ## 微调&评估过程
  271. ### Language Modeling 语言建模任务
  272. #### 微调
  273. - PTB数据集
  274. GPT-2 small / GPT-2 medium / GPT-2 large模型需要在PTB训练集上进行微调。微调模型时,只需要使用shell脚本`scripts/run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`scripts/run_language_model.sh`脚本。
  275. 微调模型时,首先配置`src/finetune_eval_config.py`中的选项:
  276. 将`cfg`下的`gpt2_network`设置为相应的GPT-2模型大小`[small/medium/large]`。
  277. 将`cfg`下的`optimizer`设置为`Lamb`,进行优化器的选择(可采用'momentum/adam/lamb’)。
  278. 选定了GPT-2模型后需要设置模型的参数,包括`batch_size`和`seq_length`。
  279. 而后执行`scripts/run_language_model.sh`这个shell脚本:
  280. ```bash
  281. sh scripts/run_language_model.sh --device_target="Ascend"
  282. --do_train="true"
  283. --do_eval="false"
  284. --epoch_num=1
  285. --train_data_shuffle="true"
  286. --eval_data_shuffle="false"
  287. --save_finetune_ckpt_path={save_finetune_ckpt_path}
  288. --load_pretrain_ckpt_path={load_pretrain_ckpt_path}
  289. --train_data_file_path={train_data_file_path}
  290. ```
  291. 日志和输出文件可以在`./ms_log/`路径下获取。
  292. ```bash
  293. sh scripts/run_language_model.sh [--options]
  294. ```
  295. `run_language_model.sh`的用法如下:
  296. ```text
  297. usage: run_language_model.sh [--device_target DEVICE_TARGET] [--device_id N]
  298. [--metric_method METRIC_METHOD]
  299. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  300. [--eval_type EVAL_TYPE] [--epoch_num N]
  301. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  302. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  303. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  304. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  305. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  306. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  307. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  308. options:
  309. --device_target Device type. Default: "Ascend"
  310. --device_id ID of target device
  311. --metric_method The eval method including [PPL]. Default: "PPL"
  312. --do_train Enable train. Default: "false"
  313. --do_eval Enable evaluation. Default: "true"
  314. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
  315. --epoch_num Epoch number. Default: 1
  316. --train_data_shuffle Enable train data shuffle. Default: "true"
  317. --eval_data_shuffle Enable eval data shuffle. Default: "false"
  318. --save_finetune_ckpt_path Save the finetuned checkpoint path
  319. --load_pretrain_ckpt_path Load the checkpoint file path for train
  320. --load_finetune_ckpt_path Load the checkpoint file path for evaluation
  321. --train_data_file_path Data path, it is better to use absolute path
  322. --eval_data_file_path Data path, it is better to use absolute path
  323. ```
  324. - 1BW数据集
  325. GPT-2 large模型需要在1BW训练集上进行微调。微调模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。该微调方法与PTB数据集的一致。
  326. #### 评估
  327. GPT-2模型可以在`WikiText2/WikiText103/PTB/1BW`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用PPL,即设置`--metric_method="PPL"`。
  328. 评估模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。
  329. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_language_model.sh`这个shell脚本,若该模型在某个数据集上被微调了,则使用该模型进行对应测试集的评估时需要设置`--eval_type="finetuned"`,否则设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是微调好后的checkpoint文件位置
  330. ```bash
  331. sh scripts/run_language_model.sh --device_target="Ascend"
  332. --metric_method="PPL"
  333. --do_train="false"
  334. --do_eval="true"
  335. --eval_type="finetuned"
  336. --train_data_shuffle="true"
  337. --eval_data_shuffle="false"
  338. --load_finetune_ckpt_path={load_eval_ckpt_path}
  339. --eval_data_file_path={eval_data_file_path}
  340. ```
  341. 日志和输出文件可以在`./ms_log/`路径下获取。
  342. ### Children's Book Test任务
  343. #### 评估
  344. GPT-2模型可以在`CBT-CN/CBT-NE`验证集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy,即设置`--metric_method="Accuracy"`。
  345. 评估模型时,只需要使用shell脚本`run_cbt.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_CBT_task.py`脚本。
  346. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_cbt.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
  347. ```bash
  348. sh scripts/run_cbt.sh --device_target="Ascend"
  349. --num_choice=10
  350. --metric_method="Accuarcy"
  351. --do_train="false"
  352. --do_eval="true"
  353. --eval_type="zero-shot"
  354. --train_data_shuffle="true"
  355. --eval_data_shuffle="false"
  356. --load_finetune_ckpt_path={load_eval_ckpt_path}
  357. --eval_data_file_path={eval_data_file_path}
  358. ```
  359. 日志和输出文件可以在`./ms_log/`路径下获取。
  360. ```bash
  361. sh scripts/run_cbt.sh [--options]
  362. ```
  363. `run_cbt.sh`的用法如下:
  364. ```text
  365. usage: run_CBT_task.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
  366. [--metric_method METRIC_METHOD]
  367. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  368. [--eval_type EVAL_TYPE] [--epoch_num N]
  369. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  370. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  371. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  372. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  373. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  374. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  375. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  376. options:
  377. --device_target Device type. Default: "Ascend"
  378. --device_id ID of target device
  379. --num_choice The number of choice in CBT task
  380. --metric_method The eval method including [Accuracy]. Default: "Accuracy"
  381. --do_train Enable train. Default: "false"
  382. --do_eval Enable evaluation. Default: "true"
  383. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
  384. --epoch_num Epoch number. Default: 1
  385. --train_data_shuffle Enable train data shuffle. Default: "true"
  386. --eval_data_shuffle Enable eval data shuffle. Default: "false"
  387. --save_finetune_ckpt_path Save the finetuned checkpoint path
  388. --load_pretrain_ckpt_path Load the checkpoint file path for train
  389. --load_finetune_ckpt_path Load the checkpoint file path for evaluation
  390. --train_data_file_path Data path, it is better to use absolute path
  391. --eval_data_file_path Data path, it is better to use absolute path
  392. ```
  393. ### LAMBADA任务
  394. #### 评估
  395. GPT-2模型可以在`LAMBADA`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy和PPL,即设置`--metric_method="Accuracy"` 或者`--metric_method="PPL"`。
  396. 评估模型时,只需要使用shell脚本`run_lambada.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_lambada.py`脚本。
  397. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_lambada.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
  398. 评估Accuracy
  399. ```bash
  400. sh scripts/run_lambada.sh --device_target="Ascend"
  401. --metric_method="Accuarcy"
  402. --do_train="false"
  403. --do_eval="true"
  404. --eval_type="zero-shot"
  405. --train_data_shuffle="true"
  406. --eval_data_shuffle="false"
  407. --generate_length_dynamically="true"
  408. --load_finetune_ckpt_path={load_eval_ckpt_path}
  409. --eval_data_file_path={eval_data_file_path}
  410. --tokenizer_file_path={tokenizer_file_path}
  411. --stop_word_file_path={stop_word_file_path}
  412. ```
  413. 评估PPL
  414. ```bash
  415. sh scripts/run_lambada.sh --device_target="Ascend"
  416. --metric_method="PPL"
  417. --do_train="false"
  418. --do_eval="true"
  419. --eval_type="zero-shot"
  420. --train_data_shuffle="true"
  421. --eval_data_shuffle="false"
  422. --load_finetune_ckpt_path={load_eval_ckpt_path}
  423. --eval_data_file_path={eval_data_file_path}
  424. ```
  425. 日志和输出文件可以在`./ms_log/`路径下获取。
  426. ```bash
  427. sh scripts/run_lambada.sh [--options]
  428. ```
  429. ```text
  430. usage: run_lambada.sh [--device_target DEVICE_TARGET] [--device_id N]
  431. [--metric_method METRIC_METHOD]
  432. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  433. [--eval_type EVAL_TYPE] [--epoch_num N]
  434. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  435. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  436. [--generate_length_dynamically GENERATE_LENGTH_DYNAMICALLY]
  437. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  438. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  439. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  440. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  441. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  442. [--tokenizer_file_path TOKENIZER_FILE_PATH]
  443. [--stop_word_file_path STOP_WORD_FILE_PATH]
  444. options:
  445. --device_target Device type. Default: "Ascend"
  446. --device_id ID of target device
  447. --metric_method The eval method including [Accuracy, PPL]. Default: "Accuracy"
  448. --do_train Enable train. Default: "false"
  449. --do_eval Enable evaluation. Default: "true"
  450. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
  451. --epoch_num Epoch number. Default: 1
  452. --train_data_shuffle Enable train data shuffle. Default: "true"
  453. --eval_data_shuffle Enable eval data shuffle. Default: "false"
  454. --generate_length_dynamically Enable generate_length_Dynamically. Default: "true"
  455. --save_finetune_ckpt_path Save the checkpoint path
  456. --load_pretrain_ckpt_path Load the checkpoint file path
  457. --load_finetune_ckpt_path Load the checkpoint file path
  458. --train_data_file_path Data path, it is better to use absolute path
  459. --eval_data_file_path Data path, it is better to use absolute path
  460. --tokenizer_file_path pretrained vocab and merge file path
  461. --stop_word_file_path The stop word file path
  462. ```
  463. ### Reading Comprehension任务
  464. #### 评估
  465. GPT-2模型可以在`CoQA`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="F1"` 。
  466. 评估模型时,只需要使用shell脚本`run_read_comprehension.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_read_comprehension.py`脚本。
  467. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_read_comprehension.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
  468. ```bash
  469. sh scripts/run_read_comprehension.sh --device_target="Ascend"
  470. --metric_method="F1"
  471. --do_train="false"
  472. --do_eval="true"
  473. --eval_type="zero-shot"
  474. --train_data_shuffle="true"
  475. --eval_data_shuffle="false"
  476. --load_finetune_ckpt_path={load_eval_ckpt_path}
  477. --eval_data_file_path={eval_data_file_path}
  478. --tokenizer_file_path={tokenizer_file_path}
  479. --generate_length=55
  480. --top_k=1
  481. --top_p="1.0"
  482. --temperature="1.0"
  483. ```
  484. 日志和输出文件可以在`./ms_log/`路径下获取。而后将得到的日志文件作为`eval_rc_addition_answer.py`脚本的`input_file`,同时将原CoQA开发集`coqa-dev-v1.0.json`作为`addition_file`。
  485. 执行`python eval_rc_addition_answer.py --input_file={path} --addition_file={path}`得到最终的F1值。
  486. ```bash
  487. sh scripts/run_read_comprehension.sh [--options]
  488. ```
  489. ```text
  490. usage: run_read_comprehension.sh [--device_target DEVICE_TARGET] [--device_id N]
  491. [--metric_method METRIC_METHOD]
  492. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  493. [--eval_type EVAL_TYPE] [--epoch_num N]
  494. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  495. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  496. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  497. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  498. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  499. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  500. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  501. [--tokenizer_file_path TOKENIZER_FILE_PATH]
  502. [--generate_length N] [--top_k N] [--top_p TOP_P]
  503. [--temperature TEMPERATURE]
  504. options:
  505. --device_target Device type. Default: "Ascend"
  506. --device_id ID of target device
  507. --metric_method The eval method including [F1]. Default: "F1"
  508. --do_train Enable train. Default: "false"
  509. --do_eval Enable evaluation. Default: "false"
  510. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
  511. --epoch_num Epoch number. Default: 1
  512. --train_data_shuffle Enable train data shuffle. Default: "true"
  513. --eval_data_shuffle Enable eval data shuffle. Default: "false"
  514. --save_finetune_ckpt_path Save the checkpoint path
  515. --load_pretrain_ckpt_path Load the checkpoint file path
  516. --load_finetune_ckpt_path Load the checkpoint file path
  517. --train_data_file_path Data path, it is better to use absolute path
  518. --eval_data_file_path Data path, it is better to use absolute path
  519. --tokenizer_file_path pretrained vocab and merge file path
  520. --generate_length The generation length of answer sentence
  521. --top_k Parameter for Top-K sampling
  522. --top_p Parameter for Top-P sampling
  523. --temperature Parameter for generation, greater if generation more diverse
  524. ```
  525. ### Summarization任务
  526. #### 评估
  527. GPT-2模型可以在`CNN_Dailymail`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="ROUGE"` 。
  528. 评估模型时,只需要使用shell脚本`run_summarization.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_summarization.py`脚本。
  529. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_summarization.sh`这个shell脚本,且对于`hint`的情况设置`eval_type="finetuned"`,`--load_finetune_ckpt_path`是需要加载微调好的checkpoint文件;而对于`no hint`的情况设置`eval_type="zero-shot"`除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
  530. ```bash
  531. sh scripts/run_summarization.sh --device_target="Ascend"
  532. --do_train="false"
  533. --do_eval="true"
  534. --metric_method="Rouge"
  535. --train_data_shuffle="true"
  536. --eval_data_shuffle="false"
  537. --generate_length=100
  538. --top_k=2
  539. --top_p="1.0"
  540. --temperature="1.0"
  541. --eval_type="finetuned"
  542. --load_finetune_ckpt_path={load_eval_ckpt_path}
  543. --eval_data_file_path={eval_data_file_path}
  544. --tokenizer_file_path={tokenizer_file_path}
  545. ```
  546. 日志和输出文件可以在`./ms_log/`路径下获取。
  547. ```bash
  548. sh scripts/run_summarization.sh [--options]
  549. ```
  550. `run_summarization.sh`的用法如下:
  551. ```text
  552. usage: run_summarization.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N]
  553. [--metric_method METRIC_METHOD]
  554. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  555. [--eval_type EVAL_TYPE] [--epoch_num N]
  556. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  557. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  558. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  559. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  560. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  561. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  562. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  563. options:
  564. --device_target Device type. Default: "Ascend"
  565. --device_id ID of target device
  566. --do_train Enable train. Default: false.
  567. --do_eval Enable evaluation. Default: false.
  568. --metric_method The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge. Default: "false"
  569. --epoch_num Epoch number. Default: 2.
  570. --train_data_shuffle Enable train data shuffle. Default: true.
  571. --eval_data_shuffle Enable eval data shuffle. Default: false.
  572. --save_finetune_ckpt_path Save the checkpoint path.
  573. --load_pretrain_ckpt_path Load the checkpoint file path.
  574. --load_finetune_ckpt_path Load the checkpoint file path.
  575. --train_data_file_path Data path, it is better to use absolute path.
  576. --eval_data_file_path Data path, it is better to use absolute path.
  577. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.
  578. --top_k Top k tokens chosen for sampling.
  579. --top_p Top p accumulated probability threshold for logit to be counted.
  580. --generate_length The number of generated tokens.
  581. --temperature Temperature on logits for sampling.
  582. --tokenizer_file_path Vocab & merge file path.
  583. ```
  584. ### Translation任务
  585. #### 评估
  586. GPT-2模型可以在`WMT14 En-Fr`和`WMT14 Fr-En`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用BLEU,即设置`--metric_method="BLEU"` 。
  587. 注:读者需要自行下载`bleu.py`脚本[脚本链接](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py), 而后将该脚本放置于`src/utils/`目录下
  588. 评估模型时,只需要使用shell脚本`run_translation.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_translation.py`脚本。
  589. 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_translation.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件
  590. ```bash
  591. sh scripts/run_translation.sh --device_target="Ascend"
  592. --metric_method="BLEU"
  593. --do_train="false"
  594. --do_eval="true"
  595. --eval_type="zero-shot"
  596. --train_data_shuffle="true"
  597. --eval_data_shuffle="false"
  598. --load_finetune_ckpt_path={load_eval_ckpt_path}
  599. --eval_data_file_path={eval_data_file_path}
  600. --tokenizer_file_path={tokenizer_file_path}
  601. --generate_length=100
  602. --top_k=1
  603. --top_p="1.0"
  604. --temperature="1.0"
  605. ```
  606. ```bash
  607. sh scripts/run_translation.sh [--options]
  608. ```
  609. ```text
  610. usage: run_translation.sh [--device_target DEVICE_TARGET] [--device_id N]
  611. [--metric_method METRIC_METHOD]
  612. [--do_train DO_TRAIN] [--do_eval DO_EVAL]
  613. [--eval_type EVAL_TYPE] [--epoch_num N]
  614. [--train_data_shuffle TRAIN_DATA_SHUFFLE]
  615. [--eval_data_shuffle EVAL_DATA_SHUFFLE]
  616. [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH]
  617. [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH]
  618. [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH]
  619. [--train_data_file_path TRAIN_DATA_FILE_PATH]
  620. [--eval_data_file_path EVAL_DATA_FILE_PATH]
  621. [--tokenizer_file_path TOKENIZER_FILE_PATH]
  622. [--generate_length N] [--top_k N] [--top_p TOP_P]
  623. [--temperature TEMPERATURE]
  624. options:
  625. --device_target Device type. Default: "Ascend"
  626. --device_id ID of target device
  627. --metric_method The eval method including [BLEU]. Default: "BLEU"
  628. --do_train Enable train. Default: "false"
  629. --do_eval Enable evaluation. Default: "true"
  630. --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot"
  631. --epoch_num Epoch number. Default: 1
  632. --train_data_shuffle Enable train data shuffle. Default: "true"
  633. --eval_data_shuffle Enable eval data shuffle. Default: "false"
  634. --save_finetune_ckpt_path Save the checkpoint path
  635. --load_pretrain_ckpt_path Load the checkpoint file path
  636. --load_finetune_ckpt_path Load the checkpoint file path
  637. --train_data_file_path Data path, it is better to use absolute path
  638. --eval_data_file_path Data path, it is better to use absolute path
  639. --tokenizer_file_path pretrained vocab and merge file path
  640. --generate_length The generation length of translation sentence
  641. --top_k Parameter for Top-K sampling
  642. --top_p Parameter for Top-P sampling
  643. --temperature Parameter for generation, greater if generation more diverse
  644. ```
  645. # 环境要求
  646. ## 平台
  647. - 硬件(Ascend)
  648. - 使用Ascend处理器准备硬件环境。
  649. - 框架
  650. - [MindSpore](https://www.mindspore.cn/install)
  651. - 更多关于Mindspore的信息,请查看以下资源:
  652. - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
  653. - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
  654. ## 其他要求
  655. ```text
  656. math
  657. numpy
  658. copy
  659. collections
  660. re
  661. rouge 1.0.0
  662. datasets >=0.4.0
  663. json
  664. tensorflow
  665. ```
  666. # 性能
  667. ## 推理性能
  668. ### Language Modeling任务
  669. 下表展示了GPT-2 small、medium、large三种规模的模型在Language Modeling任务中的PPL得分情况。
  670. | 模型 | dataset | device | eval_type | PPL | OpenAI |
  671. | :--- | :------ | :------ | :------ | :------ | :------ |
  672. | GPT-2 small | WikiText2 | Ascend | zero-shot | 24.5 | 29.41 |
  673. | GPT-2 medium | WikiText2 | Ascend | zero-shot | 19.41 | 22.76 |
  674. | GPT-2 large | WikiText2 | Ascend | zero-shot | 17.08 | 19.93 |
  675. | GPT-2 small | WikiText103 | Ascend | zero-shot | 26.89 | 37.5 |
  676. | GPT-2 medium | WikiText103 | Ascend | zero-shot | 20.23 | 26.37 |
  677. | GPT-2 large | WikiText103 | Ascend | zero-shot | 17.48 | 22.05 |
  678. | GPT-2 small | PTB | Ascend | finetune | 23.91 | 65.85 |
  679. | GPT-2 medium | PTB | Ascend | finetune | 20.06 | 47.33 |
  680. | GPT-2 large | PTB | Ascend | finetune | 18.84 | 40.31 |
  681. | GPT-2 small | 1BW | Ascend | zero-shot | 63.13 | 75.2 |
  682. | GPT-2 medium | 1BW | Ascend | zero-shot | 50.98 | 55.72 |
  683. | GPT-2 large | 1BW | Ascend | finetune | 29.28 | 44.575 |
  684. ### Children's Book Test 任务
  685. 下表展示了GPT-2 small、medium、large三种规模的模型在Children's Book Test 任务中的Accuracy得分情况。
  686. | 模型 | dataset | device | eval_type | ACC | OpenAI |
  687. | :--- | :------ | :------ | :------ | :------ | :------ |
  688. | GPT-2 small | CBT-CN valid | Ascend | zero-shot | 87.85 | 87.65 |
  689. | GPT-2 medium | CBT-CN valid | Ascend | zero-shot | 92.1 | 92.35 |
  690. | GPT-2 large | CBT-CN valid | Ascend | zero-shot | 93.7 | 93.45 |
  691. | GPT-2 small | CBT-NE valid | Ascend | zero-shot | 85.1 | 83.4 |
  692. | GPT-2 medium | CBT-NE valid | Ascend | zero-shot | 87.55 | 87.1 |
  693. | GPT-2 large | CBT-NE valid | Ascend | zero-shot | 89.1 | 88 |
  694. ### LAMBADA 任务
  695. 下表展示了GPT-2 small、medium、large三种规模的模型在LAMBADA 任务中的Accuracy和PPL得分情况。
  696. | 模型 | dataset | device | eval_type | ACC | OpenAI |
  697. | :--- | :------ | :------ | :------ | :------ | :------ |
  698. | GPT-2 small | Lambada-test | Ascend | zero-shot | 45.99 | 45.99 |
  699. | GPT-2 medium | Lambada-test | Ascend | zero-shot | 58.59 | 55.48 |
  700. | GPT-2 large | Lambada-test | Ascend | zero-shot | 62.74 | 60.12 |
  701. | 模型 | dataset | device | eval_type | PPL | OpenAI |
  702. | :--- | :------ | :------ | :------ | :------ | :------ |
  703. | GPT-2 small | Lambada-test | Ascend | zero-shot | 22.95 | 35.13 |
  704. | GPT-2 medium | Lambada-test | Ascend | zero-shot | 10.69 | 15.6 |
  705. | GPT-2 large | Lambada-test | Ascend | zero-shot | 8.64 | 10.87 |
  706. ### Reading Comprehension 任务
  707. 下表展示了GPT-2 small、medium、large三种规模的模型在Reading Comprehension任务中的F1得分情况。
  708. | 模型 | dataset | device | eval_type | F1 | OpenAI |
  709. | :--- | :------ | :------ | :------ | :------ | :------ |
  710. | GPT-2 small | CoQA | Ascend | zero-shot | 25.94 | 25~26 |
  711. | GPT-2 medium | CoQA | Ascend | zero-shot | 43.69 | 42~43 |
  712. | GPT-2 large | CoQA | Ascend | zero-shot | 49.39 | 49~51 |
  713. ### Summarization 任务
  714. 下表展示了GPT-2 small、medium、large三种规模的模型在Summarization任务中的ROUGE得分情况。
  715. | 模型 | dataset | device | eval_type | ROUGE | OpenAI |
  716. | :--- | :------ | :------ | :------ | :------ | :------ |
  717. | GPT-2 small | CNN_Dailymail(TL;DR) | Ascend | finetune | 21.4 | 16.8~17 |
  718. | GPT-2 medium | CNN_Dailymail(TL;DR) | Ascend | finetune | 25.94 | 20.6~20.9 |
  719. | GPT-2 large | CNN_Dailymail(TL;DR) | Ascend | finetune | 26.73 | 21.5~21.6 |
  720. | 模型 | dataset | device | eval_type | ROUGE | OpenAI |
  721. | :--- | :------ | :------ | :------ | :------ | :------ |
  722. | GPT-2 small | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.08 | 15.03(xlarge) |
  723. | GPT-2 medium | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.16 | 15.03(xlarge) |
  724. | GPT-2 large | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.29 | 15.03(xlarge) |
  725. ### Translation 任务
  726. 下表展示了GPT-2 small、medium、large三种规模的模型在Translation任务中的BLEU得分情况。
  727. | 模型 | dataset | device | eval_type | BLEU | OpenAI |
  728. | :--- | :------ | :------ | :------ | :------ | :------ |
  729. | GPT-2 small | WMT-14 Fr-En | Ascend | zero-shot | 4.49 | 0.7~0.8 |
  730. | GPT-2 medium | WMT-14 Fr-En | Ascend | zero-shot | 7.09 | 2.0~3.0 |
  731. | GPT-2 large | WMT-14 Fr-En | Ascend | zero-shot | 7.97 | 6.5~7.0 |
  732. | GPT-2 small | WMT-14 En-Fr | Ascend | zero-shot | 2.81 | 5(xlarge) |
  733. | GPT-2 medium | WMT-14 En-Fr | Ascend | zero-shot | 3.2 | 5(xlarge) |
  734. | GPT-2 large | WMT-14 En-Fr | Ascend | zero-shot | 3.06 | 5(xlarge) |
  735. # 其他
  736. 该模型已在Ascend环境下环境下得到验证。
  737. # ModelZoo主页
  738. [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)