| @@ -0,0 +1,931 @@ | |||
| # 目录 | |||
| <!-- TOC --> | |||
| - [目录](#目录) | |||
| - [GPT-2模型](#GPT-2模型) | |||
| - [模型架构](#模型架构) | |||
| - [下游任务](#下游任务) | |||
| - [脚本说明](#脚本说明) | |||
| - [模型转换](#模型转换) | |||
| - [准备数据集](#准备数据集) | |||
| - [Language Modeling 语言建模任务](#Language Modeling语言建模任务) | |||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||
| - [LAMBADA 任务](#LAMBADA任务) | |||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||
| - [Summarization 任务](#Summarization任务) | |||
| - [Translation 任务](#Translation任务) | |||
| - [配置](#配置) | |||
| - [微调&评估过程](#微调&训练评估过程) | |||
| - [Language Modeling 任务](#Language Modeling任务) | |||
| - 微调 | |||
| - 评估 | |||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||
| - 评估 | |||
| - [LAMBADA 任务](#LAMBADA任务) | |||
| - 评估 | |||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||
| - 评估 | |||
| - [Summarization 任务](#Summarization任务) | |||
| - 评估 | |||
| - [Translation 任务](#Translation任务) | |||
| - 评估 | |||
| - [环境要求](#环境要求) | |||
| - [平台](#平台) | |||
| - [其他要求](#其他要求) | |||
| - [性能](#性能) | |||
| - [推理性能](#推理性能) | |||
| - [Language Modeling 任务](#Language Modeling任务) | |||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||
| - [LAMBADA 任务](#LAMBADA任务) | |||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||
| - [Summarization 任务](#Summarization任务) | |||
| - [Translation 任务](#Translation任务) | |||
| - [训练性能](#训练性能) | |||
| - [推理性能](#推理性能) | |||
| - [其他](#其他) | |||
| - [ModelZoo主页](#modelzoo主页) | |||
| <!-- /TOC --> | |||
| # GPT-2模型 | |||
| [GPT-2介绍](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 由Open于2019年发布。GPT-2模型是继承于GPT模型,GPT-2是一个非常庞大的语言模型,它主要是用于预测下一个单词。按照参数量的大小,GPT-2模型可分为small(117M)、medium(345M)、large(762M)、xlarge(1542M)。 | |||
| [GPT-2介绍](https://openai.com/blog/better-language-models/) | |||
| [GPT-2论文](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf): Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9. | |||
| # 模型架构 | |||
| GPT-2模型由Transformer的解码器实现,Transformer包括多个编码器层和多个解码器层,但在GPT-2模型中仅使用了Transformer的解码器部分。 | |||
| 微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。 | |||
| 测试过程中,通过微调后的模型预测结果,对于某些任务可以直接进行zero-shot评估即可。 | |||
| # 下游任务 | |||
| 本文主要涉及6个下游任务,包括: | |||
| - Language Modeling 任务 | |||
| - Children‘s Book Test 任务 | |||
| - LAMBADA任务 | |||
| - Reading Comprehension任务 | |||
| - Summarization任务 | |||
| - Translation任务 | |||
| 数据集相关信息,参见[https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)。 | |||
| ## 脚本说明 | |||
| GPT-2脚本及代码结构如下: | |||
| ```text | |||
| ├── GPT-2 | |||
| ├── README.md // MASS模型介绍 | |||
| ├── scripts | |||
| │ ├──run_cbt.sh // CBT任务的微调&评估脚本 | |||
| │ ├──run_lambada.sh // LAMBADA任务的微调&评估脚本 | |||
| │ ├──run_language_model.sh // 语言建模任务的微调&评估脚本 | |||
| │ ├──run_read_comprehension.sh // 阅读理解任务的微调&评估脚本 | |||
| │ ├──run_summarization.sh // 摘要生成任务的微调&评估脚本 | |||
| │ ├──run_translation.sh // 翻译任务的微调&评估脚本 | |||
| ├──src | |||
| │ ├──clip_grad_utils.py // 用于梯度裁剪 | |||
| | ├──dataset.py // 数据集加载用于微调或推理 | |||
| │ ├──finetune_eval_config.py // 微调和推理配置文件 | |||
| │ ├──gpt2_for_finetune.py // 用于梯度裁剪 | |||
| | ├──GPT2_generation.py // 生成模块 | |||
| │ ├──GPT2_model.py // GPT2模型脚本 | |||
| │ ├──GPT2ForCBT.py // CBT任务的模型脚本 | |||
| │ ├──GPT2ForLanguageModel.py // 语言建模任务的模型脚本 | |||
| │ ├──GPT2ForReadComprehension.py // 阅读理解任务的模型脚本 | |||
| │ ├──GPT2ForSummarization.py // 摘要生成任务的模型脚本 | |||
| │ ├──GPT2ForTranslation.py // 翻译任务的模型脚本 | |||
| │ ├──weight_init.py // 初始化权重 | |||
| │ ├──utils | |||
| │ ├──bleu_score.py // 用于计算BLEU分数 | |||
| │ ├──rouge_score.py // 用于计算ROUGE分数 | |||
| │ ├──CrossEntropy.py // 交叉熵损失 | |||
| │ ├──data_preprocess.py // 数据集预处理脚本 | |||
| │ ├──generation_utils.py // 用于帮助生成模型,包含采样等方法 | |||
| │ ├──get_config_setting.py // 获取配置信息 | |||
| │ ├──task_utils.py // 辅助下游任务的功能脚本 | |||
| │ ├──lr_schedule.py // 学习率策略脚本 | |||
| │ ├──metric_method.py // 下游任务的评价指标 | |||
| │ ├──tensor_manipulations.py // 涉及张量操作 | |||
| │ ├──tokenization.py // 标记化,包含BPE编码和解码 | |||
| │ ├──pretrain-data | |||
| │ ├──stopwords.txt // 用于LAMBADA任务的stopword filter | |||
| ├──create_cbt_data.py // 用于CBT任务创建mindrecord | |||
| ├──create_lambada_data.py // 用于lambada任务创建mindrecord | |||
| ├──create_lambada_data.py // 用于其他任务创建mindrecord | |||
| ├──create_summary_data.py // 用于summarization任务创建mindrecord | |||
| ├──download_cnn_dailymail.py // 下载CNN & Dailymail数据集 | |||
| ├──cnn_dataset_sampler.py // CNN & Dailymail训练集采样器 | |||
| ├──eval_rc_addition_answer.py // 使用addition_answer评估阅读理解任务 | |||
| ├──run_CBT_task.py // CBT任务微调&推理API入口 | |||
| ├──run_lambada.py // LAMBADA任务微调&推理API入口 | |||
| ├──run_language_mdoel.py // 语言建模任务微调&推理API入口 | |||
| ├──run_ReadComprehension.py // 阅读理解任务微调&推理API入口 | |||
| ├──run_summarization.py // 摘要生成任务微调&推理API入口 | |||
| ├──run_translation.py // 翻译任务微调&推理API入口 | |||
| ├──task_dataset_preprocess.py // 各个任务的数据集处理入口 | |||
| ├──convert_tf_ckpt | |||
| │ ├──read_weight_tf.py // 读取tensorflow下的预训练模型 | |||
| │ ├──trans_dict.py // 模型参数名称字典 | |||
| │ ├──save_weight_ms.py // 生成mindspore ckpt | |||
| ├──third_party | |||
| │ ├──gpt2-merges.txt | |||
| │ ├──gpt2-vocab.json // GPT-2预训练词表 | |||
| │ ├──bleu.py // 辅助bleu值计算的第三方代码 | |||
| ``` | |||
| ## 模型转换 | |||
| - 下载GPT-2的预训练模型 [GPT-2预训练模型下载](https://github.com/openai/gpt-2/blob/master/download_model.py) | |||
| - 在tensorflow的环境下,运行`read_weight_tf.py`,示例代码如下: | |||
| `python read_weight_tf.py --ckpt_file_path=/{path}/model.ckpt` | |||
| - 在mindspore的环境下,运行`save_weight_ms.py`,示例代码如下: | |||
| `python save_weight_ms.py --output_file_name="mindspore_gpt2_small.ckpt"` | |||
| ## 准备数据集 | |||
| ### Language Modeling语言建模任务 | |||
| #### WikiText2 、WikiText103、PTB、1BW 数据集 | |||
| - [WikiText2数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) 解压后使用`wikitext-2 /wiki.test.tokens`作为测试集 | |||
| - [WikiText103数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 解压后使用`wikitext-103 /wiki.test.tokens`作为测试集 | |||
| - [PTB数据集下载](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz) 解压后使用 `/simple-examples/data/ptb.test.txt` 测试集,使用 `/simple-examples/data/ptb.test.txt` 作为训练集 | |||
| - [1BW数据集下载](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz) 解压后使用`1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050`作为测试集,使用`1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00001-of-00100`作为原始训练集,进行随机采样后得到30000条训练集样本 | |||
| 使用`task_dataset_preprocess.py`可以对以上数据集进行清洗。 | |||
| `task_dataset_preprocess.py`的主要参数如下: | |||
| ```bash | |||
| --task: The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada, Summarization, ReadingComprehension]. | |||
| --input_file: The raw dataset path. | |||
| --dataset: The name of dataset which should be processed, only for LanguageModeling task. | |||
| --output_file: The output dataset path after preprocessing. | |||
| --condition: Process train or test dataset, including [train, test], only for 1BW and CNN & DailyMail dataset. | |||
| ``` | |||
| 示例代码如下: | |||
| 清洗PTB训练集和测试集 | |||
| ```bash | |||
| python task_dataset_preprocess.py --task "LanguageModeling" --input_file /{path}/ptb.test.txt --dataset "ptb" --output_file /{path}/ptb_clean_test.txt --condition "test" | |||
| ``` | |||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||
| `create_lm_data.py`的主要参数如下: | |||
| ```bash | |||
| --input_file: Input raw text file. | |||
| --output_file: Output MindRecord file. | |||
| --num_splits: The MindRecord file will be split into the number of partition. | |||
| --max_seq_length: Maximum sequence length. | |||
| --vocab_file: url of gpt2-vocab.json. | |||
| --merge_file: url of gpt2-merges.txt | |||
| ``` | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_lm_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||
| ``` | |||
| ### Children's Book Test任务 | |||
| #### CBT-CN / CBT-NE 数据集 | |||
| - [CBT数据集下载](http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz) 使用在`/data`目录下使用`cbtest_CN_valid_2000ex.txt、cbtest_NE_valid_2000ex.txt`作为该任务的评估集,清洗该数据集,示例代码如下: | |||
| ```bash | |||
| python task_dataset_preprocess.py --task "CBT" --input_file /{path}/cbtest_CN_valid_2000ex.txt --dataset "cbt" --output_file /{path}/cbt_cn_valid.txt | |||
| ``` | |||
| 使用`create_cbt_data.py`可以将以上数据集格式转换为mindrecord | |||
| `create_cbt_data.py`的主要参数如下: | |||
| ```bash | |||
| --input_file: Input raw text file. | |||
| --output_file: Output MindRecord file. | |||
| --num_splits: The MindRecord file will be split into the number of partition. | |||
| --max_seq_length: Maximum sequence length. | |||
| --num_choice: Number of choices. | |||
| --vocab_file: url of gpt2-vocab.json. | |||
| --merge_file: url of gpt2-merges.txt | |||
| ``` | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_cbt_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --num_choice 10 --vocab_file={path} --merge_file={path} | |||
| ``` | |||
| ### LAMBADA任务 | |||
| #### LAMBADA 数据集 | |||
| - [LAMBADA数据集下载](https://zenodo.org/record/2630551#.X-yCSTTithH) 使用`lambada_test_plain_text.txt`作为该任务的评估集,清洗该数据集,示例代码如下: | |||
| ```bash | |||
| python task_dataset_preprocess.py --task "LAMBADA" --input_file /{path}/lambada_test_plain_text.txt --dataset "LAMBADA" --output_file /{path}/lambada_test_clean.txt | |||
| ``` | |||
| 使用`create_lambada_data.py`可以将以上数据集格式转换为mindrecord | |||
| `create_lambada_data.py`的主要参数如下: | |||
| ```bash | |||
| --input_file: Input raw text file. | |||
| --output_file: Output MindRecord file. | |||
| --num_splits: The MindRecord file will be split into the number of partition. | |||
| --max_seq_length: Maximum sequence length. | |||
| --vocab_file: url of gpt2-vocab.json. | |||
| --merge_file: url of gpt2-merges.txt | |||
| ``` | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_lambada_data.py --input_file /{path}/lambada_test_clean.txt --output_file /{path}/lambada-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||
| ``` | |||
| ### Reading Comprehension 任务 | |||
| #### CoQA数据集 | |||
| - [CoQA数据集下载](http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json) 使用`coqa-dev-v1.0.json`作为该任务的评估集,清洗该数据集,示例代码如下: | |||
| ```bash | |||
| python task_dataset_preprocess.py --task "ReadingComprehension" --input_file /{path}/coqa-dev-v1.0.json --dataset "coqa" --output_file /{path}/coqa_dev.txt | |||
| ``` | |||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_lm_data.py --input_file /{path}/coqa_dev.txt --output_file /{path}/coqa-dev-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||
| ``` | |||
| ### Summarization 任务 | |||
| #### CNN & Dailymail数据集 | |||
| - 下载该数据集,使用`download_cnn_dailymail.py`脚本进行下载,示例代码如下: | |||
| ```bash | |||
| 下载测试集 | |||
| python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split test | |||
| 下载训练集 | |||
| python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split train | |||
| ``` | |||
| 从训练集中随机采用10000条样本作为最终的微调的训练集,使用`cnn_dataset_sampler.py`脚本进行训练的采样操作,生成新的训练集,示例代码如下: | |||
| ```bash | |||
| GPT-2 small和GPT-2 medium模型的训练集中seq_length=1024, 因此该脚本中设置max_length=1022 | |||
| python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt" | |||
| --output_path="/{path}/cnn_train_hint_small.txt" | |||
| --replace_hint="true" | |||
| --sample="true" | |||
| --max_length=1022 | |||
| --prob=0.25 | |||
| --max_items=10000 | |||
| --hint="TL;DR:" | |||
| GPT-2 large模型的训练集中seq_length=768,因此该脚本中设置max_length=766 | |||
| python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt" | |||
| --output_path="/{path}/cnn_train_hint_large.txt" | |||
| --replace_hint="true" | |||
| --sample="true" | |||
| --max_length=766 | |||
| --prob=0.25 | |||
| --max_items=10000 | |||
| --hint="TL;DR:" | |||
| ``` | |||
| 使用`create_summary_data.py`可以将以上数据集格式转换为mindrecord | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_summary_data.py --input_file /{path}/cnn_dailymail_test.txt --output_file /{path}/cnn_dailymail-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} --mode 'cnn_dailymail' | |||
| ``` | |||
| ### Translation 任务 | |||
| #### WMT14 En-Fr数据集 | |||
| - [WMT14 En-Fr数据集下载](http://statmt.org/wmt14/test-full.tgz) 使用`newstest2014-fren-ref.en.sgm`和`newstest2014-fren-ref.fr.sgm`作为该任务的评估集,合并且清洗该数据集,示例代码如下: | |||
| ```bash | |||
| python task_dataset_preprocess.py --task "Translation" --input_file /{path}/test-full --dataset "wmt14" --output_file /{path}/wmt14 | |||
| ``` | |||
| 在`output_file`路径下会生成两个文件`wmt14.en_fr.txt`和`wmt14.fr_en.txt`,分别用于评估`En-Fr`和`Fr-En`。 | |||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||
| 示例代码如下: | |||
| ```bash | |||
| python create_lm_data.py --input_file /{path}/wmt14.en_fr.txt --output_file /{path}/en-fr-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||
| python create_lm_data.py --input_file /{path}/wmt14.fr_en.txt --output_file /{path}/fr-en-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||
| ``` | |||
| ## 配置 | |||
| `src/finetune_eval_config.py`为GPT-2模型训练和推理的配置文件,便于为大多数选项及参数赋值,包括GPT-2 模型规模、模型的配置、优化器参数等。 | |||
| 有关属性的详细信息,参见`src/finetune_eval_config.py`文件。 | |||
| ## 微调&评估过程 | |||
| ### Language Modeling 语言建模任务 | |||
| #### 微调 | |||
| - PTB数据集 | |||
| GPT-2 small / GPT-2 medium / GPT-2 large模型需要在PTB训练集上进行微调。微调模型时,只需要使用shell脚本`scripts/run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`scripts/run_language_model.sh`脚本。 | |||
| 微调模型时,首先配置`src/finetune_eval_config.py`中的选项: | |||
| 将`cfg`下的`gpt2_network`设置为相应的GPT-2模型大小`[small/medium/large]`。 | |||
| 将`cfg`下的`optimizer`设置为`Lamb`,进行优化器的选择(可采用'momentum/adam/lamb’)。 | |||
| 选定了GPT-2模型后需要设置模型的参数,包括`batch_size`和`seq_length`。 | |||
| 而后执行`scripts/run_language_model.sh`这个shell脚本: | |||
| ```bash | |||
| sh scripts/run_language_model.sh --device_target="Ascend" | |||
| --do_train="true" | |||
| --do_eval="false" | |||
| --epoch_num=1 | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --save_finetune_ckpt_path={save_finetune_ckpt_path} | |||
| --load_pretrain_ckpt_path={load_pretrain_ckpt_path} | |||
| --train_data_file_path={train_data_file_path} | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||
| ```bash | |||
| sh scripts/run_language_model.sh [--options] | |||
| ``` | |||
| `run_language_model.sh`的用法如下: | |||
| ```text | |||
| usage: run_language_model.sh [--device_target DEVICE_TARGET] [--device_id N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --metric_method The eval method including [PPL]. Default: "PPL" | |||
| --do_train Enable train. Default: "false" | |||
| --do_eval Enable evaluation. Default: "true" | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||
| --epoch_num Epoch number. Default: 1 | |||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||
| --save_finetune_ckpt_path Save the finetuned checkpoint path | |||
| --load_pretrain_ckpt_path Load the checkpoint file path for train | |||
| --load_finetune_ckpt_path Load the checkpoint file path for evaluation | |||
| --train_data_file_path Data path, it is better to use absolute path | |||
| --eval_data_file_path Data path, it is better to use absolute path | |||
| ``` | |||
| - 1BW数据集 | |||
| GPT-2 large模型需要在1BW训练集上进行微调。微调模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。该微调方法与PTB数据集的一致。 | |||
| #### 评估 | |||
| GPT-2模型可以在`WikiText2/WikiText103/PTB/1BW`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用PPL,即设置`--metric_method="PPL"`。 | |||
| 评估模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_language_model.sh`这个shell脚本,若该模型在某个数据集上被微调了,则使用该模型进行对应测试集的评估时需要设置`--eval_type="finetuned"`,否则设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是微调好后的checkpoint文件位置 | |||
| ```bash | |||
| sh scripts/run_language_model.sh --device_target="Ascend" | |||
| --metric_method="PPL" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="finetuned" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||
| ### Children's Book Test任务 | |||
| #### 评估 | |||
| GPT-2模型可以在`CBT-CN/CBT-NE`验证集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy,即设置`--metric_method="Accuracy"`。 | |||
| 评估模型时,只需要使用shell脚本`run_cbt.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_CBT_task.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_cbt.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||
| ```bash | |||
| sh scripts/run_cbt.sh --device_target="Ascend" | |||
| --num_choice=10 | |||
| --metric_method="Accuarcy" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="zero-shot" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||
| ```bash | |||
| sh scripts/run_cbt.sh [--options] | |||
| ``` | |||
| `run_cbt.sh`的用法如下: | |||
| ```text | |||
| usage: run_CBT_task.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --num_choice The number of choice in CBT task | |||
| --metric_method The eval method including [Accuracy]. Default: "Accuracy" | |||
| --do_train Enable train. Default: "false" | |||
| --do_eval Enable evaluation. Default: "true" | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||
| --epoch_num Epoch number. Default: 1 | |||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||
| --save_finetune_ckpt_path Save the finetuned checkpoint path | |||
| --load_pretrain_ckpt_path Load the checkpoint file path for train | |||
| --load_finetune_ckpt_path Load the checkpoint file path for evaluation | |||
| --train_data_file_path Data path, it is better to use absolute path | |||
| --eval_data_file_path Data path, it is better to use absolute path | |||
| ``` | |||
| ### LAMBADA任务 | |||
| #### 评估 | |||
| GPT-2模型可以在`LAMBADA`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy和PPL,即设置`--metric_method="Accuracy"` 或者`--metric_method="PPL"`。 | |||
| 评估模型时,只需要使用shell脚本`run_lambada.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_lambada.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_lambada.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||
| 评估Accuracy | |||
| ```bash | |||
| sh scripts/run_lambada.sh --device_target="Ascend" | |||
| --metric_method="Accuarcy" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="zero-shot" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --generate_length_dynamically="true" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| --tokenizer_file_path={tokenizer_file_path} | |||
| --stop_word_file_path={stop_word_file_path} | |||
| ``` | |||
| 评估PPL | |||
| ```bash | |||
| sh scripts/run_lambada.sh --device_target="Ascend" | |||
| --metric_method="PPL" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="zero-shot" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||
| ```bash | |||
| sh scripts/run_lambada.sh [--options] | |||
| ``` | |||
| ```text | |||
| usage: run_lambada.sh [--device_target DEVICE_TARGET] [--device_id N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--generate_length_dynamically GENERATE_LENGTH_DYNAMICALLY] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||
| [--stop_word_file_path STOP_WORD_FILE_PATH] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --metric_method The eval method including [Accuracy, PPL]. Default: "Accuracy" | |||
| --do_train Enable train. Default: "false" | |||
| --do_eval Enable evaluation. Default: "true" | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||
| --epoch_num Epoch number. Default: 1 | |||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||
| --generate_length_dynamically Enable generate_length_Dynamically. Default: "true" | |||
| --save_finetune_ckpt_path Save the checkpoint path | |||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||
| --load_finetune_ckpt_path Load the checkpoint file path | |||
| --train_data_file_path Data path, it is better to use absolute path | |||
| --eval_data_file_path Data path, it is better to use absolute path | |||
| --tokenizer_file_path pretrained vocab and merge file path | |||
| --stop_word_file_path The stop word file path | |||
| ``` | |||
| ### Reading Comprehension任务 | |||
| #### 评估 | |||
| GPT-2模型可以在`CoQA`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="F1"` 。 | |||
| 评估模型时,只需要使用shell脚本`run_read_comprehension.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_read_comprehension.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_read_comprehension.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||
| ```bash | |||
| sh scripts/run_read_comprehension.sh --device_target="Ascend" | |||
| --metric_method="F1" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="zero-shot" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| --tokenizer_file_path={tokenizer_file_path} | |||
| --generate_length=55 | |||
| --top_k=1 | |||
| --top_p="1.0" | |||
| --temperature="1.0" | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。而后将得到的日志文件作为`eval_rc_addition_answer.py`脚本的`input_file`,同时将原CoQA开发集`coqa-dev-v1.0.json`作为`addition_file`。 | |||
| 执行`python eval_rc_addition_answer.py --input_file={path} --addition_file={path}`得到最终的F1值。 | |||
| ```bash | |||
| sh scripts/run_read_comprehension.sh [--options] | |||
| ``` | |||
| ```text | |||
| usage: run_read_comprehension.sh [--device_target DEVICE_TARGET] [--device_id N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||
| [--generate_length N] [--top_k N] [--top_p TOP_P] | |||
| [--temperature TEMPERATURE] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --metric_method The eval method including [F1]. Default: "F1" | |||
| --do_train Enable train. Default: "false" | |||
| --do_eval Enable evaluation. Default: "false" | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||
| --epoch_num Epoch number. Default: 1 | |||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||
| --save_finetune_ckpt_path Save the checkpoint path | |||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||
| --load_finetune_ckpt_path Load the checkpoint file path | |||
| --train_data_file_path Data path, it is better to use absolute path | |||
| --eval_data_file_path Data path, it is better to use absolute path | |||
| --tokenizer_file_path pretrained vocab and merge file path | |||
| --generate_length The generation length of answer sentence | |||
| --top_k Parameter for Top-K sampling | |||
| --top_p Parameter for Top-P sampling | |||
| --temperature Parameter for generation, greater if generation more diverse | |||
| ``` | |||
| ### Summarization任务 | |||
| #### 评估 | |||
| GPT-2模型可以在`CNN_Dailymail`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="ROUGE"` 。 | |||
| 评估模型时,只需要使用shell脚本`run_summarization.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_summarization.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_summarization.sh`这个shell脚本,且对于`hint`的情况设置`eval_type="finetuned"`,`--load_finetune_ckpt_path`是需要加载微调好的checkpoint文件;而对于`no hint`的情况设置`eval_type="zero-shot"`除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||
| ```bash | |||
| sh scripts/run_summarization.sh --device_target="Ascend" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --metric_method="Rouge" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --generate_length=100 | |||
| --top_k=2 | |||
| --top_p="1.0" | |||
| --temperature="1.0" | |||
| --eval_type="finetuned" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| --tokenizer_file_path={tokenizer_file_path} | |||
| ``` | |||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||
| ```bash | |||
| sh scripts/run_summarization.sh [--options] | |||
| ``` | |||
| `run_summarization.sh`的用法如下: | |||
| ```text | |||
| usage: run_summarization.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --do_train Enable train. Default: false. | |||
| --do_eval Enable evaluation. Default: false. | |||
| --metric_method The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge. Default: "false" | |||
| --epoch_num Epoch number. Default: 2. | |||
| --train_data_shuffle Enable train data shuffle. Default: true. | |||
| --eval_data_shuffle Enable eval data shuffle. Default: false. | |||
| --save_finetune_ckpt_path Save the checkpoint path. | |||
| --load_pretrain_ckpt_path Load the checkpoint file path. | |||
| --load_finetune_ckpt_path Load the checkpoint file path. | |||
| --train_data_file_path Data path, it is better to use absolute path. | |||
| --eval_data_file_path Data path, it is better to use absolute path. | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: zero-shot. | |||
| --top_k Top k tokens chosen for sampling. | |||
| --top_p Top p accumulated probability threshold for logit to be counted. | |||
| --generate_length The number of generated tokens. | |||
| --temperature Temperature on logits for sampling. | |||
| --tokenizer_file_path Vocab & merge file path. | |||
| ``` | |||
| ### Translation任务 | |||
| #### 评估 | |||
| GPT-2模型可以在`WMT14 En-Fr`和`WMT14 Fr-En`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用BLEU,即设置`--metric_method="BLEU"` 。 | |||
| 注:读者需要自行下载`bleu.py`脚本[脚本链接](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py), 而后将该脚本放置于`src/utils/`目录下 | |||
| 评估模型时,只需要使用shell脚本`run_translation.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_translation.py`脚本。 | |||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_translation.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||
| ```bash | |||
| sh scripts/run_translation.sh --device_target="Ascend" | |||
| --metric_method="BLEU" | |||
| --do_train="false" | |||
| --do_eval="true" | |||
| --eval_type="zero-shot" | |||
| --train_data_shuffle="true" | |||
| --eval_data_shuffle="false" | |||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||
| --eval_data_file_path={eval_data_file_path} | |||
| --tokenizer_file_path={tokenizer_file_path} | |||
| --generate_length=100 | |||
| --top_k=1 | |||
| --top_p="1.0" | |||
| --temperature="1.0" | |||
| ``` | |||
| ```bash | |||
| sh scripts/run_translation.sh [--options] | |||
| ``` | |||
| ```text | |||
| usage: run_translation.sh [--device_target DEVICE_TARGET] [--device_id N] | |||
| [--metric_method METRIC_METHOD] | |||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||
| [--generate_length N] [--top_k N] [--top_p TOP_P] | |||
| [--temperature TEMPERATURE] | |||
| options: | |||
| --device_target Device type. Default: "Ascend" | |||
| --device_id ID of target device | |||
| --metric_method The eval method including [BLEU]. Default: "BLEU" | |||
| --do_train Enable train. Default: "false" | |||
| --do_eval Enable evaluation. Default: "true" | |||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||
| --epoch_num Epoch number. Default: 1 | |||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||
| --save_finetune_ckpt_path Save the checkpoint path | |||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||
| --load_finetune_ckpt_path Load the checkpoint file path | |||
| --train_data_file_path Data path, it is better to use absolute path | |||
| --eval_data_file_path Data path, it is better to use absolute path | |||
| --tokenizer_file_path pretrained vocab and merge file path | |||
| --generate_length The generation length of translation sentence | |||
| --top_k Parameter for Top-K sampling | |||
| --top_p Parameter for Top-P sampling | |||
| --temperature Parameter for generation, greater if generation more diverse | |||
| ``` | |||
| # 环境要求 | |||
| ## 平台 | |||
| - 硬件(Ascend) | |||
| - 使用Ascend处理器准备硬件环境。- 如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,申请通过即可获得资源。 | |||
| - 框架 | |||
| - [MindSpore](https://www.mindspore.cn/install) | |||
| - 更多关于Mindspore的信息,请查看以下资源: | |||
| - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) | |||
| ## 其他要求 | |||
| ```text | |||
| math | |||
| numpy | |||
| copy | |||
| collections | |||
| re | |||
| rouge 1.0.0 | |||
| datasets >=0.4.0 | |||
| json | |||
| tensorflow | |||
| ``` | |||
| # 性能 | |||
| ## 推理性能 | |||
| ### Language Modeling任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Language Modeling任务中的PPL得分情况。 | |||
| | 模型 | dataset | device | eval_type | PPL | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | WikiText2 | Ascend | zero-shot | 24.5 | 29.41 | | |||
| | GPT-2 medium | WikiText2 | Ascend | zero-shot | 19.41 | 22.76 | | |||
| | GPT-2 large | WikiText2 | Ascend | zero-shot | 17.08 | 19.93 | | |||
| | GPT-2 small | WikiText103 | Ascend | zero-shot | 26.89 | 37.5 | | |||
| | GPT-2 medium | WikiText103 | Ascend | zero-shot | 20.23 | 26.37 | | |||
| | GPT-2 large | WikiText103 | Ascend | zero-shot | 17.48 | 22.05 | | |||
| | GPT-2 small | PTB | Ascend | finetune | 23.91 | 65.85 | | |||
| | GPT-2 medium | PTB | Ascend | finetune | 20.06 | 47.33 | | |||
| | GPT-2 large | PTB | Ascend | finetune | 18.84 | 40.31 | | |||
| | GPT-2 small | 1BW | Ascend | zero-shot | 63.13 | 75.2 | | |||
| | GPT-2 medium | 1BW | Ascend | zero-shot | 50.98 | 55.72 | | |||
| | GPT-2 large | 1BW | Ascend | finetune | 29.28 | 44.575 | | |||
| ### Children's Book Test 任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Children's Book Test 任务中的Accuracy得分情况。 | |||
| | 模型 | dataset | device | eval_type | ACC | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | CBT-CN valid | Ascend | zero-shot | 87.85 | 87.65 | | |||
| | GPT-2 medium | CBT-CN valid | Ascend | zero-shot | 92.1 | 92.35 | | |||
| | GPT-2 large | CBT-CN valid | Ascend | zero-shot | 93.7 | 93.45 | | |||
| | GPT-2 small | CBT-NE valid | Ascend | zero-shot | 85.1 | 83.4 | | |||
| | GPT-2 medium | CBT-NE valid | Ascend | zero-shot | 87.55 | 87.1 | | |||
| | GPT-2 large | CBT-NE valid | Ascend | zero-shot | 89.1 | 88 | | |||
| ### LAMBADA 任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在LAMBADA 任务中的Accuracy和PPL得分情况。 | |||
| | 模型 | dataset | device | eval_type | ACC | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | Lambada-test | Ascend | zero-shot | 45.99 | 45.99 | | |||
| | GPT-2 medium | Lambada-test | Ascend | zero-shot | 58.59 | 55.48 | | |||
| | GPT-2 large | Lambada-test | Ascend | zero-shot | 62.74 | 60.12 | | |||
| | 模型 | dataset | device | eval_type | PPL | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | Lambada-test | Ascend | zero-shot | 22.95 | 35.13 | | |||
| | GPT-2 medium | Lambada-test | Ascend | zero-shot | 10.69 | 15.6 | | |||
| | GPT-2 large | Lambada-test | Ascend | zero-shot | 8.64 | 10.87 | | |||
| ### Reading Comprehension 任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Reading Comprehension任务中的F1得分情况。 | |||
| | 模型 | dataset | device | eval_type | F1 | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | CoQA | Ascend | zero-shot | 25.94 | 25~26 | | |||
| | GPT-2 medium | CoQA | Ascend | zero-shot | 43.69 | 42~43 | | |||
| | GPT-2 large | CoQA | Ascend | zero-shot | 49.39 | 49~51 | | |||
| ### Summarization 任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Summarization任务中的ROUGE得分情况。 | |||
| | 模型 | dataset | device | eval_type | ROUGE | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | CNN_Dailymail(TL;DR) | Ascend | finetune | 21.4 | 16.8~17 | | |||
| | GPT-2 medium | CNN_Dailymail(TL;DR) | Ascend | finetune | 25.94 | 20.6~20.9 | | |||
| | GPT-2 large | CNN_Dailymail(TL;DR) | Ascend | finetune | 26.73 | 21.5~21.6 | | |||
| | 模型 | dataset | device | eval_type | ROUGE | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.08 | 15.03(xlarge) | | |||
| | GPT-2 medium | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.16 | 15.03(xlarge) | | |||
| | GPT-2 large | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.29 | 15.03(xlarge) | | |||
| ### Translation 任务 | |||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Translation任务中的BLEU得分情况。 | |||
| | 模型 | dataset | device | eval_type | BLEU | OpenAI | | |||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||
| | GPT-2 small | WMT-14 Fr-En | Ascend | zero-shot | 4.49 | 0.7~0.8 | | |||
| | GPT-2 medium | WMT-14 Fr-En | Ascend | zero-shot | 7.09 | 2.0~3.0 | | |||
| | GPT-2 large | WMT-14 Fr-En | Ascend | zero-shot | 7.97 | 6.5~7.0 | | |||
| | GPT-2 small | WMT-14 En-Fr | Ascend | zero-shot | 2.81 | 5(xlarge) | | |||
| | GPT-2 medium | WMT-14 En-Fr | Ascend | zero-shot | 3.2 | 5(xlarge) | | |||
| | GPT-2 large | WMT-14 En-Fr | Ascend | zero-shot | 3.06 | 5(xlarge) | | |||
| # 其他 | |||
| 该模型已在Ascend环境下环境下得到验证。 | |||
| # ModelZoo主页 | |||
| [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo) | |||
| @@ -0,0 +1,141 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| CNN & DailyMail train dataset sampler | |||
| """ | |||
| import os | |||
| import sys | |||
| import shutil | |||
| import argparse | |||
| from random import random | |||
| from src.utils.tokenization import Tokenizer | |||
| def replace_split_word(read_path, output_path, tldr_str="TL;DR:", original_split='\t'): | |||
| """ | |||
| append tldr str | |||
| """ | |||
| with open(read_path, "r") as r, open(output_path, "a") as w: | |||
| line = r.readline() | |||
| while line: | |||
| article = line[:line.find(original_split)] + ' ' + tldr_str + ' ' | |||
| ref = line[line.rfind(original_split) + 1:] | |||
| w.write(article + ref) | |||
| line = r.readline() | |||
| def sample(read_path, out_path, threshold=1.0, max_items=0xFFFFFFF): | |||
| """ | |||
| sample function | |||
| """ | |||
| cnt = 0 | |||
| total_cnt = 0 | |||
| with open(read_path, "r") as r, open(out_path, "a") as w: | |||
| line = r.readline() | |||
| while line: | |||
| total_cnt += 1 | |||
| if cnt >= max_items: | |||
| break | |||
| if random() > threshold: | |||
| line = r.readline() | |||
| continue | |||
| w.write(line) | |||
| if (cnt + 1) % 3000 == 0: | |||
| print("Now Processed Samples: {}, total: {}".format(cnt, total_cnt)) | |||
| cnt += 1 | |||
| line = r.readline() | |||
| def clip_article(input_path, out_path, hint, max_length): | |||
| """ | |||
| clip article that the sample (article + summary) exceed max_length | |||
| """ | |||
| tokenizer = Tokenizer() | |||
| cnt = 0 | |||
| with open(input_path, "r") as r, open(out_path, "a+") as a: | |||
| line = r.readline() | |||
| while line: | |||
| pos = line.rfind(hint) | |||
| article = line[:pos] | |||
| summary = line[pos:] | |||
| if len(tokenizer.encode(line)) > max_length: | |||
| l_article = tokenizer.encode(article)[:max_length - len(tokenizer.encode(summary))] | |||
| article = tokenizer.decode(l_article) + " " | |||
| if cnt % 1000 == 0: | |||
| print(article + summary) | |||
| print("==============================") | |||
| cnt += 1 | |||
| a.write(article + summary) | |||
| line = r.readline() | |||
| def sampler_dataset(): | |||
| """ | |||
| run CNN & DailyMail train dataset sampler | |||
| """ | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--input_path", type=str, default="", | |||
| help="input file path") | |||
| parser.add_argument("--output_path", type=str, default="", | |||
| help="out file path") | |||
| parser.add_argument("--replace_hint", type=str, default="true") | |||
| parser.add_argument("--sample", type=str, default="true", | |||
| help="do sample? true or false") | |||
| parser.add_argument("--max_length", type=int, default=1022, | |||
| help="max seq_length of input_raw_dataset") | |||
| parser.add_argument("--prob", type=float, default=0.25, | |||
| help="sample rate") | |||
| parser.add_argument("--max_items", type=int, default=10000, | |||
| help="max number of document") | |||
| parser.add_argument("--hint", type=str, default="TL:DR;", | |||
| help="hint text") | |||
| args = parser.parse_args() | |||
| # temp_files, one for storing inputs in every stage, the other for storing middle results. | |||
| temp_file_input = sys.path[0] + '/temp_file1_by_sampler_py.txt' | |||
| temp_file_proc = sys.path[0] + '/temp_file2_by_sampler_py.txt' | |||
| read_path = args.input_path | |||
| output_path = args.output_path | |||
| prob = args.prob | |||
| max_items = args.max_items | |||
| hint = args.hint | |||
| max_length = args.max_length | |||
| split_str = '\t' | |||
| shutil.copyfile(read_path, temp_file_input) | |||
| clip_article(temp_file_input, temp_file_proc, hint=split_str, max_length=max_length) | |||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||
| os.remove(temp_file_proc) | |||
| if args.replace_hint.lower() == "true": | |||
| replace_split_word(temp_file_input, temp_file_proc, hint, split_str) | |||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||
| os.remove(temp_file_proc) | |||
| if args.sample.lower() == "true": | |||
| sample(temp_file_input, temp_file_proc, prob, max_items) | |||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||
| os.remove(temp_file_proc) | |||
| shutil.copyfile(temp_file_input, output_path) | |||
| os.remove(temp_file_input) | |||
| if __name__ == "__main__": | |||
| sampler_dataset() | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Read weight using tensorflow | |||
| to read the parameters of the gpt-2 pretrained model from tensorflow checkpoint | |||
| and save them into npy files for mindspore to load. | |||
| *this script is based on the gpt-2 model downloaded from openai.* | |||
| """ | |||
| import argparse | |||
| import tensorflow as tf | |||
| import numpy as np | |||
| from .trans_dict import trans_dict_tf | |||
| def read_weight(ckpt_path): | |||
| """ | |||
| read weight | |||
| Args: | |||
| ckpt_path: the path of tensorflow checkpoint | |||
| """ | |||
| # model path and model name | |||
| init_vars = tf.train.list_variables(ckpt_path) | |||
| # load the model parameters into vars | |||
| save_param_num = 0 | |||
| for name, _ in init_vars: | |||
| array = tf.train.load_variable(ckpt_path, name) | |||
| # By this you can understand the next step easily | |||
| name = name[6:].replace(r"/", ".") | |||
| # skip 'model/' and change var names to avoid path mistake | |||
| if name not in trans_dict_tf.keys(): | |||
| print(name + " is not in this model") | |||
| else: | |||
| np.save(trans_dict_tf[name] + ".npy", array) | |||
| save_param_num = save_param_num + 1 | |||
| # save the parameters by 'npy' | |||
| print("finished!") | |||
| print("save {num} parameters.".format(num=save_param_num)) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight") | |||
| parser.add_argument("--ckpt_file_path", type=str, default="", | |||
| help="The tensorflow GPT-2 model checkpoint file path") | |||
| args_opt = parser.parse_args() | |||
| ckpt_path = args_opt.ckpt_file_path | |||
| read_weight(ckpt_path=ckpt_path) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Save weight using mindspore, to load the parameters of gpt-2 model from npy file. | |||
| npy files should be in the same path with this script. Otherwise you should change the path name of the script. | |||
| """ | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from .trans_dict import trans_dict_tf | |||
| def trans_model_parameter(ckpt_name): | |||
| """ | |||
| transform model parameters | |||
| Args: | |||
| ckpt_name (str): the name of the transformed checkpoint. | |||
| """ | |||
| file_names = [name for name in os.listdir() if name.endswith(".npy")] | |||
| # to find all file names with suffix '.npy' in the current path. | |||
| new_params_list = [] | |||
| for file_name in file_names: | |||
| var_name = file_name[:-4] | |||
| param_dict = {"name": var_name, "data": Tensor(np.load(file_name))} | |||
| if var_name in trans_dict_tf.values(): | |||
| new_params_list.append(param_dict) | |||
| print(var_name+" has been saved") | |||
| save_checkpoint(new_params_list, ckpt_name) | |||
| # to load the parameters from npy files and save them as mindspore checkpoint | |||
| print("Finished:the parameters have been saved into mindspore checkpoint.") | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight") | |||
| parser.add_argument("--output_file_name", type=str, default="", | |||
| help="The name of output checkpoint name") | |||
| args_opt = parser.parse_args() | |||
| ckpt_name = args_opt.output_file_name | |||
| trans_model_parameter(ckpt_name=ckpt_name) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,892 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """transform diction""" | |||
| trans_dict_tf = { | |||
| 'h0.attn.c_attn.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h0.attn.c_attn.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h0.attn.c_proj.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h0.attn.c_proj.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h0.ln_1.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h0.ln_1.g': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h0.ln_2.b': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta', | |||
| 'h0.ln_2.g': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma', | |||
| 'h0.mlp.c_fc.b': 'gpt2_decoder.layers.0.feedforward.c_fc.bias', | |||
| 'h0.mlp.c_fc.w': 'gpt2_decoder.layers.0.feedforward.c_fc.weight', | |||
| 'h0.mlp.c_proj.b': 'gpt2_decoder.layers.0.feedforward.c_proj.bias', | |||
| 'h0.mlp.c_proj.w': 'gpt2_decoder.layers.0.feedforward.c_proj.weight', | |||
| 'h1.attn.c_attn.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h1.attn.c_attn.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h1.attn.c_proj.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h1.attn.c_proj.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h1.ln_1.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h1.ln_1.g': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h1.ln_2.b': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta', | |||
| 'h1.ln_2.g': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma', | |||
| 'h1.mlp.c_fc.b': 'gpt2_decoder.layers.1.feedforward.c_fc.bias', | |||
| 'h1.mlp.c_fc.w': 'gpt2_decoder.layers.1.feedforward.c_fc.weight', | |||
| 'h1.mlp.c_proj.b': 'gpt2_decoder.layers.1.feedforward.c_proj.bias', | |||
| 'h1.mlp.c_proj.w': 'gpt2_decoder.layers.1.feedforward.c_proj.weight', | |||
| 'h2.attn.c_attn.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h2.attn.c_attn.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h2.attn.c_proj.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h2.attn.c_proj.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h2.ln_1.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h2.ln_1.g': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h2.ln_2.b': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta', | |||
| 'h2.ln_2.g': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma', | |||
| 'h2.mlp.c_fc.b': 'gpt2_decoder.layers.2.feedforward.c_fc.bias', | |||
| 'h2.mlp.c_fc.w': 'gpt2_decoder.layers.2.feedforward.c_fc.weight', | |||
| 'h2.mlp.c_proj.b': 'gpt2_decoder.layers.2.feedforward.c_proj.bias', | |||
| 'h2.mlp.c_proj.w': 'gpt2_decoder.layers.2.feedforward.c_proj.weight', | |||
| 'h3.attn.c_attn.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h3.attn.c_attn.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h3.attn.c_proj.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h3.attn.c_proj.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h3.ln_1.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h3.ln_1.g': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h3.ln_2.b': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta', | |||
| 'h3.ln_2.g': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma', | |||
| 'h3.mlp.c_fc.b': 'gpt2_decoder.layers.3.feedforward.c_fc.bias', | |||
| 'h3.mlp.c_fc.w': 'gpt2_decoder.layers.3.feedforward.c_fc.weight', | |||
| 'h3.mlp.c_proj.b': 'gpt2_decoder.layers.3.feedforward.c_proj.bias', | |||
| 'h3.mlp.c_proj.w': 'gpt2_decoder.layers.3.feedforward.c_proj.weight', | |||
| 'h4.attn.c_attn.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h4.attn.c_attn.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h4.attn.c_proj.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h4.attn.c_proj.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h4.ln_1.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h4.ln_1.g': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h4.ln_2.b': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta', | |||
| 'h4.ln_2.g': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma', | |||
| 'h4.mlp.c_fc.b': 'gpt2_decoder.layers.4.feedforward.c_fc.bias', | |||
| 'h4.mlp.c_fc.w': 'gpt2_decoder.layers.4.feedforward.c_fc.weight', | |||
| 'h4.mlp.c_proj.b': 'gpt2_decoder.layers.4.feedforward.c_proj.bias', | |||
| 'h4.mlp.c_proj.w': 'gpt2_decoder.layers.4.feedforward.c_proj.weight', | |||
| 'h5.attn.c_attn.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h5.attn.c_attn.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h5.attn.c_proj.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h5.attn.c_proj.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h5.ln_1.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h5.ln_1.g': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h5.ln_2.b': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta', | |||
| 'h5.ln_2.g': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma', | |||
| 'h5.mlp.c_fc.b': 'gpt2_decoder.layers.5.feedforward.c_fc.bias', | |||
| 'h5.mlp.c_fc.w': 'gpt2_decoder.layers.5.feedforward.c_fc.weight', | |||
| 'h5.mlp.c_proj.b': 'gpt2_decoder.layers.5.feedforward.c_proj.bias', | |||
| 'h5.mlp.c_proj.w': 'gpt2_decoder.layers.5.feedforward.c_proj.weight', | |||
| 'h6.attn.c_attn.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h6.attn.c_attn.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h6.attn.c_proj.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h6.attn.c_proj.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h6.ln_1.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h6.ln_1.g': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h6.ln_2.b': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta', | |||
| 'h6.ln_2.g': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma', | |||
| 'h6.mlp.c_fc.b': 'gpt2_decoder.layers.6.feedforward.c_fc.bias', | |||
| 'h6.mlp.c_fc.w': 'gpt2_decoder.layers.6.feedforward.c_fc.weight', | |||
| 'h6.mlp.c_proj.b': 'gpt2_decoder.layers.6.feedforward.c_proj.bias', | |||
| 'h6.mlp.c_proj.w': 'gpt2_decoder.layers.6.feedforward.c_proj.weight', | |||
| 'h7.attn.c_attn.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h7.attn.c_attn.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h7.attn.c_proj.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h7.attn.c_proj.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h7.ln_1.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h7.ln_1.g': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h7.ln_2.b': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta', | |||
| 'h7.ln_2.g': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma', | |||
| 'h7.mlp.c_fc.b': 'gpt2_decoder.layers.7.feedforward.c_fc.bias', | |||
| 'h7.mlp.c_fc.w': 'gpt2_decoder.layers.7.feedforward.c_fc.weight', | |||
| 'h7.mlp.c_proj.b': 'gpt2_decoder.layers.7.feedforward.c_proj.bias', | |||
| 'h7.mlp.c_proj.w': 'gpt2_decoder.layers.7.feedforward.c_proj.weight', | |||
| 'h8.attn.c_attn.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h8.attn.c_attn.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h8.attn.c_proj.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h8.attn.c_proj.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h8.ln_1.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h8.ln_1.g': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h8.ln_2.b': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta', | |||
| 'h8.ln_2.g': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma', | |||
| 'h8.mlp.c_fc.b': 'gpt2_decoder.layers.8.feedforward.c_fc.bias', | |||
| 'h8.mlp.c_fc.w': 'gpt2_decoder.layers.8.feedforward.c_fc.weight', | |||
| 'h8.mlp.c_proj.b': 'gpt2_decoder.layers.8.feedforward.c_proj.bias', | |||
| 'h8.mlp.c_proj.w': 'gpt2_decoder.layers.8.feedforward.c_proj.weight', | |||
| 'h9.attn.c_attn.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h9.attn.c_attn.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h9.attn.c_proj.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h9.attn.c_proj.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h9.ln_1.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h9.ln_1.g': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h9.ln_2.b': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta', | |||
| 'h9.ln_2.g': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma', | |||
| 'h9.mlp.c_fc.b': 'gpt2_decoder.layers.9.feedforward.c_fc.bias', | |||
| 'h9.mlp.c_fc.w': 'gpt2_decoder.layers.9.feedforward.c_fc.weight', | |||
| 'h9.mlp.c_proj.b': 'gpt2_decoder.layers.9.feedforward.c_proj.bias', | |||
| 'h9.mlp.c_proj.w': 'gpt2_decoder.layers.9.feedforward.c_proj.weight', | |||
| 'h10.attn.c_attn.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h10.attn.c_attn.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h10.attn.c_proj.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h10.attn.c_proj.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h10.ln_1.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h10.ln_1.g': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h10.ln_2.b': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta', | |||
| 'h10.ln_2.g': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma', | |||
| 'h10.mlp.c_fc.b': 'gpt2_decoder.layers.10.feedforward.c_fc.bias', | |||
| 'h10.mlp.c_fc.w': 'gpt2_decoder.layers.10.feedforward.c_fc.weight', | |||
| 'h10.mlp.c_proj.b': 'gpt2_decoder.layers.10.feedforward.c_proj.bias', | |||
| 'h10.mlp.c_proj.w': 'gpt2_decoder.layers.10.feedforward.c_proj.weight', | |||
| 'h11.attn.c_attn.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h11.attn.c_attn.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h11.attn.c_proj.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h11.attn.c_proj.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h11.ln_1.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h11.ln_1.g': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h11.ln_2.b': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta', | |||
| 'h11.ln_2.g': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma', | |||
| 'h11.mlp.c_fc.b': 'gpt2_decoder.layers.11.feedforward.c_fc.bias', | |||
| 'h11.mlp.c_fc.w': 'gpt2_decoder.layers.11.feedforward.c_fc.weight', | |||
| 'h11.mlp.c_proj.b': 'gpt2_decoder.layers.11.feedforward.c_proj.bias', | |||
| 'h11.mlp.c_proj.w': 'gpt2_decoder.layers.11.feedforward.c_proj.weight', | |||
| 'h12.attn.c_attn.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h12.attn.c_attn.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h12.attn.c_proj.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h12.attn.c_proj.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h12.ln_1.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h12.ln_1.g': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h12.ln_2.b': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta', | |||
| 'h12.ln_2.g': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma', | |||
| 'h12.mlp.c_fc.b': 'gpt2_decoder.layers.12.feedforward.c_fc.bias', | |||
| 'h12.mlp.c_fc.w': 'gpt2_decoder.layers.12.feedforward.c_fc.weight', | |||
| 'h12.mlp.c_proj.b': 'gpt2_decoder.layers.12.feedforward.c_proj.bias', | |||
| 'h12.mlp.c_proj.w': 'gpt2_decoder.layers.12.feedforward.c_proj.weight', | |||
| 'h13.attn.c_attn.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h13.attn.c_attn.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h13.attn.c_proj.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h13.attn.c_proj.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h13.ln_1.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h13.ln_1.g': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h13.ln_2.b': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta', | |||
| 'h13.ln_2.g': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma', | |||
| 'h13.mlp.c_fc.b': 'gpt2_decoder.layers.13.feedforward.c_fc.bias', | |||
| 'h13.mlp.c_fc.w': 'gpt2_decoder.layers.13.feedforward.c_fc.weight', | |||
| 'h13.mlp.c_proj.b': 'gpt2_decoder.layers.13.feedforward.c_proj.bias', | |||
| 'h13.mlp.c_proj.w': 'gpt2_decoder.layers.13.feedforward.c_proj.weight', | |||
| 'h14.attn.c_attn.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h14.attn.c_attn.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h14.attn.c_proj.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h14.attn.c_proj.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h14.ln_1.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h14.ln_1.g': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h14.ln_2.b': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta', | |||
| 'h14.ln_2.g': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma', | |||
| 'h14.mlp.c_fc.b': 'gpt2_decoder.layers.14.feedforward.c_fc.bias', | |||
| 'h14.mlp.c_fc.w': 'gpt2_decoder.layers.14.feedforward.c_fc.weight', | |||
| 'h14.mlp.c_proj.b': 'gpt2_decoder.layers.14.feedforward.c_proj.bias', | |||
| 'h14.mlp.c_proj.w': 'gpt2_decoder.layers.14.feedforward.c_proj.weight', | |||
| 'h15.attn.c_attn.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h15.attn.c_attn.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h15.attn.c_proj.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h15.attn.c_proj.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h15.ln_1.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h15.ln_1.g': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h15.ln_2.b': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta', | |||
| 'h15.ln_2.g': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma', | |||
| 'h15.mlp.c_fc.b': 'gpt2_decoder.layers.15.feedforward.c_fc.bias', | |||
| 'h15.mlp.c_fc.w': 'gpt2_decoder.layers.15.feedforward.c_fc.weight', | |||
| 'h15.mlp.c_proj.b': 'gpt2_decoder.layers.15.feedforward.c_proj.bias', | |||
| 'h15.mlp.c_proj.w': 'gpt2_decoder.layers.15.feedforward.c_proj.weight', | |||
| 'h16.attn.c_attn.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h16.attn.c_attn.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h16.attn.c_proj.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h16.attn.c_proj.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h16.ln_1.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h16.ln_1.g': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h16.ln_2.b': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta', | |||
| 'h16.ln_2.g': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma', | |||
| 'h16.mlp.c_fc.b': 'gpt2_decoder.layers.16.feedforward.c_fc.bias', | |||
| 'h16.mlp.c_fc.w': 'gpt2_decoder.layers.16.feedforward.c_fc.weight', | |||
| 'h16.mlp.c_proj.b': 'gpt2_decoder.layers.16.feedforward.c_proj.bias', | |||
| 'h16.mlp.c_proj.w': 'gpt2_decoder.layers.16.feedforward.c_proj.weight', | |||
| 'h17.attn.c_attn.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h17.attn.c_attn.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h17.attn.c_proj.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h17.attn.c_proj.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h17.ln_1.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h17.ln_1.g': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h17.ln_2.b': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta', | |||
| 'h17.ln_2.g': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma', | |||
| 'h17.mlp.c_fc.b': 'gpt2_decoder.layers.17.feedforward.c_fc.bias', | |||
| 'h17.mlp.c_fc.w': 'gpt2_decoder.layers.17.feedforward.c_fc.weight', | |||
| 'h17.mlp.c_proj.b': 'gpt2_decoder.layers.17.feedforward.c_proj.bias', | |||
| 'h17.mlp.c_proj.w': 'gpt2_decoder.layers.17.feedforward.c_proj.weight', | |||
| 'h18.attn.c_attn.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h18.attn.c_attn.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h18.attn.c_proj.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h18.attn.c_proj.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h18.ln_1.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h18.ln_1.g': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h18.ln_2.b': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta', | |||
| 'h18.ln_2.g': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma', | |||
| 'h18.mlp.c_fc.b': 'gpt2_decoder.layers.18.feedforward.c_fc.bias', | |||
| 'h18.mlp.c_fc.w': 'gpt2_decoder.layers.18.feedforward.c_fc.weight', | |||
| 'h18.mlp.c_proj.b': 'gpt2_decoder.layers.18.feedforward.c_proj.bias', | |||
| 'h18.mlp.c_proj.w': 'gpt2_decoder.layers.18.feedforward.c_proj.weight', | |||
| 'h19.attn.c_attn.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h19.attn.c_attn.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h19.attn.c_proj.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h19.attn.c_proj.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h19.ln_1.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h19.ln_1.g': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h19.ln_2.b': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta', | |||
| 'h19.ln_2.g': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma', | |||
| 'h19.mlp.c_fc.b': 'gpt2_decoder.layers.19.feedforward.c_fc.bias', | |||
| 'h19.mlp.c_fc.w': 'gpt2_decoder.layers.19.feedforward.c_fc.weight', | |||
| 'h19.mlp.c_proj.b': 'gpt2_decoder.layers.19.feedforward.c_proj.bias', | |||
| 'h19.mlp.c_proj.w': 'gpt2_decoder.layers.19.feedforward.c_proj.weight', | |||
| 'h20.attn.c_attn.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h20.attn.c_attn.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h20.attn.c_proj.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h20.attn.c_proj.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h20.ln_1.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h20.ln_1.g': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h20.ln_2.b': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta', | |||
| 'h20.ln_2.g': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma', | |||
| 'h20.mlp.c_fc.b': 'gpt2_decoder.layers.20.feedforward.c_fc.bias', | |||
| 'h20.mlp.c_fc.w': 'gpt2_decoder.layers.20.feedforward.c_fc.weight', | |||
| 'h20.mlp.c_proj.b': 'gpt2_decoder.layers.20.feedforward.c_proj.bias', | |||
| 'h20.mlp.c_proj.w': 'gpt2_decoder.layers.20.feedforward.c_proj.weight', | |||
| 'h21.attn.c_attn.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h21.attn.c_attn.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h21.attn.c_proj.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h21.attn.c_proj.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h21.ln_1.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h21.ln_1.g': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h21.ln_2.b': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta', | |||
| 'h21.ln_2.g': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma', | |||
| 'h21.mlp.c_fc.b': 'gpt2_decoder.layers.21.feedforward.c_fc.bias', | |||
| 'h21.mlp.c_fc.w': 'gpt2_decoder.layers.21.feedforward.c_fc.weight', | |||
| 'h21.mlp.c_proj.b': 'gpt2_decoder.layers.21.feedforward.c_proj.bias', | |||
| 'h21.mlp.c_proj.w': 'gpt2_decoder.layers.21.feedforward.c_proj.weight', | |||
| 'h22.attn.c_attn.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h22.attn.c_attn.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h22.attn.c_proj.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h22.attn.c_proj.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h22.ln_1.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h22.ln_1.g': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h22.ln_2.b': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta', | |||
| 'h22.ln_2.g': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma', | |||
| 'h22.mlp.c_fc.b': 'gpt2_decoder.layers.22.feedforward.c_fc.bias', | |||
| 'h22.mlp.c_fc.w': 'gpt2_decoder.layers.22.feedforward.c_fc.weight', | |||
| 'h22.mlp.c_proj.b': 'gpt2_decoder.layers.22.feedforward.c_proj.bias', | |||
| 'h22.mlp.c_proj.w': 'gpt2_decoder.layers.22.feedforward.c_proj.weight', | |||
| 'h23.attn.c_attn.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h23.attn.c_attn.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h23.attn.c_proj.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h23.attn.c_proj.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h23.ln_1.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h23.ln_1.g': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h23.ln_2.b': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta', | |||
| 'h23.ln_2.g': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma', | |||
| 'h23.mlp.c_fc.b': 'gpt2_decoder.layers.23.feedforward.c_fc.bias', | |||
| 'h23.mlp.c_fc.w': 'gpt2_decoder.layers.23.feedforward.c_fc.weight', | |||
| 'h23.mlp.c_proj.b': 'gpt2_decoder.layers.23.feedforward.c_proj.bias', | |||
| 'h23.mlp.c_proj.w': 'gpt2_decoder.layers.23.feedforward.c_proj.weight', | |||
| 'h24.attn.c_attn.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h24.attn.c_attn.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h24.attn.c_proj.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h24.attn.c_proj.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h24.ln_1.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h24.ln_1.g': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h24.ln_2.b': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta', | |||
| 'h24.ln_2.g': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma', | |||
| 'h24.mlp.c_fc.b': 'gpt2_decoder.layers.24.feedforward.c_fc.bias', | |||
| 'h24.mlp.c_fc.w': 'gpt2_decoder.layers.24.feedforward.c_fc.weight', | |||
| 'h24.mlp.c_proj.b': 'gpt2_decoder.layers.24.feedforward.c_proj.bias', | |||
| 'h24.mlp.c_proj.w': 'gpt2_decoder.layers.24.feedforward.c_proj.weight', | |||
| 'h25.attn.c_attn.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h25.attn.c_attn.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h25.attn.c_proj.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h25.attn.c_proj.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h25.ln_1.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h25.ln_1.g': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h25.ln_2.b': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta', | |||
| 'h25.ln_2.g': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma', | |||
| 'h25.mlp.c_fc.b': 'gpt2_decoder.layers.25.feedforward.c_fc.bias', | |||
| 'h25.mlp.c_fc.w': 'gpt2_decoder.layers.25.feedforward.c_fc.weight', | |||
| 'h25.mlp.c_proj.b': 'gpt2_decoder.layers.25.feedforward.c_proj.bias', | |||
| 'h25.mlp.c_proj.w': 'gpt2_decoder.layers.25.feedforward.c_proj.weight', | |||
| 'h26.attn.c_attn.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h26.attn.c_attn.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h26.attn.c_proj.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h26.attn.c_proj.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h26.ln_1.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h26.ln_1.g': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h26.ln_2.b': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta', | |||
| 'h26.ln_2.g': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma', | |||
| 'h26.mlp.c_fc.b': 'gpt2_decoder.layers.26.feedforward.c_fc.bias', | |||
| 'h26.mlp.c_fc.w': 'gpt2_decoder.layers.26.feedforward.c_fc.weight', | |||
| 'h26.mlp.c_proj.b': 'gpt2_decoder.layers.26.feedforward.c_proj.bias', | |||
| 'h26.mlp.c_proj.w': 'gpt2_decoder.layers.26.feedforward.c_proj.weight', | |||
| 'h27.attn.c_attn.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h27.attn.c_attn.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h27.attn.c_proj.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h27.attn.c_proj.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h27.ln_1.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h27.ln_1.g': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h27.ln_2.b': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta', | |||
| 'h27.ln_2.g': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma', | |||
| 'h27.mlp.c_fc.b': 'gpt2_decoder.layers.27.feedforward.c_fc.bias', | |||
| 'h27.mlp.c_fc.w': 'gpt2_decoder.layers.27.feedforward.c_fc.weight', | |||
| 'h27.mlp.c_proj.b': 'gpt2_decoder.layers.27.feedforward.c_proj.bias', | |||
| 'h27.mlp.c_proj.w': 'gpt2_decoder.layers.27.feedforward.c_proj.weight', | |||
| 'h28.attn.c_attn.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h28.attn.c_attn.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h28.attn.c_proj.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h28.attn.c_proj.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h28.ln_1.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h28.ln_1.g': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h28.ln_2.b': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta', | |||
| 'h28.ln_2.g': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma', | |||
| 'h28.mlp.c_fc.b': 'gpt2_decoder.layers.28.feedforward.c_fc.bias', | |||
| 'h28.mlp.c_fc.w': 'gpt2_decoder.layers.28.feedforward.c_fc.weight', | |||
| 'h28.mlp.c_proj.b': 'gpt2_decoder.layers.28.feedforward.c_proj.bias', | |||
| 'h28.mlp.c_proj.w': 'gpt2_decoder.layers.28.feedforward.c_proj.weight', | |||
| 'h29.attn.c_attn.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h29.attn.c_attn.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h29.attn.c_proj.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h29.attn.c_proj.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h29.ln_1.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h29.ln_1.g': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h29.ln_2.b': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta', | |||
| 'h29.ln_2.g': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma', | |||
| 'h29.mlp.c_fc.b': 'gpt2_decoder.layers.29.feedforward.c_fc.bias', | |||
| 'h29.mlp.c_fc.w': 'gpt2_decoder.layers.29.feedforward.c_fc.weight', | |||
| 'h29.mlp.c_proj.b': 'gpt2_decoder.layers.29.feedforward.c_proj.bias', | |||
| 'h29.mlp.c_proj.w': 'gpt2_decoder.layers.29.feedforward.c_proj.weight', | |||
| 'h30.attn.c_attn.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h30.attn.c_attn.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h30.attn.c_proj.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h30.attn.c_proj.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h30.ln_1.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h30.ln_1.g': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h30.ln_2.b': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta', | |||
| 'h30.ln_2.g': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma', | |||
| 'h30.mlp.c_fc.b': 'gpt2_decoder.layers.30.feedforward.c_fc.bias', | |||
| 'h30.mlp.c_fc.w': 'gpt2_decoder.layers.30.feedforward.c_fc.weight', | |||
| 'h30.mlp.c_proj.b': 'gpt2_decoder.layers.30.feedforward.c_proj.bias', | |||
| 'h30.mlp.c_proj.w': 'gpt2_decoder.layers.30.feedforward.c_proj.weight', | |||
| 'h31.attn.c_attn.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h31.attn.c_attn.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h31.attn.c_proj.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h31.attn.c_proj.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h31.ln_1.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h31.ln_1.g': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h31.ln_2.b': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta', | |||
| 'h31.ln_2.g': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma', | |||
| 'h31.mlp.c_fc.b': 'gpt2_decoder.layers.31.feedforward.c_fc.bias', | |||
| 'h31.mlp.c_fc.w': 'gpt2_decoder.layers.31.feedforward.c_fc.weight', | |||
| 'h31.mlp.c_proj.b': 'gpt2_decoder.layers.31.feedforward.c_proj.bias', | |||
| 'h31.mlp.c_proj.w': 'gpt2_decoder.layers.31.feedforward.c_proj.weight', | |||
| 'h32.attn.c_attn.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h32.attn.c_attn.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h32.attn.c_proj.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h32.attn.c_proj.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h32.ln_1.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h32.ln_1.g': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h32.ln_2.b': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta', | |||
| 'h32.ln_2.g': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma', | |||
| 'h32.mlp.c_fc.b': 'gpt2_decoder.layers.32.feedforward.c_fc.bias', | |||
| 'h32.mlp.c_fc.w': 'gpt2_decoder.layers.32.feedforward.c_fc.weight', | |||
| 'h32.mlp.c_proj.b': 'gpt2_decoder.layers.32.feedforward.c_proj.bias', | |||
| 'h32.mlp.c_proj.w': 'gpt2_decoder.layers.32.feedforward.c_proj.weight', | |||
| 'h33.attn.c_attn.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h33.attn.c_attn.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h33.attn.c_proj.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h33.attn.c_proj.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h33.ln_1.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h33.ln_1.g': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h33.ln_2.b': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta', | |||
| 'h33.ln_2.g': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma', | |||
| 'h33.mlp.c_fc.b': 'gpt2_decoder.layers.33.feedforward.c_fc.bias', | |||
| 'h33.mlp.c_fc.w': 'gpt2_decoder.layers.33.feedforward.c_fc.weight', | |||
| 'h33.mlp.c_proj.b': 'gpt2_decoder.layers.33.feedforward.c_proj.bias', | |||
| 'h33.mlp.c_proj.w': 'gpt2_decoder.layers.33.feedforward.c_proj.weight', | |||
| 'h34.attn.c_attn.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h34.attn.c_attn.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h34.attn.c_proj.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h34.attn.c_proj.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h34.ln_1.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h34.ln_1.g': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h34.ln_2.b': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta', | |||
| 'h34.ln_2.g': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma', | |||
| 'h34.mlp.c_fc.b': 'gpt2_decoder.layers.34.feedforward.c_fc.bias', | |||
| 'h34.mlp.c_fc.w': 'gpt2_decoder.layers.34.feedforward.c_fc.weight', | |||
| 'h34.mlp.c_proj.b': 'gpt2_decoder.layers.34.feedforward.c_proj.bias', | |||
| 'h34.mlp.c_proj.w': 'gpt2_decoder.layers.34.feedforward.c_proj.weight', | |||
| 'h35.attn.c_attn.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h35.attn.c_attn.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h35.attn.c_proj.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h35.attn.c_proj.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h35.ln_1.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h35.ln_1.g': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h35.ln_2.b': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta', | |||
| 'h35.ln_2.g': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma', | |||
| 'h35.mlp.c_fc.b': 'gpt2_decoder.layers.35.feedforward.c_fc.bias', | |||
| 'h35.mlp.c_fc.w': 'gpt2_decoder.layers.35.feedforward.c_fc.weight', | |||
| 'h35.mlp.c_proj.b': 'gpt2_decoder.layers.35.feedforward.c_proj.bias', | |||
| 'h35.mlp.c_proj.w': 'gpt2_decoder.layers.35.feedforward.c_proj.weight', | |||
| 'ln_f.b': 'layer_norm.layer_norm.gamma', | |||
| 'ln_f.g': 'layer_norm.layer_norm.beta', | |||
| 'wpe': 'gpt2_embedding_postprocess.position_embedding_table', | |||
| 'wte': 'gpt2_embedding_lookup.embedding_table' | |||
| } # transfer dictionary | |||
| trans_dict_py = { | |||
| 'h.0.attn.c_attn.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.0.attn.c_attn.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.0.attn.c_proj.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.0.attn.c_proj.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.0.ln_1.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.0.ln_1.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.0.ln_2.bias': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta', | |||
| 'h.0.ln_2.weight': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.0.mlp.c_fc.bias': 'gpt2_decoder.layers.0.feedforward.c_fc.bias', | |||
| 'h.0.mlp.c_fc.weight': 'gpt2_decoder.layers.0.feedforward.c_fc.weight', | |||
| 'h.0.mlp.c_proj.bias': 'gpt2_decoder.layers.0.feedforward.c_proj.bias', | |||
| 'h.0.mlp.c_proj.weight': 'gpt2_decoder.layers.0.feedforward.c_proj.weight', | |||
| 'h.1.attn.c_attn.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.1.attn.c_attn.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.1.attn.c_proj.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.1.attn.c_proj.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.1.ln_1.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.1.ln_1.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.1.ln_2.bias': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta', | |||
| 'h.1.ln_2.weight': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.1.mlp.c_fc.bias': 'gpt2_decoder.layers.1.feedforward.c_fc.bias', | |||
| 'h.1.mlp.c_fc.weight': 'gpt2_decoder.layers.1.feedforward.c_fc.weight', | |||
| 'h.1.mlp.c_proj.bias': 'gpt2_decoder.layers.1.feedforward.c_proj.bias', | |||
| 'h.1.mlp.c_proj.weight': 'gpt2_decoder.layers.1.feedforward.c_proj.weight', | |||
| 'h.2.attn.c_attn.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.2.attn.c_attn.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.2.attn.c_proj.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.2.attn.c_proj.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.2.ln_1.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.2.ln_1.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.2.ln_2.bias': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta', | |||
| 'h.2.ln_2.weight': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.2.mlp.c_fc.bias': 'gpt2_decoder.layers.2.feedforward.c_fc.bias', | |||
| 'h.2.mlp.c_fc.weight': 'gpt2_decoder.layers.2.feedforward.c_fc.weight', | |||
| 'h.2.mlp.c_proj.bias': 'gpt2_decoder.layers.2.feedforward.c_proj.bias', | |||
| 'h.2.mlp.c_proj.weight': 'gpt2_decoder.layers.2.feedforward.c_proj.weight', | |||
| 'h.3.attn.c_attn.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.3.attn.c_attn.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.3.attn.c_proj.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.3.attn.c_proj.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.3.ln_1.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.3.ln_1.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.3.ln_2.bias': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta', | |||
| 'h.3.ln_2.weight': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.3.mlp.c_fc.bias': 'gpt2_decoder.layers.3.feedforward.c_fc.bias', | |||
| 'h.3.mlp.c_fc.weight': 'gpt2_decoder.layers.3.feedforward.c_fc.weight', | |||
| 'h.3.mlp.c_proj.bias': 'gpt2_decoder.layers.3.feedforward.c_proj.bias', | |||
| 'h.3.mlp.c_proj.weight': 'gpt2_decoder.layers.3.feedforward.c_proj.weight', | |||
| 'h.4.attn.c_attn.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.4.attn.c_attn.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.4.attn.c_proj.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.4.attn.c_proj.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.4.ln_1.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.4.ln_1.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.4.ln_2.bias': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta', | |||
| 'h.4.ln_2.weight': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.4.mlp.c_fc.bias': 'gpt2_decoder.layers.4.feedforward.c_fc.bias', | |||
| 'h.4.mlp.c_fc.weight': 'gpt2_decoder.layers.4.feedforward.c_fc.weight', | |||
| 'h.4.mlp.c_proj.bias': 'gpt2_decoder.layers.4.feedforward.c_proj.bias', | |||
| 'h.4.mlp.c_proj.weight': 'gpt2_decoder.layers.4.feedforward.c_proj.weight', | |||
| 'h.5.attn.c_attn.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.5.attn.c_attn.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.5.attn.c_proj.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.5.attn.c_proj.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.5.ln_1.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.5.ln_1.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.5.ln_2.bias': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta', | |||
| 'h.5.ln_2.weight': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.5.mlp.c_fc.bias': 'gpt2_decoder.layers.5.feedforward.c_fc.bias', | |||
| 'h.5.mlp.c_fc.weight': 'gpt2_decoder.layers.5.feedforward.c_fc.weight', | |||
| 'h.5.mlp.c_proj.bias': 'gpt2_decoder.layers.5.feedforward.c_proj.bias', | |||
| 'h.5.mlp.c_proj.weight': 'gpt2_decoder.layers.5.feedforward.c_proj.weight', | |||
| 'h.6.attn.c_attn.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.6.attn.c_attn.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.6.attn.c_proj.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.6.attn.c_proj.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.6.ln_1.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.6.ln_1.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.6.ln_2.bias': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta', | |||
| 'h.6.ln_2.weight': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.6.mlp.c_fc.bias': 'gpt2_decoder.layers.6.feedforward.c_fc.bias', | |||
| 'h.6.mlp.c_fc.weight': 'gpt2_decoder.layers.6.feedforward.c_fc.weight', | |||
| 'h.6.mlp.c_proj.bias': 'gpt2_decoder.layers.6.feedforward.c_proj.bias', | |||
| 'h.6.mlp.c_proj.weight': 'gpt2_decoder.layers.6.feedforward.c_proj.weight', | |||
| 'h.7.attn.c_attn.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.7.attn.c_attn.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.7.attn.c_proj.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.7.attn.c_proj.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.7.ln_1.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.7.ln_1.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.7.ln_2.bias': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta', | |||
| 'h.7.ln_2.weight': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.7.mlp.c_fc.bias': 'gpt2_decoder.layers.7.feedforward.c_fc.bias', | |||
| 'h.7.mlp.c_fc.weight': 'gpt2_decoder.layers.7.feedforward.c_fc.weight', | |||
| 'h.7.mlp.c_proj.bias': 'gpt2_decoder.layers.7.feedforward.c_proj.bias', | |||
| 'h.7.mlp.c_proj.weight': 'gpt2_decoder.layers.7.feedforward.c_proj.weight', | |||
| 'h.8.attn.c_attn.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.8.attn.c_attn.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.8.attn.c_proj.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.8.attn.c_proj.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.8.ln_1.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.8.ln_1.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.8.ln_2.bias': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta', | |||
| 'h.8.ln_2.weight': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.8.mlp.c_fc.bias': 'gpt2_decoder.layers.8.feedforward.c_fc.bias', | |||
| 'h.8.mlp.c_fc.weight': 'gpt2_decoder.layers.8.feedforward.c_fc.weight', | |||
| 'h.8.mlp.c_proj.bias': 'gpt2_decoder.layers.8.feedforward.c_proj.bias', | |||
| 'h.8.mlp.c_proj.weight': 'gpt2_decoder.layers.8.feedforward.c_proj.weight', | |||
| 'h.9.attn.c_attn.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.9.attn.c_attn.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.9.attn.c_proj.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.9.attn.c_proj.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.9.ln_1.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.9.ln_1.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.9.ln_2.bias': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta', | |||
| 'h.9.ln_2.weight': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.9.mlp.c_fc.bias': 'gpt2_decoder.layers.9.feedforward.c_fc.bias', | |||
| 'h.9.mlp.c_fc.weight': 'gpt2_decoder.layers.9.feedforward.c_fc.weight', | |||
| 'h.9.mlp.c_proj.bias': 'gpt2_decoder.layers.9.feedforward.c_proj.bias', | |||
| 'h.9.mlp.c_proj.weight': 'gpt2_decoder.layers.9.feedforward.c_proj.weight', | |||
| 'h.10.attn.c_attn.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.10.attn.c_attn.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.10.attn.c_proj.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.10.attn.c_proj.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.10.ln_1.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.10.ln_1.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.10.ln_2.bias': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta', | |||
| 'h.10.ln_2.weight': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.10.mlp.c_fc.bias': 'gpt2_decoder.layers.10.feedforward.c_fc.bias', | |||
| 'h.10.mlp.c_fc.weight': 'gpt2_decoder.layers.10.feedforward.c_fc.weight', | |||
| 'h.10.mlp.c_proj.bias': 'gpt2_decoder.layers.10.feedforward.c_proj.bias', | |||
| 'h.10.mlp.c_proj.weight': 'gpt2_decoder.layers.10.feedforward.c_proj.weight', | |||
| 'h.11.attn.c_attn.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.11.attn.c_attn.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.11.attn.c_proj.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.11.attn.c_proj.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.11.ln_1.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.11.ln_1.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.11.ln_2.bias': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta', | |||
| 'h.11.ln_2.weight': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.11.mlp.c_fc.bias': 'gpt2_decoder.layers.11.feedforward.c_fc.bias', | |||
| 'h.11.mlp.c_fc.weight': 'gpt2_decoder.layers.11.feedforward.c_fc.weight', | |||
| 'h.11.mlp.c_proj.bias': 'gpt2_decoder.layers.11.feedforward.c_proj.bias', | |||
| 'h.11.mlp.c_proj.weight': 'gpt2_decoder.layers.11.feedforward.c_proj.weight', | |||
| 'h.12.attn.c_attn.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.12.attn.c_attn.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.12.attn.c_proj.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.12.attn.c_proj.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.12.ln_1.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.12.ln_1.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.12.ln_2.bias': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta', | |||
| 'h.12.ln_2.weight': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.12.mlp.c_fc.bias': 'gpt2_decoder.layers.12.feedforward.c_fc.bias', | |||
| 'h.12.mlp.c_fc.weight': 'gpt2_decoder.layers.12.feedforward.c_fc.weight', | |||
| 'h.12.mlp.c_proj.bias': 'gpt2_decoder.layers.12.feedforward.c_proj.bias', | |||
| 'h.12.mlp.c_proj.weight': 'gpt2_decoder.layers.12.feedforward.c_proj.weight', | |||
| 'h.13.attn.c_attn.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.13.attn.c_attn.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.13.attn.c_proj.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.13.attn.c_proj.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.13.ln_1.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.13.ln_1.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.13.ln_2.bias': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta', | |||
| 'h.13.ln_2.weight': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.13.mlp.c_fc.bias': 'gpt2_decoder.layers.13.feedforward.c_fc.bias', | |||
| 'h.13.mlp.c_fc.weight': 'gpt2_decoder.layers.13.feedforward.c_fc.weight', | |||
| 'h.13.mlp.c_proj.bias': 'gpt2_decoder.layers.13.feedforward.c_proj.bias', | |||
| 'h.13.mlp.c_proj.weight': 'gpt2_decoder.layers.13.feedforward.c_proj.weight', | |||
| 'h.14.attn.c_attn.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.14.attn.c_attn.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.14.attn.c_proj.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.14.attn.c_proj.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.14.ln_1.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.14.ln_1.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.14.ln_2.bias': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta', | |||
| 'h.14.ln_2.weight': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.14.mlp.c_fc.bias': 'gpt2_decoder.layers.14.feedforward.c_fc.bias', | |||
| 'h.14.mlp.c_fc.weight': 'gpt2_decoder.layers.14.feedforward.c_fc.weight', | |||
| 'h.14.mlp.c_proj.bias': 'gpt2_decoder.layers.14.feedforward.c_proj.bias', | |||
| 'h.14.mlp.c_proj.weight': 'gpt2_decoder.layers.14.feedforward.c_proj.weight', | |||
| 'h.15.attn.c_attn.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.15.attn.c_attn.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.15.attn.c_proj.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.15.attn.c_proj.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.15.ln_1.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.15.ln_1.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.15.ln_2.bias': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta', | |||
| 'h.15.ln_2.weight': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.15.mlp.c_fc.bias': 'gpt2_decoder.layers.15.feedforward.c_fc.bias', | |||
| 'h.15.mlp.c_fc.weight': 'gpt2_decoder.layers.15.feedforward.c_fc.weight', | |||
| 'h.15.mlp.c_proj.bias': 'gpt2_decoder.layers.15.feedforward.c_proj.bias', | |||
| 'h.15.mlp.c_proj.weight': 'gpt2_decoder.layers.15.feedforward.c_proj.weight', | |||
| 'h.16.attn.c_attn.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.16.attn.c_attn.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.16.attn.c_proj.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.16.attn.c_proj.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.16.ln_1.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.16.ln_1.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.16.ln_2.bias': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta', | |||
| 'h.16.ln_2.weight': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.16.mlp.c_fc.bias': 'gpt2_decoder.layers.16.feedforward.c_fc.bias', | |||
| 'h.16.mlp.c_fc.weight': 'gpt2_decoder.layers.16.feedforward.c_fc.weight', | |||
| 'h.16.mlp.c_proj.bias': 'gpt2_decoder.layers.16.feedforward.c_proj.bias', | |||
| 'h.16.mlp.c_proj.weight': 'gpt2_decoder.layers.16.feedforward.c_proj.weight', | |||
| 'h.17.attn.c_attn.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.17.attn.c_attn.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.17.attn.c_proj.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.17.attn.c_proj.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.17.ln_1.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.17.ln_1.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.17.ln_2.bias': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta', | |||
| 'h.17.ln_2.weight': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.17.mlp.c_fc.bias': 'gpt2_decoder.layers.17.feedforward.c_fc.bias', | |||
| 'h.17.mlp.c_fc.weight': 'gpt2_decoder.layers.17.feedforward.c_fc.weight', | |||
| 'h.17.mlp.c_proj.bias': 'gpt2_decoder.layers.17.feedforward.c_proj.bias', | |||
| 'h.17.mlp.c_proj.weight': 'gpt2_decoder.layers.17.feedforward.c_proj.weight', | |||
| 'h.18.attn.c_attn.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.18.attn.c_attn.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.18.attn.c_proj.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.18.attn.c_proj.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.18.ln_1.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.18.ln_1.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.18.ln_2.bias': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta', | |||
| 'h.18.ln_2.weight': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.18.mlp.c_fc.bias': 'gpt2_decoder.layers.18.feedforward.c_fc.bias', | |||
| 'h.18.mlp.c_fc.weight': 'gpt2_decoder.layers.18.feedforward.c_fc.weight', | |||
| 'h.18.mlp.c_proj.bias': 'gpt2_decoder.layers.18.feedforward.c_proj.bias', | |||
| 'h.18.mlp.c_proj.weight': 'gpt2_decoder.layers.18.feedforward.c_proj.weight', | |||
| 'h.19.attn.c_attn.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.19.attn.c_attn.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.19.attn.c_proj.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.19.attn.c_proj.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.19.ln_1.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.19.ln_1.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.19.ln_2.bias': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta', | |||
| 'h.19.ln_2.weight': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.19.mlp.c_fc.bias': 'gpt2_decoder.layers.19.feedforward.c_fc.bias', | |||
| 'h.19.mlp.c_fc.weight': 'gpt2_decoder.layers.19.feedforward.c_fc.weight', | |||
| 'h.19.mlp.c_proj.bias': 'gpt2_decoder.layers.19.feedforward.c_proj.bias', | |||
| 'h.19.mlp.c_proj.weight': 'gpt2_decoder.layers.19.feedforward.c_proj.weight', | |||
| 'h.20.attn.c_attn.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.20.attn.c_attn.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.20.attn.c_proj.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.20.attn.c_proj.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.20.ln_1.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.20.ln_1.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.20.ln_2.bias': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta', | |||
| 'h.20.ln_2.weight': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.20.mlp.c_fc.bias': 'gpt2_decoder.layers.20.feedforward.c_fc.bias', | |||
| 'h.20.mlp.c_fc.weight': 'gpt2_decoder.layers.20.feedforward.c_fc.weight', | |||
| 'h.20.mlp.c_proj.bias': 'gpt2_decoder.layers.20.feedforward.c_proj.bias', | |||
| 'h.20.mlp.c_proj.weight': 'gpt2_decoder.layers.20.feedforward.c_proj.weight', | |||
| 'h.21.attn.c_attn.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.21.attn.c_attn.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.21.attn.c_proj.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.21.attn.c_proj.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.21.ln_1.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.21.ln_1.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.21.ln_2.bias': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta', | |||
| 'h.21.ln_2.weight': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.21.mlp.c_fc.bias': 'gpt2_decoder.layers.21.feedforward.c_fc.bias', | |||
| 'h.21.mlp.c_fc.weight': 'gpt2_decoder.layers.21.feedforward.c_fc.weight', | |||
| 'h.21.mlp.c_proj.bias': 'gpt2_decoder.layers.21.feedforward.c_proj.bias', | |||
| 'h.21.mlp.c_proj.weight': 'gpt2_decoder.layers.21.feedforward.c_proj.weight', | |||
| 'h.22.attn.c_attn.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.22.attn.c_attn.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.22.attn.c_proj.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.22.attn.c_proj.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.22.ln_1.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.22.ln_1.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.22.ln_2.bias': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta', | |||
| 'h.22.ln_2.weight': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.22.mlp.c_fc.bias': 'gpt2_decoder.layers.22.feedforward.c_fc.bias', | |||
| 'h.22.mlp.c_fc.weight': 'gpt2_decoder.layers.22.feedforward.c_fc.weight', | |||
| 'h.22.mlp.c_proj.bias': 'gpt2_decoder.layers.22.feedforward.c_proj.bias', | |||
| 'h.22.mlp.c_proj.weight': 'gpt2_decoder.layers.22.feedforward.c_proj.weight', | |||
| 'h.23.attn.c_attn.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.23.attn.c_attn.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.23.attn.c_proj.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.23.attn.c_proj.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.23.ln_1.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.23.ln_1.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.23.ln_2.bias': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta', | |||
| 'h.23.ln_2.weight': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.23.mlp.c_fc.bias': 'gpt2_decoder.layers.23.feedforward.c_fc.bias', | |||
| 'h.23.mlp.c_fc.weight': 'gpt2_decoder.layers.23.feedforward.c_fc.weight', | |||
| 'h.23.mlp.c_proj.bias': 'gpt2_decoder.layers.23.feedforward.c_proj.bias', | |||
| 'h.23.mlp.c_proj.weight': 'gpt2_decoder.layers.23.feedforward.c_proj.weight', | |||
| 'h.24.attn.c_attn.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.24.attn.c_attn.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.24.attn.c_proj.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.24.attn.c_proj.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.24.ln_1.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.24.ln_1.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.24.ln_2.bias': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta', | |||
| 'h.24.ln_2.weight': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.24.mlp.c_fc.bias': 'gpt2_decoder.layers.24.feedforward.c_fc.bias', | |||
| 'h.24.mlp.c_fc.weight': 'gpt2_decoder.layers.24.feedforward.c_fc.weight', | |||
| 'h.24.mlp.c_proj.bias': 'gpt2_decoder.layers.24.feedforward.c_proj.bias', | |||
| 'h.24.mlp.c_proj.weight': 'gpt2_decoder.layers.24.feedforward.c_proj.weight', | |||
| 'h.25.attn.c_attn.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.25.attn.c_attn.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.25.attn.c_proj.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.25.attn.c_proj.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.25.ln_1.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.25.ln_1.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.25.ln_2.bias': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta', | |||
| 'h.25.ln_2.weight': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.25.mlp.c_fc.bias': 'gpt2_decoder.layers.25.feedforward.c_fc.bias', | |||
| 'h.25.mlp.c_fc.weight': 'gpt2_decoder.layers.25.feedforward.c_fc.weight', | |||
| 'h.25.mlp.c_proj.bias': 'gpt2_decoder.layers.25.feedforward.c_proj.bias', | |||
| 'h.25.mlp.c_proj.weight': 'gpt2_decoder.layers.25.feedforward.c_proj.weight', | |||
| 'h.26.attn.c_attn.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.26.attn.c_attn.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.26.attn.c_proj.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.26.attn.c_proj.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.26.ln_1.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.26.ln_1.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.26.ln_2.bias': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta', | |||
| 'h.26.ln_2.weight': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.26.mlp.c_fc.bias': 'gpt2_decoder.layers.26.feedforward.c_fc.bias', | |||
| 'h.26.mlp.c_fc.weight': 'gpt2_decoder.layers.26.feedforward.c_fc.weight', | |||
| 'h.26.mlp.c_proj.bias': 'gpt2_decoder.layers.26.feedforward.c_proj.bias', | |||
| 'h.26.mlp.c_proj.weight': 'gpt2_decoder.layers.26.feedforward.c_proj.weight', | |||
| 'h.27.attn.c_attn.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.27.attn.c_attn.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.27.attn.c_proj.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.27.attn.c_proj.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.27.ln_1.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.27.ln_1.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.27.ln_2.bias': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta', | |||
| 'h.27.ln_2.weight': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.27.mlp.c_fc.bias': 'gpt2_decoder.layers.27.feedforward.c_fc.bias', | |||
| 'h.27.mlp.c_fc.weight': 'gpt2_decoder.layers.27.feedforward.c_fc.weight', | |||
| 'h.27.mlp.c_proj.bias': 'gpt2_decoder.layers.27.feedforward.c_proj.bias', | |||
| 'h.27.mlp.c_proj.weight': 'gpt2_decoder.layers.27.feedforward.c_proj.weight', | |||
| 'h.28.attn.c_attn.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.28.attn.c_attn.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.28.attn.c_proj.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.28.attn.c_proj.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.28.ln_1.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.28.ln_1.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.28.ln_2.bias': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta', | |||
| 'h.28.ln_2.weight': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.28.mlp.c_fc.bias': 'gpt2_decoder.layers.28.feedforward.c_fc.bias', | |||
| 'h.28.mlp.c_fc.weight': 'gpt2_decoder.layers.28.feedforward.c_fc.weight', | |||
| 'h.28.mlp.c_proj.bias': 'gpt2_decoder.layers.28.feedforward.c_proj.bias', | |||
| 'h.28.mlp.c_proj.weight': 'gpt2_decoder.layers.28.feedforward.c_proj.weight', | |||
| 'h.29.attn.c_attn.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.29.attn.c_attn.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.29.attn.c_proj.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.29.attn.c_proj.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.29.ln_1.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.29.ln_1.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.29.ln_2.bias': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta', | |||
| 'h.29.ln_2.weight': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.29.mlp.c_fc.bias': 'gpt2_decoder.layers.29.feedforward.c_fc.bias', | |||
| 'h.29.mlp.c_fc.weight': 'gpt2_decoder.layers.29.feedforward.c_fc.weight', | |||
| 'h.29.mlp.c_proj.bias': 'gpt2_decoder.layers.29.feedforward.c_proj.bias', | |||
| 'h.29.mlp.c_proj.weight': 'gpt2_decoder.layers.29.feedforward.c_proj.weight', | |||
| 'h.30.attn.c_attn.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.30.attn.c_attn.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.30.attn.c_proj.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.30.attn.c_proj.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.30.ln_1.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.30.ln_1.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.30.ln_2.bias': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta', | |||
| 'h.30.ln_2.weight': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.30.mlp.c_fc.bias': 'gpt2_decoder.layers.30.feedforward.c_fc.bias', | |||
| 'h.30.mlp.c_fc.weight': 'gpt2_decoder.layers.30.feedforward.c_fc.weight', | |||
| 'h.30.mlp.c_proj.bias': 'gpt2_decoder.layers.30.feedforward.c_proj.bias', | |||
| 'h.30.mlp.c_proj.weight': 'gpt2_decoder.layers.30.feedforward.c_proj.weight', | |||
| 'h.31.attn.c_attn.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.31.attn.c_attn.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.31.attn.c_proj.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.31.attn.c_proj.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.31.ln_1.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.31.ln_1.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.31.ln_2.bias': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta', | |||
| 'h.31.ln_2.weight': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.31.mlp.c_fc.bias': 'gpt2_decoder.layers.31.feedforward.c_fc.bias', | |||
| 'h.31.mlp.c_fc.weight': 'gpt2_decoder.layers.31.feedforward.c_fc.weight', | |||
| 'h.31.mlp.c_proj.bias': 'gpt2_decoder.layers.31.feedforward.c_proj.bias', | |||
| 'h.31.mlp.c_proj.weight': 'gpt2_decoder.layers.31.feedforward.c_proj.weight', | |||
| 'h.32.attn.c_attn.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.32.attn.c_attn.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.32.attn.c_proj.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.32.attn.c_proj.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.32.ln_1.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.32.ln_1.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.32.ln_2.bias': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta', | |||
| 'h.32.ln_2.weight': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.32.mlp.c_fc.bias': 'gpt2_decoder.layers.32.feedforward.c_fc.bias', | |||
| 'h.32.mlp.c_fc.weight': 'gpt2_decoder.layers.32.feedforward.c_fc.weight', | |||
| 'h.32.mlp.c_proj.bias': 'gpt2_decoder.layers.32.feedforward.c_proj.bias', | |||
| 'h.32.mlp.c_proj.weight': 'gpt2_decoder.layers.32.feedforward.c_proj.weight', | |||
| 'h.33.attn.c_attn.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.33.attn.c_attn.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.33.attn.c_proj.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.33.attn.c_proj.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.33.ln_1.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.33.ln_1.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.33.ln_2.bias': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta', | |||
| 'h.33.ln_2.weight': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.33.mlp.c_fc.bias': 'gpt2_decoder.layers.33.feedforward.c_fc.bias', | |||
| 'h.33.mlp.c_fc.weight': 'gpt2_decoder.layers.33.feedforward.c_fc.weight', | |||
| 'h.33.mlp.c_proj.bias': 'gpt2_decoder.layers.33.feedforward.c_proj.bias', | |||
| 'h.33.mlp.c_proj.weight': 'gpt2_decoder.layers.33.feedforward.c_proj.weight', | |||
| 'h.34.attn.c_attn.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.34.attn.c_attn.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.34.attn.c_proj.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.34.attn.c_proj.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.34.ln_1.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.34.ln_1.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.34.ln_2.bias': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta', | |||
| 'h.34.ln_2.weight': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.34.mlp.c_fc.bias': 'gpt2_decoder.layers.34.feedforward.c_fc.bias', | |||
| 'h.34.mlp.c_fc.weight': 'gpt2_decoder.layers.34.feedforward.c_fc.weight', | |||
| 'h.34.mlp.c_proj.bias': 'gpt2_decoder.layers.34.feedforward.c_proj.bias', | |||
| 'h.34.mlp.c_proj.weight': 'gpt2_decoder.layers.34.feedforward.c_proj.weight', | |||
| 'h.35.attn.c_attn.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||
| 'h.35.attn.c_attn.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||
| 'h.35.attn.c_proj.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||
| 'h.35.attn.c_proj.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||
| 'h.35.ln_1.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||
| 'h.35.ln_1.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||
| 'h.35.ln_2.bias': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta', | |||
| 'h.35.ln_2.weight': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma', | |||
| 'h.35.mlp.c_fc.bias': 'gpt2_decoder.layers.35.feedforward.c_fc.bias', | |||
| 'h.35.mlp.c_fc.weight': 'gpt2_decoder.layers.35.feedforward.c_fc.weight', | |||
| 'h.35.mlp.c_proj.bias': 'gpt2_decoder.layers.35.feedforward.c_proj.bias', | |||
| 'h.35.mlp.c_proj.weight': 'gpt2_decoder.layers.35.feedforward.c_proj.weight', | |||
| 'ln_f.bias': 'layer_norm.layer_norm.gamma', | |||
| 'ln_f.weight': 'layer_norm.layer_norm.beta', | |||
| 'wpe.weight': 'gpt2_embedding_postprocess.position_embedding_table', | |||
| 'wte.weight': 'gpt2_embedding_lookup.embedding_table' | |||
| } | |||
| @@ -0,0 +1,148 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """create mindrecord data for Children's Book Test task""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import argparse | |||
| import collections | |||
| import logging | |||
| import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.utils.tokenization import Tokenizer | |||
| def create_instance(tokenizer, text, max_length=None, num_choice=None): | |||
| """A single sample instance for cbt task.""" | |||
| text = text.replace(" \t ", "\t ") | |||
| sentence = text.strip().split("\t") | |||
| context_length = len(tokenizer.encode(sentence[0])) | |||
| whole_sentence = sentence[0] + sentence[1] | |||
| whole_sentence = whole_sentence.strip() | |||
| assert whole_sentence != "" | |||
| print(" | whole sentence: ", whole_sentence) | |||
| ids = tokenizer.encode(whole_sentence) | |||
| input_length = len(ids) | |||
| pair_ids = None | |||
| output = tokenizer.prepare_for_model(ids=ids, | |||
| pair_ids=pair_ids, | |||
| add_special_tokens=True, | |||
| max_length=max_length, | |||
| padding=True, | |||
| truncate_direction="RIGHT", | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True) | |||
| output["length"] = [context_length + 1] + [input_length + 1] | |||
| gold_answer_id = int(sentence[2]) | |||
| assert gold_answer_id < 10 | |||
| output["mc_labels"] = gold_answer_id | |||
| for name, value in output.items(): | |||
| print(name) | |||
| print(value) | |||
| print("==================================") | |||
| return output | |||
| def write_instance_to_file(writer, instance): | |||
| """write the instance to file""" | |||
| input_ids = instance["input_ids"] | |||
| input_mask = instance["attention_mask"] | |||
| assert len(input_ids) == len(input_mask) | |||
| length = instance["length"] # list | |||
| mc_labels = instance["mc_labels"] | |||
| features = collections.OrderedDict() | |||
| features["input_ids"] = np.asarray(input_ids) | |||
| features["input_mask"] = np.asarray(input_mask) | |||
| features["input_length"] = np.asarray(length) | |||
| features["mc_labels"] = mc_labels | |||
| writer.write_raw_data([features]) | |||
| return features | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--input_file", type=str, required=True, default="", help='Input raw text file. ') | |||
| parser.add_argument("--output_file", type=str, required=True, default="", help='Output MindRecord file. ') | |||
| parser.add_argument("--num_splits", type=int, default=1, | |||
| help='The MindRecord file will be split into the number of partition. ') | |||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||
| parser.add_argument("--num_choice", type=int, required=True, help='Number of choices. ') | |||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||
| args = parser.parse_args() | |||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||
| num_choice = args.num_choice | |||
| input_file = args.input_file | |||
| logging.info("***** Reading from input files *****") | |||
| logging.info("Input File: %s", input_file) | |||
| output_file = args.output_file | |||
| logging.info("***** Writing to output files *****") | |||
| logging.info("Output File: %s", output_file) | |||
| writer = FileWriter(output_file, args.num_splits) | |||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||
| "input_length": {"type": "int64", "shape": [-1]}, | |||
| "mc_labels": {"type": "int64"} | |||
| } | |||
| writer.add_schema(data_schema, "cbt-schema") | |||
| total_written = 0 | |||
| total_read = 0 | |||
| logging.info("***** Reading from %s *****", input_file) | |||
| with open(input_file, "r") as f: | |||
| while True: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| total_read += 1 | |||
| if total_read % 500 == 0: | |||
| logging.info("%d ...", total_read) | |||
| output = create_instance(tokenizer, line, args.max_seq_length, num_choice) | |||
| features = write_instance_to_file(writer, instance=output) | |||
| total_written += 1 | |||
| if total_written <= 20: | |||
| logging.info("***** Example *****") | |||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||
| for feature_name in features.keys(): | |||
| feature = features[feature_name] | |||
| logging.info("%s: %s", feature_name, feature) | |||
| writer.commit() | |||
| logging.info("Wrote %d total instances", total_written) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,140 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """create mindrecord data for LAMBADA task""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import argparse | |||
| import collections | |||
| import logging | |||
| import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.utils.tokenization import Tokenizer | |||
| def create_instance(tokenizer, text, max_length=None): | |||
| """A single sample instance for LAMBADA task.""" | |||
| text = text.replace(" \t ", "\t ") | |||
| sentence = text.strip().split("\t") | |||
| context_length = len(tokenizer.encode(sentence[0])) | |||
| whole_sentence = sentence[0] + sentence[1] | |||
| whole_sentence = whole_sentence.strip() | |||
| assert whole_sentence != "" | |||
| print(" | whole sentence: ", whole_sentence) | |||
| ids = tokenizer.encode(whole_sentence) | |||
| input_length = len(ids) | |||
| pair_ids = None | |||
| output = tokenizer.prepare_for_model(ids=ids, | |||
| pair_ids=pair_ids, | |||
| add_special_tokens=True, | |||
| max_length=max_length, | |||
| padding=True, | |||
| truncate_direction="RIGHT", | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True) | |||
| # input_length = <bos> + text_length, not include <eos> | |||
| output["length"] = [context_length + 1] + [input_length + 1] | |||
| for k, v in output.items(): | |||
| print(k) | |||
| print(v) | |||
| print("==================================") | |||
| return output | |||
| def write_instance_to_file(writer, instance): | |||
| """write the instance to file""" | |||
| input_ids = instance["input_ids"] | |||
| input_mask = instance["attention_mask"] | |||
| assert len(input_ids) == len(input_mask) | |||
| length = instance["length"] # list | |||
| features = collections.OrderedDict() | |||
| features["input_ids"] = np.asarray(input_ids) | |||
| features["input_mask"] = np.asarray(input_mask) | |||
| features["input_length"] = np.asarray(length) | |||
| writer.write_raw_data([features]) | |||
| return features | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ') | |||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ') | |||
| parser.add_argument("--num_splits", type=int, default=1, | |||
| help='The MindRecord file will be split into the number of partition. ') | |||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||
| args = parser.parse_args() | |||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||
| input_file = args.input_file | |||
| logging.info("***** Reading from input files *****") | |||
| logging.info("Input File: %s", input_file) | |||
| output_file = args.output_file | |||
| logging.info("***** Writing to output files *****") | |||
| logging.info("Output File: %s", output_file) | |||
| writer = FileWriter(output_file, args.num_splits) | |||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||
| "input_length": {"type": "int64", "shape": [-1]}, | |||
| } | |||
| writer.add_schema(data_schema, "lambada-schema") | |||
| total_written = 0 | |||
| total_read = 0 | |||
| logging.info("***** Reading from %s *****", input_file) | |||
| with open(input_file, "r") as f: | |||
| while True: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| total_read += 1 | |||
| if total_read % 500 == 0: | |||
| logging.info("%d ...", total_read) | |||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||
| features = write_instance_to_file(writer, instance=output) | |||
| total_written += 1 | |||
| if total_written <= 20: | |||
| logging.info("***** Example *****") | |||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||
| for feature_name in features.keys(): | |||
| feature = features[feature_name] | |||
| logging.info("%s: %s", feature_name, feature) | |||
| writer.commit() | |||
| logging.info("Wrote %d total instances", total_written) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,126 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """create mindrecord data for LM task""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import argparse | |||
| import collections | |||
| import logging | |||
| import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.utils.tokenization import Tokenizer | |||
| def create_instance(tokenizer, text, max_length=None): | |||
| """A single sample instance for LM task.""" | |||
| sentence = text.strip().split("\t") | |||
| ids = tokenizer.encode(sentence[0]) | |||
| pair_ids = None | |||
| if len(sentence) == 2: | |||
| pair_ids = tokenizer.encode(sentence[1]) | |||
| output = tokenizer.prepare_for_model(ids=ids, | |||
| pair_ids=pair_ids, | |||
| add_special_tokens=True, | |||
| max_length=max_length, | |||
| padding=True, | |||
| truncate_direction="LEFT", | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True) | |||
| return output | |||
| def write_instance_to_file(writer, instance): | |||
| """write the instance to file""" | |||
| input_ids = instance["input_ids"] | |||
| input_mask = instance["attention_mask"] | |||
| label_ids = instance["input_ids"] | |||
| assert len(input_ids) == len(label_ids) | |||
| features = collections.OrderedDict() | |||
| features["input_ids"] = np.asarray(input_ids) | |||
| features["input_mask"] = np.asarray(input_mask) | |||
| features["label_ids"] = np.asarray(label_ids) | |||
| writer.write_raw_data([features]) | |||
| return features | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ') | |||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ') | |||
| parser.add_argument("--num_splits", type=int, default=1, | |||
| help='The MindRecord file will be split into the number of partition. ') | |||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||
| args = parser.parse_args() | |||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||
| input_file = args.input_file | |||
| logging.info("***** Reading from input files *****") | |||
| logging.info("Input File: %s", input_file) | |||
| output_file = args.output_file | |||
| logging.info("***** Writing to output files *****") | |||
| logging.info("Output File: %s", output_file) | |||
| writer = FileWriter(output_file, args.num_splits) | |||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||
| "label_ids": {"type": "int64", "shape": [-1]} | |||
| } | |||
| writer.add_schema(data_schema, "lm-schema") | |||
| total_written = 0 | |||
| total_read = 0 | |||
| logging.info("***** Reading from %s *****", input_file) | |||
| with open(input_file, "r") as f: | |||
| while True: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| total_read += 1 | |||
| if total_read % 500 == 0: | |||
| logging.info("%d ...", total_read) | |||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||
| features = write_instance_to_file(writer, instance=output) | |||
| total_written += 1 | |||
| if total_written <= 20: | |||
| logging.info("***** Example *****") | |||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||
| for feature_name in features.keys(): | |||
| feature = features[feature_name] | |||
| logging.info("%s: %s", feature_name, feature) | |||
| writer.commit() | |||
| logging.info("Wrote %d total instances", total_written) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """create mindrecord data for Summarization task""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import argparse | |||
| import collections | |||
| import logging | |||
| import numpy as np | |||
| from mindspore.mindrecord import FileWriter | |||
| from src.utils import tokenization | |||
| def create_instance(tokenizer, text, max_length=None): | |||
| """A single sample instance for Summarization task.""" | |||
| sentence = text.strip().split("\t") | |||
| ids = tokenizer.encode(sentence[0]) | |||
| pair_ids = None | |||
| if len(sentence) == 2: | |||
| pair_ids = tokenizer.encode(sentence[1]) | |||
| if len(sentence) >= 3: | |||
| article = sentence[0] | |||
| for i in range(1, len(sentence) - 1): | |||
| article += sentence[i] | |||
| ids = tokenizer.encode(article) | |||
| pair_ids = tokenizer.encode(sentence[-1]) | |||
| output = tokenizer.prepare_for_model(ids=ids, | |||
| pair_ids=pair_ids, | |||
| add_special_tokens=True, | |||
| max_length=max_length, | |||
| padding=True, | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True) | |||
| return output | |||
| def write_instance_to_file(writer, instance): | |||
| """write the instance to file""" | |||
| input_ids = instance["input_ids"] | |||
| input_mask = instance["attention_mask"] | |||
| label_ids = instance["input_ids"] | |||
| assert len(input_ids) == len(label_ids) | |||
| features = collections.OrderedDict() | |||
| features["input_ids"] = np.asarray(input_ids) | |||
| features["input_mask"] = np.asarray(input_mask) | |||
| features["label_ids"] = np.asarray(label_ids) | |||
| writer.write_raw_data([features]) | |||
| return features | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file.') | |||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.') | |||
| parser.add_argument("--num_splits", type=int, default=1, | |||
| help='The MindRecord file will be split into the number of partition. ') | |||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length.') | |||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||
| parser.add_argument("--mode", type=str, required=True, default='cnn_dailymail', help='mode of dataset creation') | |||
| args = parser.parse_args() | |||
| tokenizer = tokenization.Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file, mode=args.mode) | |||
| input_file = args.input_file | |||
| logging.info("***** Reading from input files *****") | |||
| logging.info("Input File: %s", input_file) | |||
| output_file = args.output_file | |||
| logging.info("***** Writing to output files *****") | |||
| logging.info("Output File: %s", output_file) | |||
| writer = FileWriter(output_file, args.num_splits) | |||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||
| "label_ids": {"type": "int64", "shape": [-1]} | |||
| } | |||
| writer.add_schema(data_schema, "wikitext2-schema") | |||
| total_written = 0 | |||
| total_read = 0 | |||
| logging.info("***** Reading from %s *****", input_file) | |||
| with open(input_file, "r") as f: | |||
| while True: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| total_read += 1 | |||
| if total_read % 500 == 0: | |||
| logging.info("%d ...", total_read) | |||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||
| features = write_instance_to_file(writer, instance=output) | |||
| total_written += 1 | |||
| if total_written <= 20: | |||
| logging.info("***** Example *****") | |||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||
| for feature_name in features.keys(): | |||
| feature = features[feature_name] | |||
| logging.info("%s: %s", feature_name, feature) | |||
| writer.commit() | |||
| logging.info("Wrote %d total instances", total_written) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """download the CNN & DailyMail for Summarization task""" | |||
| import argparse | |||
| from datasets import load_dataset | |||
| def generate_txt(url, split_, number=None, version="3.0.0"): | |||
| """ | |||
| generate txt file of cnn_dailymail dataset | |||
| Args: | |||
| url (str): directory of dataset txt file. | |||
| split_ (str): test or train. | |||
| number (int): top-n number of samples from dataset | |||
| version (str): "3.0.0" by default | |||
| """ | |||
| cnn = load_dataset("cnn_dailymail", version, split=split_) | |||
| if number == -1: | |||
| number = len(cnn) | |||
| f = open(url + split_ + '.txt', 'w') | |||
| for idx in range(number): | |||
| article = cnn[idx]['article'] | |||
| article = article.replace('\n', ' ') | |||
| highlights = cnn[idx]['highlights'] | |||
| highlights = highlights.replace('\n', ' ') | |||
| f.write(article + "\t" + highlights + '\n') | |||
| f.close() | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='Download CNN_Dailymail 3.0.0 using datasets by Huggingface') | |||
| parser.add_argument('--dir', type=str, default="", help="directory of dataset") | |||
| parser.add_argument('--split', type=str, default='test', help="[test,train]") | |||
| parser.add_argument('--num', type=int, default=-1, | |||
| help=" number of samples by default order. " | |||
| "If num is -1, it will download whole dataset. Default: -1") | |||
| args = parser.parse_args() | |||
| data_directory = args.dir | |||
| split = args.split | |||
| num = args.num | |||
| generate_txt(url=data_directory, split_=split, number=num) | |||
| @@ -0,0 +1,135 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Evaluation reading comprehension result with additional answer.""" | |||
| import json | |||
| import re | |||
| import string | |||
| import argparse | |||
| from collections import Counter | |||
| def get_normalize_answer_token(string_): | |||
| """normalize the answer token, Lower text and remove punctuation, article and extra whitespace""" | |||
| def remove_articles(text): | |||
| regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) | |||
| return re.sub(regex, ' ', text) | |||
| def white_space_fix(text): | |||
| return ' '.join(text.split()) | |||
| def remove_punc(text): | |||
| exclude = set(string.punctuation) | |||
| return ''.join(char for char in text if char not in exclude) | |||
| def lower(text): | |||
| return text.lower() | |||
| return white_space_fix(remove_articles(remove_punc(lower(string_)))).split() | |||
| def calculate_f1(pred_answer, gold_answer): | |||
| """ | |||
| calculate final F1 score with addition answer | |||
| """ | |||
| f1_score = 0 | |||
| pred_answer = get_normalize_answer_token(pred_answer) | |||
| gold_answer = get_normalize_answer_token(gold_answer) | |||
| common = Counter(pred_answer) & Counter(gold_answer) | |||
| num_same = sum(common.values()) | |||
| # the number of same tokens between pred_answer and gold_answer | |||
| precision = 1.0 * num_same / len(pred_answer) if pred_answer.strip() == "" else 0 | |||
| recall = 1.0 * num_same / len(gold_answer) if gold_answer.strip() == "" else 0 | |||
| if pred_answer.strip() == "" and gold_answer.strip() == "": | |||
| f1_score = 1 | |||
| else: | |||
| f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0 | |||
| return f1_score | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="All Task dataset preprocessing") | |||
| parser.add_argument("--input_file", type=str, default="", | |||
| help="The log file path of evaluation in Reading Comprehension. ") | |||
| parser.add_argument("--addition_file", type=str, default="", help="Coqa-dev-v1.0.json path") | |||
| args_opt = parser.parse_args() | |||
| input_file = args_opt.input_file | |||
| addition_file = args_opt.addition_file | |||
| find_word = 'Pred_answer:' | |||
| find_word_length = len(find_word) | |||
| pred_answer_list = [] | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| while True: | |||
| line = f.readline() | |||
| if not line: | |||
| break | |||
| index = line.find(find_word) | |||
| if index != -1: | |||
| pred_answer = line[index + find_word_length:].strip() | |||
| pred_answer_list.append(pred_answer) | |||
| dataset = json.load(open(addition_file)) | |||
| pred_answer_num = 0 | |||
| total_f1score = 0 | |||
| average_f1score = 0 | |||
| data_num = len(pred_answer_list) | |||
| for story in dataset['data']: | |||
| questions = story['questions'] | |||
| multiple_answers = [story['answers']] | |||
| multiple_answers += story['additional_answers'].values() | |||
| for question in questions: | |||
| pred_a = pred_answer_list[pred_answer_num] | |||
| turn_id = question['turn_id'] | |||
| max_score = 0 | |||
| max_group = 0 | |||
| flag = 0 | |||
| for i, answer in enumerate(multiple_answers): | |||
| gold_a = answer[turn_id - 1]['input_text'] | |||
| score = calculate_f1(pred_a, gold_a) | |||
| if score > max_score: | |||
| max_score = score | |||
| max_group = i | |||
| # calculate the max score in multiple answers and record it's number. | |||
| gold_a = multiple_answers[max_group][turn_id - 1]['input_text'] | |||
| pred_answer_num += 1 | |||
| total_f1score += max_score | |||
| average_f1score = total_f1score / pred_answer_num | |||
| print('==================== data {} ===================='.format(pred_answer_num)) | |||
| print('| Gold_answer:{}'.format(gold_a)) | |||
| print('| Pred_answer:{}'.format(pred_a)) | |||
| print('| F1_Score:{:.8f}'.format(average_f1score)) | |||
| print('=====================================================\n') | |||
| if pred_answer_num >= data_num: | |||
| flag = 1 | |||
| break | |||
| # Stop flag | |||
| if flag: | |||
| print('Finished evaluation with addition answer! \n') | |||
| print("********************** Testing Finished **********************") | |||
| print('| Test file name: {}'.format(input_file)) | |||
| print('| Final F1 score: {:.8f}'.format(average_f1score)) | |||
| print('| Total data num: {}'.format(pred_answer_num)) | |||
| print("**************************************************************") | |||
| break | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,270 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for Children's Book Test task. | |||
| """ | |||
| import argparse | |||
| import time | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CBT | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.utils.metric_method import Accuracy | |||
| from src.dataset import create_cbt_dataset, create_language_model_dataset | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.utils.task_utils import calculate_choice_prob_for_cbt | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||
| epoch_num: the number of epoch. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_" + "cbt_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" + str(epoch_num) +\ | |||
| "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load pretrained parameter successfully!\n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("==================== Starting Finetuning ====================") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("==================== Finetuning Success ====================") | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, num_choice=None): | |||
| """ | |||
| Do evaluation for CBT task. | |||
| Args: | |||
| dataset: the eval dataset. | |||
| network: the network with loss. | |||
| metric: the evaluation method. | |||
| load_checkpoint_path: the file path which saved finetuned model checkpoint. | |||
| eval_type: | |||
| num_choice: | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| if metric.lower() == "accuracy": | |||
| print("Prepare to calculate the accuracy score ...") | |||
| gpt2_cbt = network(config=gpt2_net_cfg, | |||
| is_training=False, | |||
| use_one_hot_embeddings=False | |||
| ) | |||
| gpt2_cbt.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| if eval_type == "zero-shot": | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(gpt2_cbt, final_param_dict) | |||
| print("load pretrained parameter successfully!\n") | |||
| elif eval_type == "finetuned": | |||
| load_param_into_net(gpt2_cbt, param_dict) | |||
| print("load finetuned parameter successfully!\n") | |||
| else: | |||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||
| model = Model(gpt2_cbt) | |||
| callback = Accuracy() | |||
| columns_list = ["input_ids", "input_mask", "input_length", "mc_labels"] | |||
| print("==================== [ACC] Testing ====================") | |||
| num_data = 1 | |||
| all_choice_prob = [] | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, input_mask, input_length, mc_labels = input_data | |||
| print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||
| # print("mc_labels: {}".format(mc_labels)) # [batch_size] | |||
| logits = model.predict(input_ids, input_mask) | |||
| # choice_prob_list [batch_size] | |||
| choice_prob_list = calculate_choice_prob_for_cbt(logits=logits, | |||
| batch_size=gpt2_net_cfg.batch_size, | |||
| input_length=input_length, | |||
| input_ids=input_ids) | |||
| all_choice_prob.append(choice_prob_list) | |||
| if (num_data * gpt2_net_cfg.batch_size) % num_choice == 0: | |||
| all_choice_prob_np = np.array(all_choice_prob) | |||
| all_choice_prob_np = all_choice_prob_np.reshape((-1, num_choice)) | |||
| print("| all_choice_prob_np: ", all_choice_prob_np) | |||
| print("| all_choice_prob_np shape: ", all_choice_prob_np.shape) | |||
| mc_labels = np.array([mc_labels.asnumpy()[0]]) | |||
| callback.update(all_choice_prob_np, mc_labels) | |||
| all_choice_prob = [] | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**************************************************************") | |||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||
| callback.acc_num / callback.total_num)) | |||
| print("********************** Testing Finished **********************") | |||
| else: | |||
| raise ValueError("metric method not supported, support: [Accuracy]") | |||
| def run_cbt_task(): | |||
| """ | |||
| run Children's Book Test (CBT) task | |||
| """ | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate CBT task") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=1, | |||
| help="ID of target device. ") | |||
| parser.add_argument("--num_choice", type=int, default=10, | |||
| help="The number of choice in CBT task. ") | |||
| parser.add_argument("--metric_method", type=str, default="Accuracy", | |||
| help="The eval method including [Accuracy]. Default: Accuracy.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: true.") | |||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--epoch_num", type=int, default=1, | |||
| help="Epoch number. Default: 1.") | |||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the finetuned checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path for train.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path for evaluation.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device_target = args_opt.device_target | |||
| if device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=device_target, | |||
| device_id=args_opt.device_id, | |||
| max_call_depth=3000) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| gpt2_loss = GPT2CBT(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| if args_opt.do_train.lower() == "true": | |||
| print("============== Start Loading Train Dataset ============") | |||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| print("============== Start Loading Evaluation Dataset ============") | |||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||
| eval_dataset = create_cbt_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.eval_data_file_path) | |||
| do_eval(eval_dataset, GPT2CBT, metric, load_finetune_ckpt_path, args_opt.eval_type, args_opt.num_choice) | |||
| if __name__ == "__main__": | |||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_cbt_task() | |||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,293 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for Reading Comprehension task. | |||
| """ | |||
| import argparse | |||
| import time | |||
| from mindspore import context | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CoQA | |||
| from src.GPT2ForReadComprehension import GPT2CoQAModel | |||
| from src.utils.metric_method import F1 | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.dataset import create_language_model_dataset | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.utils.tokenization import Tokenizer | |||
| from src.GPT2_generation import GenerateForReadComprehension | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||
| epoch_num: the number of epoch. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_rc_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load the pretrained parameter successfully! \n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("=================== Starting Training For Translation Task ====================") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("=================== Translation Training Success ====================") | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="", | |||
| generate_length=1, top_k=1, top_p=1.0, temperature=1.0): | |||
| """ | |||
| Do evaluation on Translation | |||
| Args: | |||
| dataset: the eval dataset. | |||
| network: the network with loss. | |||
| metric: the evaluation method. | |||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| if metric.lower() == "f1": | |||
| print("Prepare to calculate the BLEU score ...") | |||
| gpt2_rc = network(config=gpt2_net_cfg, | |||
| is_training=False, | |||
| use_one_hot_embeddings=False) | |||
| gpt2_rc.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| if eval_type == "zero-shot": | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.' + name] = param_dict[name] | |||
| final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(gpt2_rc, final_param_dict) | |||
| print("load pretrained parameter successfully!\n") | |||
| elif eval_type == "finetuned": | |||
| load_param_into_net(gpt2_rc, param_dict) | |||
| print("load finetuned parameter successfully!\n") | |||
| else: | |||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||
| model = Model(gpt2_rc) | |||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||
| callback = F1() | |||
| rc_generator = GenerateForReadComprehension(decoder=model, | |||
| config=gpt2_net_cfg, | |||
| tokenizer=tokenizer, | |||
| generate_length=generate_length, | |||
| topk_num=top_k, | |||
| topp_prob=float(top_p), | |||
| temperature=float(temperature) | |||
| ) | |||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||
| print("==================== [F1] Testing ====================") | |||
| num_data = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, _, label_ids = input_data | |||
| print("input_ids shape: {}".format(input_ids.shape)) | |||
| print("label_ids shape: {}".format(label_ids.shape)) | |||
| passage, pred_answer, gold_answer = rc_generator.generate_for_read_comprehension(input_ids) | |||
| for batch_id in range(gpt2_net_cfg.batch_size): | |||
| print("============== [F1] {} ================".format(num_data + 1)) | |||
| print(" | Passage:{}".format(passage[batch_id])) | |||
| print(" | Gold_answer:{}".format(gold_answer[batch_id])) | |||
| print(" | Pred_answer:{}".format(pred_answer[batch_id])) | |||
| pred = callback.get_normalize_answer_token(pred_answer[batch_id]) | |||
| gold = callback.get_normalize_answer_token(gold_answer[batch_id]) | |||
| callback.update(pred, gold) | |||
| num_data += 1 | |||
| average_f1_score = callback.f1_score / num_data | |||
| print("============== Evaluation =================") | |||
| print("| Avg F1 Score:{:.8f}".format(average_f1_score)) | |||
| print("=============================================\n\n") | |||
| print("********************** Testing Finished **********************") | |||
| else: | |||
| raise ValueError("metric method not supported in Reading Comprehension task, support: [F1]") | |||
| def run_Readcomprehension(): | |||
| ''' | |||
| run Readcomprehension task | |||
| ''' | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate translation") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=0, | |||
| help="ID of target device. ") | |||
| parser.add_argument("--metric_method", type=str, default="F1", | |||
| help="The eval method including [F1]. Default: F1.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: false.") | |||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--epoch_num", type=int, default=1, | |||
| help="Epoch number. Default: 1.") | |||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||
| help="pretrained vocab and merge file path.") | |||
| parser.add_argument("--generate_length", type=int, default=55, | |||
| help="The generation length of translation sentence.") | |||
| parser.add_argument("--top_k", type=int, default=1, | |||
| help="Parameter for Top-K sampling.") | |||
| parser.add_argument("--top_p", type=str, default="1.0", | |||
| help="parameter for Top-P sampling.") | |||
| parser.add_argument("--temperature", type=str, default="1.0", | |||
| help="Parameter for generation, greater if generation more diverse. ") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device_target = args_opt.device_target | |||
| if device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=device_target, | |||
| device_id=args_opt.device_id, | |||
| max_call_depth=3000) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| gpt2_loss = GPT2CoQA(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| if args_opt.do_train.lower() == "true": | |||
| print("============== Start Loading Translation Train Dataset ==============") | |||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| print("============ Start Loading Translation Evaluation Dataset ============") | |||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.eval_data_file_path) | |||
| do_eval(eval_dataset, GPT2CoQAModel, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||
| args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p, | |||
| args_opt.temperature) | |||
| if __name__ == "__main__": | |||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_Readcomprehension() | |||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,328 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for LAMBADA task. | |||
| """ | |||
| import argparse | |||
| import math | |||
| import time | |||
| from mindspore import context | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Lambada | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.utils.metric_method import LastWordAccuracy | |||
| from src.dataset import create_language_model_dataset, create_lambada_control_dataset | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.utils.task_utils import get_final_word_label | |||
| from src.utils.tokenization import Tokenizer | |||
| from src.GPT2_generation import GenerateForLambada | |||
| from src.utils.CrossEntropy import CrossEntropyCalculationWithMask | |||
| from src.utils.get_config_setting import get_train_setting, get_model_setting | |||
| from src.utils.task_utils import calculate_final_word_loss | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrain model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetune model checkpoint. | |||
| epoch_num: the number of epoch | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_" + "lambada_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load pretrained parameter successfully!\n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("==================== Starting Finetuning ====================") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("==================== Finetuning Success ====================") | |||
| def eval_result_print(metric="accuracy", callback=None): | |||
| """ | |||
| Print eval result. | |||
| """ | |||
| if metric.lower() == "accuracy": | |||
| print("acc_num {}, total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||
| callback.acc_num / callback.total_num)) | |||
| else: | |||
| raise ValueError("metric method not supported, support: [accuracy]") | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, stop_word_file="", | |||
| generate_length_dynamic=True, tokenizer_file_path=""): | |||
| """ | |||
| Do eval | |||
| Args: | |||
| dataset: the eval dataset. | |||
| network: the network with loss. | |||
| metric: the evaluation method. | |||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||
| eval_type: the eval type, i.e. zero-shot, finetuned. | |||
| generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True. | |||
| tokenizer_file_path: the tokenizer file path for vocab file and merge file. | |||
| stop_word_file: stop word file for calculating Accuracy. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||
| gpt2_lambada = network(config=gpt2_net_cfg, | |||
| is_training=False, | |||
| use_one_hot_embeddings=False) | |||
| gpt2_lambada.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| if eval_type == "zero-shot": | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(gpt2_lambada, final_param_dict) | |||
| print("load pretrained parameter successfully!\n") | |||
| elif eval_type == "finetuned": | |||
| load_param_into_net(gpt2_lambada, param_dict) | |||
| print("load finetuned parameter successfully!\n") | |||
| model = Model(gpt2_lambada) | |||
| if metric.lower() == "accuracy": | |||
| print("Prepare to calculate the accuracy score ...") | |||
| callback = LastWordAccuracy() | |||
| columns_list = ["input_ids", "input_mask", "input_length"] | |||
| print("==================== [ACC] Testing ====================") | |||
| lambada_generator = GenerateForLambada(decoder=model, | |||
| config=gpt2_net_cfg, | |||
| tokenizer=tokenizer, | |||
| generate_length_dynamic=generate_length_dynamic, | |||
| max_iterations=200, | |||
| stop_word_file=stop_word_file) | |||
| num_data = 1 | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, input_mask, input_length = input_data | |||
| print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||
| logits = model.predict(input_ids, input_mask) | |||
| predict_str = lambada_generator.generate_for_lambada(input_ids=input_ids, | |||
| logits=logits, | |||
| input_length=input_length) | |||
| label_str = get_final_word_label(input_ids=input_ids, input_length=input_length, tokenizer=tokenizer) | |||
| callback.update(predict_str, label_str) | |||
| eval_result_print(metric, callback) | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**********************************************************") | |||
| eval_result_print(metric, callback) | |||
| print("******************** Testing Finished ********************") | |||
| elif metric.lower() == "ppl": | |||
| print("Prepare to calculate the ppl score ...") | |||
| cross_entropy = CrossEntropyCalculationWithMask(is_training=True, | |||
| num_labels=gpt2_net_cfg.vocab_size, | |||
| config=gpt2_net_cfg) | |||
| columns_list = ["input_ids", "input_mask", "input_length"] | |||
| num_data = 1 | |||
| total_loss = 0.0 | |||
| print("==================== [PPL] Testing ====================") | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, input_mask, input_length = input_data | |||
| print("| [PPL] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||
| logits = model.predict(input_ids, input_mask) # (batch_size, seq_len, vocab_size) | |||
| avg_batch_loss = calculate_final_word_loss(logits, | |||
| gpt2_net_cfg.batch_size, | |||
| input_ids, | |||
| input_length, | |||
| cross_entropy) | |||
| total_loss += avg_batch_loss | |||
| avg_total_loss = total_loss / num_data | |||
| print(" | Current AVG loss:", avg_total_loss) | |||
| print(" | Current AVG ppl:", math.exp(avg_total_loss)) | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**********************************************************") | |||
| print("Average PPL: {:.6f}".format(math.exp(avg_total_loss))) | |||
| print("******************** Testing Finished ********************") | |||
| else: | |||
| raise ValueError("metric method not supported, support: [accuracy, ppl]") | |||
| def run_lambada(): | |||
| """ | |||
| Run Lambada task. | |||
| """ | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate languagemodel") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=2, | |||
| help="ID of target device.") | |||
| parser.add_argument("--metric_method", type=str, default="PPL", | |||
| help="The eval method including [Accuracy, PPL]. Default: Accuracy.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: false.") | |||
| parser.add_argument("--eval_type", type=str, default="finetuned", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--epoch_num", type=int, default=3, | |||
| help="Epoch number. Default: 1.") | |||
| parser.add_argument("--train_data_shuffle", type=str, default="false", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--generate_length_dynamically", type=str, default="true", | |||
| help="Enable generate_length_Dynamically. Default: true.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path.") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path.") | |||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||
| help="pretrained vocab and merge file path.") | |||
| parser.add_argument("--stop_word_file_path", type=str, default="", | |||
| help="The stop word file path.") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device = args_opt.device_target | |||
| if device == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| gpt2_loss = GPT2Lambada(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| if args_opt.do_train.lower() == "true": | |||
| get_train_setting(cfg) | |||
| get_model_setting(cfg, gpt2_net_cfg) | |||
| print("============== Start Loading Train Dataset ============") | |||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| get_model_setting(cfg, gpt2_net_cfg) | |||
| print("============== Start Loading Evaluation Dataset ============") | |||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||
| eval_dataset = create_lambada_control_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.eval_data_file_path) | |||
| do_eval(eval_dataset, GPT2Lambada, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||
| args_opt.stop_word_file_path, args_opt.generate_length_dynamically, args_opt.tokenizer_file_path) | |||
| if __name__ == "__main__": | |||
| print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_lambada() | |||
| print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,255 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for Language Modeling task. | |||
| """ | |||
| import argparse | |||
| import math | |||
| import time | |||
| from mindspore import context | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2LM | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.dataset import create_language_model_dataset | |||
| from src.utils.get_config_setting import get_train_setting, get_model_setting | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||
| epoch_num: the number of epoch. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_language_model_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load pretrained parameter successfully!\n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("==================== Starting Finetuning ====================") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("==================== Finetuning Success ====================") | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None): | |||
| """ | |||
| Do eval | |||
| Args: | |||
| dataset: the eval dataset. | |||
| network: the network with loss. | |||
| metric: the evaluation method. | |||
| load_checkpoint_path: the file path which saved finetuned model checkpoint. | |||
| eval_type: | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| if metric.lower() == "ppl": | |||
| print("Prepare to calculate the ppl score ...") | |||
| gpt2_loss = network(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| gpt2_loss.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| if eval_type == "zero-shot": | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(gpt2_loss, final_param_dict) | |||
| print("load pretrained parameter successfully!\n") | |||
| elif eval_type == "finetuned": | |||
| load_param_into_net(gpt2_loss, param_dict) | |||
| print("load finetuned parameter successfully!\n") | |||
| else: | |||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||
| model = Model(gpt2_loss) | |||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||
| print("==================== [PPL] Testing ====================") | |||
| num_data = 1 | |||
| total_loss = 0.0 | |||
| avg_loss = 0.0 | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, input_mask, label_ids = input_data | |||
| loss = model.predict(input_ids, input_mask, label_ids) | |||
| loss = float(loss.asnumpy()) | |||
| total_loss += loss | |||
| avg_loss = float(total_loss / num_data) | |||
| print(" | Current Loss: {:.6f}".format(avg_loss)) | |||
| print(" | Current PPL: {}\n\n".format(math.exp(avg_loss))) | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**************************************************************") | |||
| print("Average Loss: {:.6f}".format(avg_loss)) | |||
| print("Average PPL: {:.6f}".format(math.exp(avg_loss))) | |||
| print("********************** Testing Finished **********************") | |||
| else: | |||
| raise ValueError("metric method not supported, support: [ppl]") | |||
| def run_languagemodel(): | |||
| """ | |||
| run Language Modeling task | |||
| """ | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate language modelings task") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=1, | |||
| help="ID of target device. ") | |||
| parser.add_argument("--metric_method", type=str, default="PPL", | |||
| help="The eval method including [PPL]. Default: PPL.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: true.") | |||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--epoch_num", type=int, default=1, | |||
| help="Epoch number. Default: 1.") | |||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the finetuned checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path for train.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path for evaluation.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device_target = args_opt.device_target | |||
| if device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=device_target, | |||
| device_id=args_opt.device_id, | |||
| max_call_depth=3000) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| gpt2_loss = GPT2LM(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| if args_opt.do_train.lower() == "true": | |||
| get_train_setting(cfg) | |||
| get_model_setting(cfg, gpt2_net_cfg) | |||
| print("==================== Start Loading Train Dataset ==================") | |||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| get_model_setting(cfg, gpt2_net_cfg) | |||
| print("==================== Start Loading Evaluation Dataset ==================") | |||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.eval_data_file_path) | |||
| do_eval(eval_dataset, GPT2LM, metric, load_finetune_ckpt_path, args_opt.eval_type) | |||
| if __name__ == "__main__": | |||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_languagemodel() | |||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,296 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for Summarization task. | |||
| """ | |||
| import time | |||
| import argparse | |||
| from mindspore import context | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.GPT2ForSummarization import GPT2SummarizationModel | |||
| from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.utils.metric_method import Rouge | |||
| from src.dataset import create_language_model_dataset | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.utils.tokenization import Tokenizer | |||
| from src.utils.task_utils import clean_hypo, modify_paramdict | |||
| from src.GPT2_generation import GenerateForSummarization | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrain model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetune model checkpoint. | |||
| epoch_num: the number of epoch | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list( | |||
| filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load pretrained parameter successfully!\n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("============== Starting Finetuning ==============") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("============== Finetuning Success ==============") | |||
| def eval_result_print(metric="Rouge", callback=None): | |||
| """ | |||
| print eval result | |||
| """ | |||
| if metric == "Rouge": | |||
| print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}". | |||
| format(callback.Rouge1 / callback.total_num, | |||
| callback.Rouge2 / callback.total_num, | |||
| callback.RougeL / callback.total_num, | |||
| (callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num))) | |||
| else: | |||
| raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric))) | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="", | |||
| top_k=None, top_p=None, temperature=None, generate_length=None): | |||
| """ | |||
| Do evaluation on summarization | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| if metric.lower() == "rouge": | |||
| print("Prepare to calculate the Rouge score ...") | |||
| callback = Rouge() | |||
| gpt2_loss = network(config=gpt2_net_cfg, | |||
| is_training=False, | |||
| use_one_hot_embeddings=False) | |||
| gpt2_loss.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.") | |||
| load_param_into_net(gpt2_loss, reorganized_param_dict) | |||
| # load nn.Cell into Model and initiate tokenizer and Sample | |||
| model = Model(gpt2_loss) | |||
| tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json', | |||
| merge_file=tokenizer_file + 'gpt2-merges.txt') | |||
| # load data and process text generation | |||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||
| summarization_generator = GenerateForSummarization(model, | |||
| config=gpt2_net_cfg, | |||
| tokenizer=tokenizer, | |||
| select_sentence=3, | |||
| eval_type=eval_type, | |||
| topk=top_k, | |||
| topp=float(top_p), | |||
| temperature=float(temperature), | |||
| generate_length=generate_length) | |||
| num_data = 1 | |||
| print("==================== [Summrization] Testing ====================") | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for value in columns_list: | |||
| input_data.append(data[value]) | |||
| input_ids, _, label_ids = input_data | |||
| print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||
| print("input_ids shape: {}".format(input_ids.shape)) | |||
| print("label_ids shape: {}".format(label_ids.shape)) | |||
| hypothesis, ref = summarization_generator.generate_for_summarization(input_ids) | |||
| if ref[0] == '' or ref[0] is None: | |||
| print("Sorry ref_list is None, skip it!") | |||
| continue | |||
| print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n") | |||
| for batch_idx in range(gpt2_net_cfg.batch_size): | |||
| hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx]) | |||
| for batch_idx in range(gpt2_net_cfg.batch_size): | |||
| hypothesis[batch_idx] = hypothesis[batch_idx].lower() | |||
| ref[batch_idx] = ref[batch_idx].lower() | |||
| callback.update(hypothesis, ref) | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**********************************************************") | |||
| eval_result_print(metric, callback) | |||
| print("******************** Testing Finished ********************") | |||
| else: | |||
| raise ValueError("metric method not supported in summarization, support: [Rouge]") | |||
| def run_summarization(): | |||
| """ | |||
| Run Summarization task. | |||
| """ | |||
| # set argument parser | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization") | |||
| # context and task settings | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=4, | |||
| help="ID of target device.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: false.") | |||
| parser.add_argument("--eval_type", type=str, default="finetuned", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--metric_method", type=str, default="Rouge", | |||
| help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.") | |||
| parser.add_argument("--epoch_num", type=int, default=2, | |||
| help="Epoch number. Default: 2.") | |||
| # dataset and params_dict file settings | |||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| # sampling settings | |||
| parser.add_argument("--top_k", type=int, default=2, | |||
| help="top k tokens chosen for sampling") | |||
| parser.add_argument("--top_p", type=str, default="1.0", | |||
| help="top p accumulated probability threshold for logit to be counted") | |||
| parser.add_argument("--generate_length", type=int, default=100, | |||
| help="the number of generated tokens.") | |||
| parser.add_argument("--temperature", type=str, default="1.0", | |||
| help="temperature on logits for sampling") | |||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||
| help="vocab & merge file path") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| eval_type = args_opt.eval_type | |||
| tokenizer_file = args_opt.tokenizer_file_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device = args_opt.device_target | |||
| if device == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| if args_opt.do_train.lower() == "true": | |||
| train_data_file_path = args_opt.train_data_file_path | |||
| gpt2_loss = GPT2Summarization(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| print("============== Start Loading Train Dataset ============") | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| eval_dataset_file_path = args_opt.eval_data_file_path | |||
| print("============== Start Loading Evaluation Dataset ============") | |||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=eval_dataset_file_path) | |||
| do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file, | |||
| args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length) | |||
| if __name__ == "__main__": | |||
| print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_summarization() | |||
| print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,298 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 finetune and evaluation script for Translation task. | |||
| """ | |||
| import argparse | |||
| import time | |||
| from mindspore import context | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.GPT2ForTranslation import GPT2TranslationModel | |||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Translation | |||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||
| from src.dataset import create_language_model_dataset | |||
| from src.utils.lr_schedule import GPT2LearningRate | |||
| from src.utils.tokenization import Tokenizer | |||
| from src.utils.metric_method import BLEU | |||
| from src.GPT2_generation import GenerateForTranslation | |||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||
| """ | |||
| Do train | |||
| Args: | |||
| dataset: the train dataset. | |||
| network: the network with loss | |||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||
| epoch_num: the number of epoch. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||
| steps_per_epoch = dataset.get_dataset_size() | |||
| # optimizer | |||
| if cfg.optimizer == 'AdamWeightDecay': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.AdamWeightDecay.power) | |||
| params = network.trainable_params() | |||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||
| {'params': other_params, 'weight_decay': 0.0}] | |||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||
| elif cfg.optimizer == 'Lamb': | |||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||
| decay_steps=steps_per_epoch * epoch_num, | |||
| power=cfg.Lamb.power) | |||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||
| elif cfg.optimizer == 'Momentum': | |||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||
| else: | |||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||
| # load checkpoint into network | |||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||
| prefix_name = "gpt2_translation_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||
| config=ckpt_config) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(network, final_param_dict) | |||
| print("Load the pretrained parameter successfully! \n") | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||
| netwithgrads.set_train(True) | |||
| loss_cb = LossMonitor(per_print_times=1) | |||
| model = Model(netwithgrads) | |||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||
| print("=================== Starting Training For Translation Task ====================") | |||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||
| print("=================== Translation Training Success ====================") | |||
| def eval_result_print(metric="BLEU", callback=None): | |||
| """ print eval result""" | |||
| if metric == "BLEU": | |||
| print(" | BLEU: {:.6f}".format(callback.bleu / float(callback.total_num))) | |||
| else: | |||
| raise ValueError("metric method '{}' not supported, support: [BLEU]. ".format(str(metric))) | |||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="", | |||
| generate_length=1, top_k=1, top_p=1.0, temperature=1.0): | |||
| """ | |||
| Do evaluation on Translation | |||
| Args: | |||
| dataset: the eval dataset. | |||
| network: the network with loss. | |||
| metric: the evaluation method. | |||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||
| """ | |||
| if load_checkpoint_path == "": | |||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||
| if metric.lower() == "bleu": | |||
| print("Prepare to calculate the BLEU score ...") | |||
| gpt2_translation = network(config=gpt2_net_cfg, | |||
| is_training=False, | |||
| use_one_hot_embeddings=False) | |||
| gpt2_translation.set_train(False) | |||
| param_dict = load_checkpoint(load_checkpoint_path) | |||
| if eval_type == "zero-shot": | |||
| final_param_dict = {} | |||
| for name, _ in param_dict.items(): | |||
| final_param_dict['gpt2.' + name] = param_dict[name] | |||
| final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| load_param_into_net(gpt2_translation, final_param_dict) | |||
| print("load pretrained parameter successfully!\n") | |||
| elif eval_type == "finetuned": | |||
| load_param_into_net(gpt2_translation, param_dict) | |||
| print("load finetuned parameter successfully!\n") | |||
| else: | |||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||
| model = Model(gpt2_translation) | |||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||
| callback = BLEU(tokenizer) | |||
| translation_generator = GenerateForTranslation(decoder=model, | |||
| config=gpt2_net_cfg, | |||
| tokenizer=tokenizer, | |||
| generate_length=1, | |||
| use_hint=True, | |||
| select_first_sentence=True, | |||
| topk_num=top_k, | |||
| topp_prob=float(top_p), | |||
| temperature=float(temperature) | |||
| ) | |||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||
| print("==================== [BLEU] Testing ====================") | |||
| num_data = 1 | |||
| for data in dataset.create_dict_iterator(): | |||
| input_data = [] | |||
| for i in columns_list: | |||
| input_data.append(data[i]) | |||
| input_ids, input_mask, label_ids = input_data | |||
| print("| Data count: {}".format(num_data * gpt2_net_cfg.batch_size)) | |||
| print("input_ids shape: {}".format(input_ids.shape)) | |||
| print("input_mask shape: {}".format(input_mask.shape)) | |||
| print("label_ids shape: {}".format(label_ids.shape)) | |||
| ts_predict_list, ref_list = translation_generator.generate_for_translation(input_ids) | |||
| print("| Batch Reference translation:\n{}\n".format(ref_list)) | |||
| if ref_list == '' or ref_list is None: | |||
| print("Sorry ref_list is None, skip it!") | |||
| continue | |||
| else: | |||
| print(" | Batch Predict translation:\n{}\n".format(ts_predict_list)) | |||
| callback.update(ref_list, ts_predict_list) | |||
| num_data += 1 | |||
| print("\n\n") | |||
| print("**************************************************************") | |||
| eval_result_print(metric, callback) | |||
| print("********************** Testing Finished **********************") | |||
| else: | |||
| raise ValueError("metric method not supported in translation, support: [BLEU]") | |||
| def run_translation(): | |||
| """ | |||
| run translation task | |||
| """ | |||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate translation") | |||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||
| help="Device type. Default: Ascend.") | |||
| parser.add_argument("--device_id", type=int, default=0, | |||
| help="ID of target device. ") | |||
| parser.add_argument("--metric_method", type=str, default="BLEU", | |||
| help="The eval method including [BLEU]. Default: BLEU.") | |||
| parser.add_argument("--do_train", type=str, default="false", | |||
| help="Enable train. Default: false.") | |||
| parser.add_argument("--do_eval", type=str, default="true", | |||
| help="Enable evaluation. Default: false.") | |||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||
| parser.add_argument("--epoch_num", type=int, default=1, | |||
| help="Epoch number. Default: 1.") | |||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||
| help="Enable train data shuffle. Default: true.") | |||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||
| help="Enable eval data shuffle. Default: false.") | |||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||
| help="Save the checkpoint path.") | |||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||
| help="Load the checkpoint file path.") | |||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||
| help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||
| help="pretrained vocab and merge file path.") | |||
| parser.add_argument("--generate_length", type=int, default=150, | |||
| help="The generation length of translation sentence.") | |||
| parser.add_argument("--top_k", type=int, default=1, | |||
| help="Parameter for Top-K sampling.") | |||
| parser.add_argument("--top_p", type=str, default="1.0", | |||
| help="parameter for Top-P sampling.") | |||
| parser.add_argument("--temperature", type=str, default="1.0", | |||
| help="Parameter for generation, greater if generation more diverse. ") | |||
| args_opt = parser.parse_args() | |||
| epoch_num = args_opt.epoch_num | |||
| metric = args_opt.metric_method | |||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||
| device_target = args_opt.device_target | |||
| if device_target == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=device_target, | |||
| device_id=args_opt.device_id, | |||
| max_call_depth=3000) | |||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||
| else: | |||
| raise Exception("Device target error, Ascend is supported.") | |||
| gpt2_loss = GPT2Translation(config=gpt2_net_cfg, | |||
| is_training=True, | |||
| use_one_hot_embeddings=False) | |||
| if args_opt.do_train.lower() == "true": | |||
| print("============== Start Loading Translation Train Dataset ==============") | |||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.train_data_file_path) | |||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||
| if args_opt.do_eval.lower() == "true": | |||
| print("============ Start Loading Translation Evaluation Dataset ============") | |||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||
| dataset_path=args_opt.eval_data_file_path) | |||
| do_eval(eval_dataset, GPT2TranslationModel, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||
| args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p, | |||
| args_opt.temperature) | |||
| if __name__ == "__main__": | |||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| run_translation() | |||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||
| @@ -0,0 +1,60 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_cbt.sh" | |||
| echo "for example: bash scripts/run_cbt.sh" | |||
| echo "metric method: Accuracy" | |||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_cbt.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_CBT_task.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=4 \ | |||
| --num_choice=10 \ | |||
| --metric_method="Accuracy" \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --eval_type="zero-shot" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 & | |||
| @@ -0,0 +1,68 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_lambada.sh" | |||
| echo "for example: bash scripts/run_lambada.sh" | |||
| echo "method metric include: [Accuracy, PPL]" | |||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_lambada.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| # tokenizer path | |||
| tokenizer_file_path="" | |||
| # stopword path | |||
| stop_word_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_lambada.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=1 \ | |||
| --metric_method="PPL" \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --eval_type="zero-shot" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --generate_length_dynamically="true" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path \ | |||
| --tokenizer_file_path=$tokenizer_file_path \ | |||
| --stop_word_file_path=$stop_word_file_path >> $output_log 2>&1 & | |||
| @@ -0,0 +1,59 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_language_model.sh" | |||
| echo "for example: bash scripts/run_language_model.sh" | |||
| echo "metric method: PPL" | |||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_language_model.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_language_model.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=4 \ | |||
| --metric_method="PPL" \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --eval_type="zero-shot" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 & | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_read_comprehension.sh" | |||
| echo "for example: bash scripts/run_read_comprehension.sh" | |||
| echo "metric method: F1" | |||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_read_comprehension.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| # tokenizer path | |||
| tokenizer_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_ReadComprehension.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=7 \ | |||
| --metric_method="F1" \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --eval_type="zero-shot" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path \ | |||
| --tokenizer_file_path=$tokenizer_file_path \ | |||
| --generate_length=55 \ | |||
| --top_k=1 \ | |||
| --top_p="1.0" \ | |||
| --temperature="1.0" >> $output_log 2>&1 & | |||
| @@ -0,0 +1,66 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_summarization.sh" | |||
| echo "for example: bash scripts/run_summarization.sh" | |||
| echo "eval_load_param_mode include: [zero-shot, finetuned]. Default: finetuned" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_summarization.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| # tokenizer path | |||
| tokenizer_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_summarization.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=0 \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --metric_method="Rouge" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --top_k=2 \ | |||
| --top_p="1.0" \ | |||
| --generate_length=100 \ | |||
| --temperature="1.0" \ | |||
| --eval_type="finetuned" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path \ | |||
| --tokenizer_file_path=$tokenizer_file_path >> $output_log 2>&1 & | |||
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash scripts/run_translation.sh" | |||
| echo "for example: bash scripts/run_translation.sh" | |||
| echo "metric method: BLEU" | |||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||
| echo "==============================================================================================================" | |||
| CUR_DIR=`pwd` | |||
| mkdir -p ms_log | |||
| output_log="${CUR_DIR}/ms_log/gpt2_translation.log" | |||
| # create file and head line | |||
| echo " | Eval log file: " > $output_log | |||
| echo $output_log >> $output_log | |||
| # checkpoint path | |||
| save_finetune_ckpt_path="" | |||
| load_pretrain_ckpt_path="" | |||
| load_eval_ckpt_path="" | |||
| # dataset path | |||
| train_data_file_path="" | |||
| eval_data_file_path="" | |||
| # tokenizer path | |||
| tokenizer_file_path="" | |||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||
| export GLOG_logtostderr=0 | |||
| python ${PROJECT_DIR}/../run_translation.py \ | |||
| --device_target="Ascend" \ | |||
| --device_id=4 \ | |||
| --metric_method="BLEU" \ | |||
| --do_train="false" \ | |||
| --do_eval="true" \ | |||
| --eval_type="zero-shot" \ | |||
| --epoch_num=1 \ | |||
| --train_data_shuffle="true" \ | |||
| --eval_data_shuffle="false" \ | |||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||
| --train_data_file_path=$train_data_file_path \ | |||
| --eval_data_file_path=$eval_data_file_path \ | |||
| --tokenizer_file_path=$tokenizer_file_path \ | |||
| --generate_length=100 \ | |||
| --top_k=1 \ | |||
| --top_p="1.0" \ | |||
| --temperature="1.0" >> $output_log 2>&1 & | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (CBT) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2CBTModel(nn.Cell): | |||
| """ | |||
| GPT2CBTModel is responsible for Children's Book Test (CBT) task, i.e. CBT-CN, CBT-NE datasets. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| """ | |||
| Args: | |||
| config: the configuration of GPT-2 model | |||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||
| use_one_hot_embeddings (bool): default False. | |||
| """ | |||
| super(GPT2CBTModel, self).__init__() | |||
| if not is_training: | |||
| config.summary_first_dropout = 0.0 | |||
| self.is_training = is_training | |||
| self.d_model = config.d_model | |||
| self.batch_size = config.batch_size | |||
| self.seq_length = config.seq_length | |||
| self.vocab_size = config.vocab_size | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.dtype = config.dtype | |||
| self.lm_head = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=TruncatedNormal(config.initializer_range), | |||
| has_bias=False).to_float(config.compute_type) | |||
| self.first_dropout = nn.Dropout(1 - config.summary_first_dropout) | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): shape with [batch_size, seq_len] | |||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||
| Returns: | |||
| lm_logits (Tensor): language model distribution with log_softmax, | |||
| shape with [batch_size, seq_len, vocab_size] | |||
| """ | |||
| output, _ = self.gpt2(input_ids, input_mask) # output shape is [batch_size, seq_len, d_model] | |||
| output = self.cast(output, self.dtype) | |||
| output = self.reshape(output, (-1, self.d_model)) | |||
| output = self.first_dropout(output) | |||
| lm_logits = self.lm_head(output) # [batch_size * seq_len, vocab_size] | |||
| lm_logits = self.reshape(lm_logits, (self.batch_size, self.seq_length, self.vocab_size)) | |||
| lm_logits = self.cast(lm_logits, self.dtype) | |||
| lm_logits = self.log_softmax(lm_logits) | |||
| return lm_logits | |||
| def get_lm_head(self): | |||
| return self.lm_head.weight | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (LAMBADA) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2LambadaModel(nn.Cell): | |||
| """ | |||
| GPT2LambadaModel is responsible for Lambada task, i.e. Lambada-train, Lambada-test datasets. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| """ | |||
| Args: | |||
| config: the configuration of GPT-2 model | |||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||
| use_one_hot_embeddings (bool): default False. | |||
| """ | |||
| super(GPT2LambadaModel, self).__init__() | |||
| if not is_training: | |||
| config.hidden_dropout = 0.0 | |||
| self.vocab_size = config.vocab_size | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.dtype = config.dtype | |||
| self.dense1 = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=TruncatedNormal(config.initializer_range)).to_float(mstype.float16) | |||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Args: | |||
| input_ids (Tensor): shape with [batch_size, seq_len] | |||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||
| Returns: | |||
| lm_logits (Tensor): language model distribution with log_softmax, | |||
| shape with [batch_size, seq_len, vocab_size] | |||
| """ | |||
| output, _ = self.gpt2(input_ids, input_mask) | |||
| output = self.cast(output, self.dtype) | |||
| output = self.dropout(output) | |||
| batch_size, seq_length, d_model = self.shape(output) | |||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||
| logits = self.dense1(output_reshape) | |||
| logits = self.cast(logits, self.dtype) | |||
| logits = self.log_softmax(logits) | |||
| lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||
| return lm_logits | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (Language Modeling) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2LanguageModel(nn.Cell): | |||
| """ | |||
| GPT2LanguageModel is responsible for Language Modeling task, i.e. WikiText2, WikiText103, PTB, 1BW datasets. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| """ | |||
| Args: | |||
| config: the configuration of GPT-2 model | |||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||
| use_one_hot_embeddings (bool): default False. | |||
| """ | |||
| super(GPT2LanguageModel, self).__init__() | |||
| if not is_training: | |||
| config.hidden_dropout = 0.0 | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.vocab_size = config.vocab_size | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| self.dtype = config.dtype | |||
| self.dense1 = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=TruncatedNormal(config.initializer_range), | |||
| has_bias=False).to_float(config.compute_type) | |||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||
| where 0 indicates padding position. | |||
| Returns: | |||
| lm_logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model]. | |||
| """ | |||
| output, _ = self.gpt2(input_ids, input_mask) | |||
| output = self.cast(output, self.dtype) | |||
| batch_size, seq_length, d_model = self.shape(output) | |||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||
| logits = self.dense1(output_reshape) | |||
| logits = self.cast(logits, self.dtype) | |||
| logits = self.log_softmax(logits) | |||
| lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) # [batch_size, seq_len, vocab] | |||
| return lm_logits | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (Reading Comprehension) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindspore.ops import operations as P | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2CoQAModel(nn.Cell): | |||
| """ | |||
| This class is responsible for CoQA | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| super(GPT2CoQAModel, self).__init__() | |||
| if not is_training: | |||
| config.hidden_dropout = 0.0 | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||
| self.dense1 = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=self.weight_init, | |||
| has_bias=False).to_float(config.compute_type) | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.vocab_size = config.vocab_size | |||
| self.dtype = config.dtype | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||
| where 0 indicates padding position. | |||
| Returns: | |||
| logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model]. | |||
| """ | |||
| decoder_output, _ = self.gpt2(input_ids, input_mask) | |||
| decoder_output = P.Cast()(decoder_output, self.dtype) | |||
| batch_size, seq_length, d_model = P.Shape()(decoder_output) | |||
| reshaped_ouput = P.Reshape()(decoder_output, (-1, d_model)) # [batch_size * seq_length, d_model] | |||
| logits = self.dense1(reshaped_ouput) | |||
| logits = P.Cast()(logits, self.dtype) | |||
| logits = self.log_softmax(logits) | |||
| logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||
| return logits | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (Summarization) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2SummarizationModel(nn.Cell): | |||
| """ | |||
| GPT2SummarizationModel is responsible for summary task, i.e. cnn_dailymail datasets. | |||
| Args: | |||
| config: the configuration of GPT-2 model | |||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||
| use_one_hot_embeddings (bool): default False. | |||
| """ | |||
| def __init__(self, config, is_training=True, use_one_hot_embeddings=False): | |||
| super(GPT2SummarizationModel, self).__init__() | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.lm_head = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=TruncatedNormal(config.initializer_range), | |||
| has_bias=False).to_float(mstype.float16) | |||
| self.reshape = P.Reshape() | |||
| self.dtype = config.dtype | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||
| where 0 indicates padding position. | |||
| Returns: | |||
| lm_logits (Tensor): language model distribution without log_softmax, | |||
| shape with [batch_size, seq_len, d_model]. | |||
| """ | |||
| output, _ = self.gpt2(input_ids, input_mask) | |||
| output = self.cast(output, self.dtype) | |||
| batch_size, seq_length, d_model = self.shape(output) | |||
| hidden_state = self.reshape(output, (-1, d_model)) | |||
| hidden_state = self.cast(hidden_state, self.dtype) | |||
| lm_logits = self.lm_head(hidden_state) | |||
| lm_logits = self.cast(lm_logits, self.dtype) | |||
| lm_logits = self.reshape(lm_logits, (batch_size, seq_length, -1)) | |||
| return lm_logits | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 downstream task (Translation) model script. | |||
| """ | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from .GPT2_model import GPT2Model | |||
| class GPT2TranslationModel(nn.Cell): | |||
| """ | |||
| GPT2TranslationModel is responsible for translation task, i.e. WMT-14 En-Fr, WMT-14 Fr-En datasets. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| """ | |||
| Args: | |||
| config: the configuration of GPT-2 model | |||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||
| use_one_hot_embeddings (bool): default False. | |||
| """ | |||
| super(GPT2TranslationModel, self).__init__() | |||
| if not is_training: | |||
| config.hidden_dropout = 0.0 | |||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||
| self.vocab_size = config.vocab_size | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| self.dtype = config.dtype | |||
| self.dense1 = nn.Dense(config.d_model, | |||
| config.vocab_size, | |||
| weight_init=TruncatedNormal(config.initializer_range), | |||
| has_bias=True).to_float(config.compute_type) | |||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): input sentences shape with [batch_size, seq_len] | |||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||
| Returns: | |||
| translation_logits (Tensor): language model distribution without log_softmax, | |||
| shape with [batch_size, seq_len, vocab_size] | |||
| """ | |||
| output, _ = self.gpt2(input_ids, input_mask) | |||
| output = self.cast(output, self.dtype) | |||
| output = self.dropout(output) | |||
| batch_size, seq_length, d_model = self.shape(output) | |||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||
| logits = self.dense1(output_reshape) | |||
| logits = self.cast(logits, self.dtype) | |||
| translation_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||
| return translation_logits | |||
| @@ -0,0 +1,366 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| generation class for downstream task (Summarization, Reading Comprehension, Translation) | |||
| """ | |||
| import numpy as np | |||
| from .utils.task_utils import extract_logits | |||
| from .utils.generation_utils import Sample | |||
| from .utils.tensor_manipulations import extract_string_from_tensor | |||
| INF = 1. * 1e9 | |||
| class GenerateForSummarization(): | |||
| """ | |||
| generate for summarization task | |||
| """ | |||
| def __init__(self, | |||
| decoder, | |||
| config=None, | |||
| tokenizer=None, | |||
| select_sentence=3, | |||
| eval_type="finetuned", | |||
| temperature=1.0, | |||
| generate_length=100, | |||
| topk=2, | |||
| topp=1.0): | |||
| self.decoder = decoder | |||
| self.config = config | |||
| self.tokenizer = tokenizer | |||
| self.select_sentence = select_sentence | |||
| self.eval_type = eval_type | |||
| self.generator = Sample(decoder, | |||
| tokenizer=tokenizer, | |||
| config=config, | |||
| topk_num=topk, | |||
| topp_prob=topp, | |||
| min_tokens_to_keep=1, | |||
| demo_mode=False, | |||
| temperature=temperature) | |||
| self.generate_length = generate_length | |||
| def generate_for_summarization(self, input_ids): | |||
| """generation function for summarization task""" | |||
| # prepare input_str | |||
| article_str, summary_str = extract_string_from_tensor(input_ids=input_ids, | |||
| mode="pair", | |||
| config=self.config, | |||
| tokenizer=self.tokenizer) | |||
| generated_summary_list = [""] * self.config.batch_size | |||
| # clip overflow | |||
| for batch_idx in range(self.config.batch_size): | |||
| last_dot_pos = max(article_str[batch_idx].rfind(' .'), article_str[batch_idx].rfind('. ')) + 2 | |||
| article_str[batch_idx] = article_str[batch_idx][:last_dot_pos] | |||
| # pad a <TL,DR;> token(<EOS>) after the string of Article. | |||
| tldr_str = "TL;DR:" | |||
| if self.eval_type == "finetuned": | |||
| for batch_idx in range(self.config.batch_size): | |||
| article_str[batch_idx] += (" " + tldr_str) | |||
| # add prefix | |||
| for batch_idx in range(self.config.batch_size): | |||
| article_str[batch_idx] = article_str[batch_idx] | |||
| generate_str_list, _ = self.generator.generate(input_str=article_str, generate_length=self.generate_length) | |||
| for batch_idx in range(self.config.batch_size): | |||
| generate_str = generate_str_list[batch_idx] | |||
| generated_summary = "" | |||
| if self.select_sentence > 0: | |||
| # check if there are number of select_sentence of sentences in generated text, | |||
| # if not enough, it will return full generated string | |||
| len_generate_str = len(generate_str) | |||
| search_index = -1 | |||
| for _ in range(self.select_sentence): | |||
| search_index = generate_str.find('.', search_index + 1) | |||
| if search_index == -1 or search_index >= len_generate_str: | |||
| search_index = len_generate_str | |||
| break | |||
| # increase search_index to add period token('.') if search_index does not overflow. | |||
| search_index = search_index + 1 if search_index < len_generate_str else len_generate_str | |||
| generated_summary = generate_str[:search_index] | |||
| if generated_summary.find(self.tokenizer.eos_token) != -1: | |||
| cut_pos = generated_summary.find(self.tokenizer.eos_token, 0) | |||
| generated_summary = generated_summary[:cut_pos] | |||
| else: | |||
| generated_summary = generate_str | |||
| # if all of str hs been clipped, restore it to beginning state. | |||
| if generated_summary == '': | |||
| generated_summary = generate_str | |||
| # empty str check | |||
| if generated_summary == '': | |||
| generated_summary = '<empty>' | |||
| generated_summary_list[batch_idx] = generated_summary | |||
| return generated_summary_list, summary_str # Hypo and Ref | |||
| class GenerateForLambada(): | |||
| """ | |||
| generate class for lambada task, which is to predict the final word of sentence. | |||
| """ | |||
| def __init__(self, | |||
| decoder, | |||
| config=None, | |||
| tokenizer=None, | |||
| generate_length_dynamic=True, | |||
| generate_length=1, | |||
| max_iterations=200, | |||
| stop_word_file=""): | |||
| """ | |||
| Args: | |||
| decoder: decoder (Model): GPT2 model to do generation. | |||
| config (object): configuration of given GPT2 model. | |||
| tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||
| generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True. | |||
| max_iterations (int): choose the top k token according to selected probability, there k = `max_iterations`. | |||
| generate_length (int): the final word max generated token length. | |||
| stop_word_file (str): stop word file is used to be a stop-word filter. | |||
| """ | |||
| self.decoder = decoder | |||
| self.config = config | |||
| self.batch_size = config.batch_size | |||
| self.tokenizer = tokenizer | |||
| self.generate_length_dynamic = generate_length_dynamic | |||
| self.generate_length = generate_length | |||
| self.max_iterations = max_iterations | |||
| self.stop_word_set = self.build_stop_word(stop_word_file) | |||
| self.generator = Sample(decoder=decoder, | |||
| config=config, | |||
| batch_size=1, | |||
| tokenizer=tokenizer, | |||
| topk_num=1, | |||
| topp_prob=1, | |||
| return_ids=True | |||
| ) | |||
| self.stop_eos = ['.', ',', '!', '?', '"', " '", " and", " says", " said"] | |||
| def build_stop_word(self, stop_word_file): | |||
| stop_words_set = set() | |||
| with open(stop_word_file, 'r', encoding="utf8") as file: | |||
| for line in file.readlines(): | |||
| line = line.strip('\n') | |||
| stop_words_set.add(line) | |||
| return stop_words_set | |||
| def is_stop_word(self, word): | |||
| flag = False | |||
| if word in self.stop_word_set: | |||
| flag = True | |||
| return flag | |||
| return flag | |||
| def generate_for_lambada(self, input_ids, logits, input_length): | |||
| """ | |||
| generation function for lambada task | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| logits (Tensor): the language model distribution. | |||
| input_length (Tensor): store the context length which not including final word , and whole sentence length | |||
| return: | |||
| batch_predict_words (list): the list of predict_words | |||
| """ | |||
| batch_predict_words = ["" for _ in range(self.batch_size)] | |||
| input_len_np = input_length.asnumpy() | |||
| input_ids_list = input_ids.asnumpy().tolist() | |||
| extracted_logits = extract_logits(logits=logits, position=input_len_np) # [batch_size, vocab_size] | |||
| extracted_logits = extracted_logits.asnumpy() | |||
| sorted_ids = np.argsort(-extracted_logits, axis=-1)[::, :self.max_iterations] # [batch_size, max_iterations] | |||
| for batch_idx in range(self.batch_size): | |||
| final_word_spos = input_len_np[batch_idx, 0] | |||
| context_ids = input_ids_list[batch_idx][1:final_word_spos] # 1 for dropping <bos> token | |||
| last_word_token_num = input_len_np[batch_idx, 1] - input_len_np[batch_idx, 0] | |||
| if self.generate_length_dynamic: | |||
| generate_length = last_word_token_num | |||
| else: | |||
| generate_length = self.generate_length | |||
| for num in range(self.max_iterations): | |||
| id_ = sorted_ids[batch_idx][num] | |||
| source_ids = context_ids + [id_] | |||
| source_string = self.tokenizer.decode(source_ids) | |||
| generated_ids_list = self.generator.generate(input_str=source_string, | |||
| generate_length=generate_length, | |||
| do_sample=False) | |||
| predict_tokens_ids = [id_] + generated_ids_list[0] | |||
| predict_word = self.tokenizer.decode(predict_tokens_ids) | |||
| eos_pos = min(predict_word.find(word) if predict_word.find(word) >= 0 | |||
| else INF for word in self.stop_eos) | |||
| if eos_pos == INF: | |||
| continue | |||
| else: | |||
| predict_word = predict_word[:eos_pos] | |||
| predict_word = predict_word.strip() | |||
| if predict_word.find(" ") == -1: | |||
| if self.is_stop_word(word=predict_word.lower()): | |||
| continue | |||
| batch_predict_words[batch_idx] = predict_word | |||
| print("predict word: {}".format(predict_word)) | |||
| break | |||
| return batch_predict_words | |||
| class GenerateForTranslation(): | |||
| """ | |||
| generate class for translation task | |||
| """ | |||
| def __init__(self, | |||
| decoder, | |||
| config=None, | |||
| tokenizer=None, | |||
| generate_length=1, | |||
| use_hint=True, | |||
| select_first_sentence=True, | |||
| topk_num=None, | |||
| topp_prob=None, | |||
| temperature=None | |||
| ): | |||
| self.decoder = decoder | |||
| self.config = config | |||
| self.batch_size = config.batch_size | |||
| self.tokenizer = tokenizer | |||
| self.generate_length = generate_length | |||
| self.use_hint = use_hint | |||
| self.select_first_sentence = select_first_sentence | |||
| self.generator = Sample(decoder=decoder, | |||
| config=config, | |||
| tokenizer=tokenizer, | |||
| topk_num=topk_num, | |||
| topp_prob=topp_prob, | |||
| temperature=temperature, | |||
| min_tokens_to_keep=1, | |||
| early_stop=False) | |||
| def generate_for_translation(self, input_ids): | |||
| """generation function for translation task""" | |||
| source_str_list, ref_str_list = extract_string_from_tensor(input_ids=input_ids, | |||
| mode="pair", | |||
| config=self.config, | |||
| tokenizer=self.tokenizer) | |||
| final_predict_translation_list = [""] * self.batch_size | |||
| if self.use_hint: | |||
| for index in range(self.batch_size): | |||
| source_str_list[index] += " =" # now source_str is "english sentence =" | |||
| translation_str_list, _ = self.generator.generate(input_str=source_str_list, | |||
| generate_length=self.generate_length, | |||
| do_sample=False) | |||
| for index in range(self.batch_size): | |||
| generate_str = translation_str_list[index].replace('<|endoftext|>', '') | |||
| predict_translation = "" | |||
| # According to the GPT2 paper, the select_first_sentence will be set "True" | |||
| if self.select_first_sentence: | |||
| # check if there are number of select_sentence of sentences in generated text, | |||
| # if not enough, it will return full generated string | |||
| search_index = generate_str.find('.', 0, len(generate_str)) | |||
| if search_index == -1: | |||
| search_index = len(generate_str) | |||
| else: | |||
| search_index = search_index + 1 | |||
| predict_translation = generate_str[:search_index] | |||
| else: | |||
| predict_translation = generate_str | |||
| if predict_translation == '': | |||
| predict_translation = '<empty>' | |||
| final_predict_translation_list[index] = predict_translation | |||
| return final_predict_translation_list, ref_str_list | |||
| class GenerateForReadComprehension(): | |||
| """ | |||
| generate class for Reading Comprehension task. | |||
| Args: | |||
| decoder: decoder (Model): GPT2 model to do generation. | |||
| config (object): configuration of given GPT2 model. | |||
| tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||
| generate_length (int): | |||
| """ | |||
| def __init__(self, | |||
| decoder, | |||
| config=None, | |||
| tokenizer=None, | |||
| generate_length=1, | |||
| topk_num=None, | |||
| topp_prob=None, | |||
| temperature=None | |||
| ): | |||
| self.decoder = decoder | |||
| self.config = config | |||
| self.batch_size = config.batch_size | |||
| self.tokenizer = tokenizer | |||
| self.generate_length = generate_length | |||
| self.generator = Sample(decoder=decoder, | |||
| config=config, | |||
| tokenizer=tokenizer, | |||
| topk_num=topk_num, | |||
| topp_prob=topp_prob, | |||
| temperature=temperature, | |||
| min_tokens_to_keep=1, | |||
| ) | |||
| def generate_for_read_comprehension(self, input_ids): | |||
| """generation function for reading comprehension task""" | |||
| passage_str_list, answer_str_list = extract_string_from_tensor(input_ids=input_ids, | |||
| mode="pair", | |||
| config=self.config, | |||
| tokenizer=self.tokenizer) | |||
| passage = passage_str_list[:] | |||
| generate_str_list, _ = self.generator.generate(input_str=passage_str_list, | |||
| generate_length=self.generate_length, | |||
| do_sample=False) | |||
| pred_answer = [] | |||
| for batch_id in range(self.batch_size): | |||
| new_str = generate_str_list[batch_id].replace('<|endoftext|>', '') | |||
| index_a = new_str.find('.') | |||
| index_b = new_str.find('Q:') | |||
| if index_a != -1 or index_b != -1: | |||
| index = max(index_a, index_b) | |||
| pred_answer += [new_str[1:index]] # 1 represents skip the space in the beginning of the sentence | |||
| else: | |||
| pred_answer += [new_str] | |||
| return passage, pred_answer, answer_str_list | |||
| @@ -0,0 +1,896 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| GPT-2 base model | |||
| """ | |||
| import math | |||
| import copy | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| from .weight_init import normal_weight, zero_weight | |||
| class GPT2Config: | |||
| """ | |||
| Configuration for `GPT2Model`. | |||
| Args: | |||
| batch_size (int): Batch size of input dataset. Default: 512. | |||
| seq_length (int): Length of input sequence. Default: 1024. | |||
| vocab_size (int): The shape of each embedding vector. Default: 50257. | |||
| d_model (int): Size of the bert encoder layers. Default: 768. | |||
| num_hidden_layers (int): Number of hidden layers in the GPT2Transformer decoder block. Default: 12. | |||
| num_attention_heads (int): Number of attention heads in the GPT2Transformer decoder block. Default: 12. | |||
| intermediate_size (int): Size of intermediate layer in the GPT2Transformer decoder block. Default: 3072. | |||
| hidden_act (str): Activation function used in the GPT2Transformer decoder block. Default: "gelu". | |||
| hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1. | |||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1. | |||
| max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024. | |||
| initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | |||
| input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from dataset. | |||
| Default: True. | |||
| summary_first_dropout (float): The dropout probability for GPT2CBTModel. Default: 0.1. | |||
| dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in GPT2Transformer. Default: mstype.float16. | |||
| """ | |||
| def __init__(self, | |||
| batch_size=512, | |||
| seq_length=1024, | |||
| vocab_size=50257, | |||
| d_model=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout=0.1, | |||
| attention_dropout=0.1, | |||
| max_position_embeddings=1024, | |||
| initializer_range=0.02, | |||
| input_mask_from_dataset=True, | |||
| summary_first_dropout=0.1, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ): | |||
| self.batch_size = batch_size | |||
| self.seq_length = seq_length | |||
| self.vocab_size = vocab_size | |||
| self.d_model = d_model | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_act = hidden_act | |||
| self.hidden_dropout = hidden_dropout | |||
| self.attention_dropout = attention_dropout | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.initializer_range = initializer_range | |||
| self.input_mask_from_dataset = input_mask_from_dataset | |||
| self.summary_first_dropout = summary_first_dropout | |||
| self.dtype = dtype | |||
| self.compute_type = compute_type | |||
| class EmbeddingLookup(nn.Cell): | |||
| """ | |||
| A embeddings lookup table with a fixed dictionary and size. | |||
| Args: | |||
| vocab_size (int): Size of the dictionary of embeddings. | |||
| embedding_dim (int): The size of each embedding vector. | |||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | |||
| """ | |||
| def __init__(self, | |||
| vocab_size, | |||
| embedding_dim, | |||
| use_one_hot_embeddings=False, | |||
| compute_type=mstype.float16): | |||
| super(EmbeddingLookup, self).__init__() | |||
| self.vocab_size = vocab_size | |||
| self.embedding_dim = embedding_dim | |||
| self.use_one_hot_embeddings = use_one_hot_embeddings | |||
| self.compute_type = compute_type | |||
| self.embedding_table = Parameter(normal_weight([vocab_size, embedding_dim], embedding_dim), | |||
| name='embedding_table') | |||
| self.expand = P.ExpandDims() | |||
| self.shape_flat = (-1,) | |||
| self.gather = P.GatherV2() | |||
| self.one_hot = P.OneHot() | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.array_mul = P.MatMul() | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids): | |||
| """ | |||
| get embedding according to input_ids. | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| Returns: | |||
| output (Tensor): the embedding matrix according to the input_ids. | |||
| self.embedding_table (Parameter): the whole embedding table of GPT-2 model. | |||
| """ | |||
| input_shape = self.shape(input_ids) # [batch_size, seq_length] | |||
| flat_ids = self.reshape(input_ids, self.shape_flat) # [batch_size * seq_length] | |||
| if self.use_one_hot_embeddings: | |||
| one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) | |||
| # precision transition fp32 -> fp16 | |||
| one_hot_ids = self.cast(one_hot_ids, self.compute_type) | |||
| self.embedding_table = self.cast(self.embedding_table, self.compute_type) | |||
| output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) | |||
| output_for_reshape = self.cast(output_for_reshape, mstype.float32) | |||
| else: | |||
| # [batch_size * seq_length * embedding_dim] | |||
| output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) | |||
| out_shape = input_shape + (self.embedding_dim,) | |||
| output = self.reshape(output_for_reshape, out_shape) # [batch_size, seq_length, embedidng_dim] | |||
| return output, self.embedding_table | |||
| class EmbeddingPostprocessor(nn.Cell): | |||
| """ | |||
| Postprocessors apply positional embeddings to word embeddings. | |||
| Args: | |||
| embedding_dim (int): The size of each embedding vector. | |||
| seq_length (int): the length of input sequence. | |||
| max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024. | |||
| dropout_prob (float): The dropout probability. Default: 0.1. | |||
| """ | |||
| def __init__(self, | |||
| embedding_dim=None, | |||
| seq_length=None, | |||
| max_position_embeddings=1024, | |||
| dropout_prob=0.1): | |||
| super(EmbeddingPostprocessor, self).__init__() | |||
| self.position_embedding_table = Parameter( | |||
| normal_weight([max_position_embeddings, embedding_dim], embedding_dim), name='position_embeddings') | |||
| self.expand_dims = P.ExpandDims() | |||
| self.add = P.TensorAdd() | |||
| self.gather = P.GatherV2() | |||
| self.input_indices = Tensor(np.array([x for x in range(seq_length)]), mindspore.int32) | |||
| self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32) | |||
| self.use_dropout = dropout_prob > 0 | |||
| def construct(self, word_embeddings): | |||
| """ | |||
| Add the position embedding table to token embedding table | |||
| Args: | |||
| word_embeddings (Tensor): the token embedding matrix | |||
| Returns: | |||
| output (Tensor): the final embedding matrix by adding the position embedding table | |||
| to token embedding table. | |||
| """ | |||
| position_embeddings = self.gather(self.position_embedding_table, self.input_indices, 0) | |||
| position_embeddings = self.expand_dims(position_embeddings, 0) | |||
| output = self.add(word_embeddings, position_embeddings) | |||
| if self.use_dropout: | |||
| output = self.dropout(output) | |||
| return output | |||
| class CastWrapper(nn.Cell): | |||
| """ | |||
| Cast wrapper | |||
| """ | |||
| def __init__(self, | |||
| dst_type=mstype.float32): | |||
| super(CastWrapper, self).__init__() | |||
| self.cast = P.Cast() | |||
| self.dst_type = dst_type | |||
| def construct(self, x): | |||
| """ | |||
| type cast | |||
| Args: | |||
| x (Tensor): the input which need to be cast. | |||
| Returns: | |||
| Tensor, the cast output. | |||
| """ | |||
| return self.cast(x, self.dst_type) | |||
| class LayerNorm(nn.Cell): | |||
| """ | |||
| Do layer norm | |||
| Args: | |||
| in_channels (int): In channels number of layer norm | |||
| """ | |||
| def __init__(self, | |||
| in_channels=None): | |||
| super(LayerNorm, self).__init__() | |||
| self.layer_norm = nn.LayerNorm((in_channels,)) | |||
| self.cast = P.Cast() | |||
| self.get_dtype = P.DType() | |||
| def construct(self, input_tensor): | |||
| """ | |||
| layer norm | |||
| Args: | |||
| input_tensor (Tensor): the input of layernorm. | |||
| Returns: | |||
| Tensor, the output after layernorm. | |||
| """ | |||
| output = self.cast(input_tensor, mstype.float32) | |||
| output = self.layer_norm(output) | |||
| output = self.cast(output, self.get_dtype(input_tensor)) | |||
| return output | |||
| class ResidualConnection(nn.Cell): | |||
| """ | |||
| Add residual to output. | |||
| Args: | |||
| dropout_prob (float): Dropout rate. | |||
| """ | |||
| def __init__(self, dropout_prob=0.0): | |||
| super(ResidualConnection, self).__init__() | |||
| self.add = P.TensorAdd() | |||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||
| self.use_dropout = dropout_prob > 0 | |||
| def construct(self, hidden_tensor, input_tensor): | |||
| """ | |||
| Args: | |||
| hidden_tensor (Tensor): the output of sublayer. | |||
| input_tensor (Tensor): the input tensor. | |||
| Returns: | |||
| output (Tensor): with the same shape of hidden_tensor. | |||
| """ | |||
| output = hidden_tensor | |||
| if self.use_dropout: | |||
| output = self.dropout(output) | |||
| output = self.add(output, input_tensor) | |||
| return output | |||
| class Conv1D(nn.Cell): | |||
| """ | |||
| 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). | |||
| Basically works like a linear layer but the weights are transposed. | |||
| Args: | |||
| nx (int): The number of input features. | |||
| nf (int): The number of output features. | |||
| """ | |||
| def __init__(self, | |||
| nx, | |||
| nf): | |||
| super(Conv1D, self).__init__() | |||
| self.nx = nx | |||
| self.nf = nf | |||
| self.weight = Parameter(normal_weight([nx, nf], nf), name='projection_weight') | |||
| self.bias = Parameter(zero_weight(nf), name='projection_bias') | |||
| self.matmul = P.MatMul() | |||
| self.bias_add = P.BiasAdd() | |||
| self.cast = P.Cast() | |||
| def construct(self, input_tensor): | |||
| """ | |||
| Args: | |||
| input_tensor (Tensor): the input tensor of Conv1D with shape [batch_size * seq_length, nx] | |||
| Returns: | |||
| output_tensor (Tensor): the output tensor with shape [batch_size * seq_length, self.nf] | |||
| """ | |||
| # precision transition fp32 -> fp16 | |||
| input_tensor = self.cast(input_tensor, mstype.float16) | |||
| fp16_weight = self.cast(self.weight, mstype.float16) | |||
| output_tensor = self.matmul(input_tensor, fp16_weight) # [batch_size * seq_length, self.nf] | |||
| output_tensor = self.cast(output_tensor, mstype.float32) | |||
| output_tensor = self.bias_add(output_tensor, self.bias) | |||
| return output_tensor | |||
| class MaskedSelfAttention(nn.Cell): | |||
| """ | |||
| Apply masked multi-head attention. | |||
| Args: | |||
| batch_size (int): Batch size of input datasets. Default: 512. | |||
| d_model (int): Size of last dim of input tensor. Default: 768. | |||
| seq_length (int): Length of input tensor sequence. Default: 1024. | |||
| num_attention_heads (int): Number of attention heads. Default: 12. | |||
| dim_per_head (int): Size of each attention head. Default: 64. | |||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||
| attention_dropout (float): The dropout probability for MultiheadAttention. Default: 0.0. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. | |||
| Returns: | |||
| Tensor, with the shape [batch_size, seq_length, d_model] | |||
| """ | |||
| def __init__(self, | |||
| batch_size=512, | |||
| d_model=768, | |||
| seq_length=1024, | |||
| num_attention_heads=12, | |||
| dim_per_head=64, | |||
| has_attention_mask=True, | |||
| do_return_2d_tensor=True, | |||
| attention_dropout=0.0, | |||
| compute_type=mstype.float16): | |||
| super(MaskedSelfAttention, self).__init__() | |||
| self.batch_size = batch_size | |||
| self.d_model = d_model | |||
| self.seq_length = seq_length | |||
| self.num_heads = num_attention_heads | |||
| self.dim_per_head = dim_per_head | |||
| self.has_attention_mask = has_attention_mask | |||
| self.compute_type = compute_type | |||
| assert has_attention_mask | |||
| self.scale = Tensor([1.0 / math.sqrt(float(self.dim_per_head))], dtype=compute_type) # attention scale | |||
| self.mask_data = Tensor([-10000.0,], dtype=compute_type) | |||
| self.split_head_shape = (-1, self.seq_length, self.num_heads, self.dim_per_head) | |||
| self.c_attn = Conv1D(d_model, d_model * 3) | |||
| self.c_proj = Conv1D(d_model, d_model) | |||
| self.split_for_qkv = P.Split(1, 3) | |||
| self.reshape = P.Reshape() | |||
| self.transpose = P.Transpose() | |||
| self.trans_shape = (0, 2, 1, 3) | |||
| self.matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||
| self.matmul = P.BatchMatMul() | |||
| self.multiply = P.Mul() | |||
| if self.has_attention_mask: | |||
| self.expand_dims = P.ExpandDims() | |||
| self.sub = P.Sub() | |||
| self.add = P.TensorAdd() | |||
| self.cast = P.Cast() | |||
| self.get_dtype = P.DType() | |||
| if do_return_2d_tensor: | |||
| self.shape_return = (-1, d_model) | |||
| else: | |||
| self.shape_return = (-1, seq_length, d_model) | |||
| self.softmax = nn.Softmax() | |||
| self.softmax_cast = P.Cast() | |||
| self.dropout = nn.Dropout(1 - attention_dropout) | |||
| self.use_attention_dropout = attention_dropout > 0 | |||
| def construct(self, input_tensor, attention_mask): | |||
| """ | |||
| do masked self-attention | |||
| Args: | |||
| input_tensor (Tensor): the embedding of input sequence tokens, | |||
| shape with [batch_size * seq_length, d_mdoel] | |||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||
| shape with [batch_size, seq_len, seq_len]. | |||
| Returns: | |||
| outputs (Tensor): the output of masked self-attention, shape with [batch_size * seq_len, d_model]. | |||
| """ | |||
| input_tensor = self.c_attn(input_tensor) # [batch_size * seq_length, d_model*3]---> eg.[1 * 3, 2304] | |||
| input_tensor = self.split_for_qkv(input_tensor) | |||
| query = input_tensor[0] # [batch_size * seq_length, d_model] ---> eg. [1 * 3, 768] | |||
| key = input_tensor[1] | |||
| value = input_tensor[2] | |||
| # split head | |||
| query = self.reshape(query, self.split_head_shape) | |||
| # query shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||
| query = self.transpose(query, self.trans_shape) | |||
| key = self.reshape(key, self.split_head_shape) | |||
| # key shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||
| key = self.transpose(key, self.trans_shape) | |||
| value = self.reshape(value, self.split_head_shape) | |||
| # value shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||
| value = self.transpose(value, self.trans_shape) | |||
| # attention and mask | |||
| # precision transition fp32 -> fp16 | |||
| query = self.cast(query, self.compute_type) | |||
| key = self.cast(key, self.compute_type) | |||
| attention_scores = self.matmul_trans_b(query, key) # [batch_size, num_heads, seq_len, seq_len] | |||
| attention_scores = self.cast(attention_scores, self.compute_type) | |||
| attention_scores = self.multiply(attention_scores, self.scale) | |||
| if self.has_attention_mask: | |||
| attention_mask = self.expand_dims(attention_mask, 1) # [batch_size, 1, seq_length, seq_length] | |||
| multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), | |||
| self.cast(attention_mask, self.get_dtype(attention_scores))) # fp16 | |||
| adder = self.multiply(multiply_out, self.mask_data) | |||
| adder = self.cast(adder, mstype.float32) | |||
| attention_scores = self.cast(attention_scores, mstype.float32) | |||
| attention_scores = self.add(adder, attention_scores) | |||
| attention_scores = self.softmax_cast(attention_scores, mstype.float32) | |||
| attention_probs = self.softmax(attention_scores) # [batch_size, num_heads, seq_len, seq_len] | |||
| attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key)) | |||
| if self.use_attention_dropout: | |||
| attention_probs = self.dropout(attention_probs) | |||
| value = self.cast(value, mstype.float16) | |||
| attention_probs = self.cast(attention_probs, self.compute_type) | |||
| outputs = self.matmul(attention_probs, value) # [batch_size, num_heads, seq_len, dim_per_head] | |||
| outputs = self.cast(outputs, mstype.float32) | |||
| # merge heads | |||
| outputs = self.transpose(outputs, self.trans_shape) # [batch_size, seq_len, num_heads, dim_per_head] | |||
| outputs = self.reshape(outputs, | |||
| self.shape_return) # default True, the outputs shape [batch_size * seq_len, d_model] | |||
| # project | |||
| outputs = self.c_proj(outputs) | |||
| return outputs | |||
| class FeedForward(nn.Cell): | |||
| """ | |||
| Apply two-layer feed forward | |||
| Args: | |||
| in_channels (int): Size of the input layer. Default: 768. | |||
| out_channels (int): Size of the output layers. Default: 768. | |||
| hidden_size (int): Size of the hidden layer. Default: 3072. | |||
| hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1. | |||
| """ | |||
| def __init__(self, | |||
| in_channels=786, | |||
| out_channels=768, | |||
| hidden_size=3072, | |||
| hidden_dropout=0.1): | |||
| super(FeedForward, self).__init__() | |||
| self.c_fc = Conv1D(in_channels, hidden_size) | |||
| self.c_proj = Conv1D(hidden_size, out_channels) | |||
| # self.gelu = Gelu() | |||
| self.layernorm = LayerNorm(in_channels=in_channels) | |||
| self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout) | |||
| self.gelu_act = P.Gelu() | |||
| self.dropout = nn.Dropout(1 - hidden_dropout) | |||
| self.use_dropout = hidden_dropout > 0 | |||
| self.reshape = P.Reshape() | |||
| def construct(self, input_tensor): | |||
| """ | |||
| FeedForward construct function with layernorm and residual connection. | |||
| Args: | |||
| input_tensor (Tensor): the input of FeedForward layer, shape with [batch_szie * seq_len, d_model]. | |||
| Returns: | |||
| output (Tensor): the output of FeedForward layer, shape with [batch_szie * seq_len, d_model] | |||
| """ | |||
| # LayerNorm | |||
| output = self.layernorm(input_tensor) | |||
| # Feed Forward | |||
| output = self.c_fc(output) # [batch_szie * seq_len, d_model * 4] | |||
| output = self.gelu_act(output) | |||
| # output = self.gelu(output) | |||
| output = self.c_proj(output) # [batch_szie * seq_len, d_model] | |||
| if self.use_dropout: | |||
| output = self.dropout(output) | |||
| # Add | |||
| output = self.residual_connect(output, input_tensor) | |||
| return output | |||
| class MaskedMultiHeadAttention(nn.Cell): | |||
| """ | |||
| Masked multi-head attention block. | |||
| """ | |||
| def __init__(self, | |||
| batch_size=512, | |||
| seq_length=2014, | |||
| d_model=768, | |||
| num_attention_heads=12, | |||
| attention_dropout=0.02, | |||
| hidden_dropout=0.1, | |||
| has_attention_mask=True, | |||
| compute_type=mstype.float16 | |||
| ): | |||
| super(MaskedMultiHeadAttention, self).__init__() | |||
| if d_model % num_attention_heads != 0: | |||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||
| "of attention heads (%d)" % (d_model, num_attention_heads)) | |||
| self.dim_per_head = int(d_model / num_attention_heads) # 64 | |||
| self.masked_self_attention = MaskedSelfAttention( | |||
| batch_size=batch_size, | |||
| d_model=d_model, | |||
| seq_length=seq_length, | |||
| num_attention_heads=num_attention_heads, | |||
| dim_per_head=self.dim_per_head, | |||
| has_attention_mask=has_attention_mask, | |||
| do_return_2d_tensor=True, | |||
| attention_dropout=attention_dropout, | |||
| compute_type=compute_type | |||
| ) | |||
| self.layer_norm = LayerNorm(in_channels=d_model) | |||
| self.residual_connection = ResidualConnection() | |||
| self.reshape = P.Reshape() | |||
| self.new_shape = (-1, d_model) | |||
| def construct(self, input_tensor, attention_mask): | |||
| """ | |||
| do masked multi head self-attention with layernorm and residual_connection. | |||
| Args: | |||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||
| shape with [batch_size * seq_length, d_mdoel] | |||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||
| shape with [batch_size, seq_len, seq_len]. | |||
| Returns: | |||
| outputs (Tensor): the output of MaskedMultiHeadAttention, shape with [batch_size * seq_len, d_model]. | |||
| """ | |||
| # LayerNorm | |||
| output_tensor = self.layer_norm(input_tensor) | |||
| # masked multi-head attention | |||
| # attention_output shape [batch_size * seq_length, d_model] | |||
| attention_output = self.masked_self_attention(output_tensor, attention_mask) | |||
| # residual connection | |||
| output = self.residual_connection(attention_output, input_tensor) | |||
| return output | |||
| class DecoderBlock(nn.Cell): | |||
| """ | |||
| decoder block used in GPT2. | |||
| Args: | |||
| batch_size (int): Batch size of input dataset. Default: 512. | |||
| seq_length (int): Length of input sequence. Default: 1024. | |||
| d_model (int): Size of the GPT2 decoder layers. Default: 768. | |||
| num_attention_heads (int): Number of attention heads. Default: 12. | |||
| intermediate_size (int): Size of intermediate layer. Default: 3072. | |||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.02. | |||
| hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1. | |||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. | |||
| """ | |||
| def __init__(self, | |||
| batch_size=512, | |||
| seq_length=1024, | |||
| d_model=768, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| attention_dropout=0.02, | |||
| hidden_dropout=0.1, | |||
| has_attention_mask=True, | |||
| compute_type=mstype.float16 | |||
| ): | |||
| super(DecoderBlock, self).__init__() | |||
| if d_model % num_attention_heads != 0: | |||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||
| "of attention heads (%d)" % (d_model, num_attention_heads)) | |||
| self.dim_per_head = int(d_model / num_attention_heads) # 64 | |||
| self.masked_multi_head_attention = MaskedMultiHeadAttention( | |||
| batch_size=batch_size, | |||
| seq_length=seq_length, | |||
| d_model=d_model, | |||
| num_attention_heads=num_attention_heads, | |||
| attention_dropout=attention_dropout, | |||
| hidden_dropout=hidden_dropout, | |||
| has_attention_mask=has_attention_mask, | |||
| compute_type=compute_type | |||
| ) | |||
| self.feedforward = FeedForward( | |||
| in_channels=d_model, | |||
| out_channels=d_model, | |||
| hidden_size=intermediate_size, | |||
| hidden_dropout=hidden_dropout | |||
| ) | |||
| self.reshape = P.Reshape() | |||
| self.new_shape = (-1, d_model) | |||
| def construct(self, input_tensor, attention_mask): # input tensor shape[batch_size, seq_length, d_model] | |||
| """ | |||
| DecoderBlock with masked_multi_head_attention and feedforward. | |||
| Args: | |||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||
| shape with [batch_size * seq_length, d_mdoel] | |||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||
| shape with [batch_size, seq_len, seq_len]. | |||
| Returns: | |||
| outputs (Tensor): the output of DecoderBlock, shape with [batch_size * seq_len, d_model]. | |||
| """ | |||
| input_tensor = self.reshape(input_tensor, self.new_shape) | |||
| # masked multi head attention with ln, res | |||
| attention_output = self.masked_multi_head_attention(input_tensor, attention_mask) | |||
| # feed forward with ln, res | |||
| output = self.feedforward(attention_output) | |||
| return output | |||
| class GPT2Transformer(nn.Cell): | |||
| """ | |||
| Multi-layer GPT2 transformer. | |||
| Args: | |||
| batch_size (int): Batch size of input dataset. Default: 512. | |||
| d_model (int): Size of the decoder layers. Default: 768. | |||
| seq_length (int): Length of input sequence. Default: 1024. | |||
| num_hidden_layers (int): Number of hidden layers in decoder cells. Default: 12. | |||
| num_attention_heads (int): Number of attention heads in decoder cells. Default: 12. | |||
| intermediate_size (int): Size of intermediate layer in decoder cells. Default: 3072. | |||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1. | |||
| hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1. | |||
| compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. | |||
| """ | |||
| def __init__(self, | |||
| batch_size=512, | |||
| d_model=768, | |||
| seq_length=1024, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| has_attention_mask=True, | |||
| attention_dropout=0.1, | |||
| hidden_dropout=0.1, | |||
| compute_type=mstype.float16): | |||
| super(GPT2Transformer, self).__init__() | |||
| layers = [] | |||
| for _ in range(num_hidden_layers): | |||
| layer = DecoderBlock(batch_size=batch_size, | |||
| seq_length=seq_length, | |||
| d_model=d_model, | |||
| num_attention_heads=num_attention_heads, | |||
| intermediate_size=intermediate_size, | |||
| attention_dropout=attention_dropout, | |||
| hidden_dropout=hidden_dropout, | |||
| has_attention_mask=has_attention_mask, | |||
| compute_type=compute_type) | |||
| layers.append(layer) | |||
| self.layers = nn.CellList(layers) | |||
| self.reshape = P.Reshape() | |||
| self.new_shape = (-1, d_model) | |||
| # self.out_shape = (batch_size, seq_length, d_model) | |||
| self.out_shape = (-1, seq_length, d_model) | |||
| def construct(self, input_tensor, attention_mask): | |||
| """ | |||
| Do Multi DecoderBlock. | |||
| Args: | |||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||
| shape with [batch_size * seq_length, d_mdoel] | |||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||
| shape with [batch_size, seq_len, seq_len]. | |||
| Returns: | |||
| outputs (Tensor): the output of GPT2Transformer, shape with [batch_size * seq_len, d_model]. | |||
| """ | |||
| prev_output = self.reshape(input_tensor, self.new_shape) | |||
| for layer_module in self.layers: | |||
| layer_output = layer_module(prev_output, attention_mask) | |||
| prev_output = layer_output | |||
| output = self.reshape(prev_output, self.out_shape) | |||
| return output | |||
| class CreateAttentionMaskFromInputMask(nn.Cell): | |||
| """ | |||
| Create attention mask according to input mask. | |||
| Args: | |||
| config (Class): Configuration for GPT2Model. | |||
| """ | |||
| def __init__(self, config): | |||
| super(CreateAttentionMaskFromInputMask, self).__init__() | |||
| self.input_mask_from_dataset = config.input_mask_from_dataset | |||
| self.input_mask = None | |||
| self.compute_type = config.compute_type | |||
| assert self.input_mask_from_dataset | |||
| self.cast = P.Cast() | |||
| self.shape = P.Shape() | |||
| self.reshape = P.Reshape() | |||
| self.matmul = P.BatchMatMul() | |||
| self.multiply = P.Mul() | |||
| # mask future positions | |||
| ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length)) | |||
| self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32) | |||
| def construct(self, input_mask, mask_future=True): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_mask (Tensor): Tensor mask vectors with shape [batch_size, seq_len]. | |||
| mask_future (bool): Whether mask future (for decoder training). Default: True. | |||
| Returns: | |||
| attention_mask (Tensor): shape [batch_size, seq_len, seq_len]. | |||
| """ | |||
| input_shape = self.shape(input_mask) | |||
| shape_right = (input_shape[0], 1, input_shape[1]) # [batch_size, 1, seq_len] | |||
| shape_left = input_shape + (1,) # [batch_size, seq_len, 1] | |||
| input_mask = self.cast(input_mask, mstype.float32) | |||
| mask_left = self.reshape(input_mask, shape_left) | |||
| mask_right = self.reshape(input_mask, shape_right) | |||
| # precision transition fp32 -> fp16 | |||
| mask_left = self.cast(mask_left, self.compute_type) | |||
| mask_right = self.cast(mask_right, self.compute_type) | |||
| attention_mask = self.matmul(mask_left, mask_right) # [batch_szie, seq_len, seq_len] | |||
| attention_mask = self.cast(attention_mask, mstype.float32) | |||
| if mask_future: | |||
| attention_mask = self.multiply(attention_mask, self.lower_triangle_mask) | |||
| return attention_mask | |||
| class GPT2Model(nn.Cell): | |||
| """ | |||
| Decoder Representations from Transformers. | |||
| Args: | |||
| config (Class): Configuration for GPT2Model. | |||
| is_training (bool): True for training mode. False for eval mode. | |||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | |||
| """ | |||
| def __init__(self, | |||
| config, | |||
| is_training, | |||
| use_one_hot_embeddings=False | |||
| ): | |||
| super(GPT2Model, self).__init__() | |||
| self.config = copy.deepcopy(config) | |||
| self.is_training = is_training | |||
| if not is_training: | |||
| self.config.hidden_dropout = 0.0 | |||
| self.config.attention_dropout = 0.0 | |||
| self.input_mask_from_dataset = self.config.input_mask_from_dataset | |||
| self.batch_size = self.config.batch_size | |||
| self.seq_length = self.config.seq_length | |||
| self.d_model = self.config.d_model | |||
| self.num_hidden_layers = self.config.num_hidden_layers | |||
| self.embedding_dim = self.config.d_model | |||
| self.last_idx = self.num_hidden_layers - 1 | |||
| self.gpt2_embedding_lookup = EmbeddingLookup( | |||
| vocab_size=self.config.vocab_size, | |||
| embedding_dim=self.embedding_dim, | |||
| use_one_hot_embeddings=use_one_hot_embeddings, | |||
| compute_type=self.config.compute_type | |||
| ) | |||
| self.gpt2_embedding_postprocess = EmbeddingPostprocessor( | |||
| embedding_dim=self.embedding_dim, | |||
| seq_length=self.seq_length, | |||
| max_position_embeddings=self.config.max_position_embeddings, | |||
| dropout_prob=self.config.hidden_dropout | |||
| ) | |||
| self.gpt2_decoder = GPT2Transformer( | |||
| batch_size=self.batch_size, | |||
| d_model=self.d_model, | |||
| seq_length=self.seq_length, | |||
| num_hidden_layers=self.num_hidden_layers, | |||
| num_attention_heads=self.config.num_attention_heads, | |||
| intermediate_size=self.config.intermediate_size, | |||
| has_attention_mask=True, | |||
| attention_dropout=self.config.attention_dropout, | |||
| hidden_dropout=self.config.hidden_dropout, | |||
| compute_type=self.config.compute_type | |||
| ) | |||
| self.cast_compute_type = CastWrapper(dst_type=self.config.compute_type) | |||
| self.layer_norm = LayerNorm(in_channels=self.d_model) | |||
| self.dropout = nn.Dropout(1 - self.config.hidden_dropout) | |||
| self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.config) | |||
| self.reshape = P.Reshape() | |||
| self.new_shape = (-1, self.d_model) | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| Construct network. | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||
| where 0 indicates padding position. | |||
| Returns: | |||
| decoder_output (Tensor): shape[batch_size, seq_len, d_model]. | |||
| embedding_tables (Tensor): word embeddings with shape [vocab_size, d_model] | |||
| """ | |||
| # Embedding | |||
| word_embeddings, embedding_tables = self.gpt2_embedding_lookup(input_ids) | |||
| embedding_output = self.gpt2_embedding_postprocess(word_embeddings) | |||
| embedding_output = self.dropout(embedding_output) | |||
| # Attention mask with shape [batch_size, seq_len, seq_len] | |||
| attention_mask = self._create_attention_mask_from_input_mask(input_mask, True) | |||
| # GPT2 decoder | |||
| decoder_output = self.gpt2_decoder( | |||
| self.cast_compute_type(embedding_output), | |||
| self.cast_compute_type(attention_mask) | |||
| ) | |||
| # LayerNorm | |||
| decoder_output = self.reshape(decoder_output, self.new_shape) | |||
| decoder_output = self.layer_norm(decoder_output) | |||
| decoder_output = self.reshape(decoder_output, (-1, self.seq_length, self.d_model)) | |||
| return decoder_output, embedding_tables | |||
| def get_token_embeddings(self): | |||
| return self.gpt2_embedding_lookup.embedding_table.asnumpy() | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """clip gradient""" | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||
| # pylint: disable=consider-using-in | |||
| @clip_grad.register("Number", "Number", "Tensor") | |||
| def _clip_grad(clip_type, clip_value, grad): | |||
| """ | |||
| Clip gradients. | |||
| Inputs: | |||
| clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. | |||
| clip_value (float): Specifies how much to clip. | |||
| grad (tuple[Tensor]): Gradients. | |||
| Outputs: | |||
| tuple[Tensor], clipped gradients. | |||
| """ | |||
| if clip_type != 0 and clip_type != 1: | |||
| return grad | |||
| dt = F.dtype(grad) | |||
| if clip_type == 0: | |||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| else: | |||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| return new_grad | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Data operations""" | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from .finetune_eval_config import gpt2_net_cfg | |||
| def create_language_model_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""): | |||
| """create dataset like language model task""" | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = de.MindDataset(dataset_path, | |||
| columns_list=["input_ids", "input_mask", "label_ids"], | |||
| shuffle=do_shuffle, | |||
| num_shards=device_num, | |||
| shard_id=rank_id) | |||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||
| ds = ds.map(operations=type_cast_op, input_columns="label_ids") | |||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||
| print("output shape: {}".format(ds.output_shapes())) | |||
| print("output type: {}".format(ds.output_types())) | |||
| print("============== create dataset successful ===============") | |||
| return ds | |||
| def create_cbt_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=False, dataset_path=""): | |||
| """create dataset for cbt task""" | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = de.MindDataset(dataset_path, | |||
| columns_list=["input_ids", "input_mask", "input_length", "mc_labels"], | |||
| shuffle=do_shuffle, | |||
| num_shards=device_num, | |||
| shard_id=rank_id) | |||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_length") | |||
| ds = ds.map(operations=type_cast_op, input_columns="mc_labels") | |||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||
| print("output shape: {}".format(ds.output_shapes())) | |||
| print("output type: {}".format(ds.output_types())) | |||
| print("============== create CBT LM dataset successful ===============") | |||
| return ds | |||
| def create_lambada_control_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""): | |||
| """create dataset for lambada task""" | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = de.MindDataset(dataset_path, | |||
| columns_list=["input_ids", "input_mask", "input_length"], | |||
| shuffle=do_shuffle, | |||
| num_shards=device_num, | |||
| shard_id=rank_id) | |||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||
| ds = ds.map(operations=type_cast_op, input_columns="input_length") | |||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||
| ds = ds.repeat(repeat_count) | |||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||
| print("output shape: {}".format(ds.output_shapes())) | |||
| print("output type: {}".format(ds.output_types())) | |||
| print("============== create dataset successful ===============") | |||
| return ds | |||
| @@ -0,0 +1,104 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GPT-2 finetune config and GPT-2 model config""" | |||
| from easydict import EasyDict as edict | |||
| import mindspore.common.dtype as mstype | |||
| from .GPT2_model import GPT2Config | |||
| cfg = edict({ | |||
| 'gpt2_network': 'large', | |||
| 'optimizer': 'Lamb', | |||
| 'AdamWeightDecay': edict({ | |||
| 'learning_rate': 1e-5, | |||
| 'end_learning_rate': 1e-7, | |||
| 'power': 1.0, | |||
| 'weight_decay': 0.01, | |||
| 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), | |||
| 'eps': 1e-6, | |||
| }), | |||
| 'Lamb': edict({ | |||
| 'learning_rate': 1e-5, | |||
| 'end_learning_rate': 1e-7, | |||
| 'power': 1.0, | |||
| 'weight_decay': 0.01, | |||
| 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), | |||
| }), | |||
| 'Momentum': edict({ | |||
| 'learning_rate': 2e-5, | |||
| 'momentum': 0.9, | |||
| }), | |||
| }) | |||
| """ | |||
| three kinds of GPT2 model version | |||
| """ | |||
| if cfg.gpt2_network == 'small': | |||
| gpt2_net_cfg = GPT2Config( | |||
| batch_size=8, | |||
| seq_length=1024, | |||
| vocab_size=50257, | |||
| d_model=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act="gelu", | |||
| hidden_dropout=0.1, | |||
| attention_dropout=0.1, | |||
| max_position_embeddings=1024, | |||
| initializer_range=0.02, | |||
| input_mask_from_dataset=True, | |||
| summary_first_dropout=0.1, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| if cfg.gpt2_network == 'medium': | |||
| gpt2_net_cfg = GPT2Config( | |||
| batch_size=8, | |||
| seq_length=1024, | |||
| vocab_size=50257, | |||
| d_model=1024, | |||
| num_hidden_layers=24, | |||
| num_attention_heads=16, | |||
| intermediate_size=4096, | |||
| hidden_act="gelu", | |||
| hidden_dropout=0.1, | |||
| attention_dropout=0.1, | |||
| max_position_embeddings=1024, | |||
| initializer_range=0.02, | |||
| input_mask_from_dataset=True, | |||
| summary_first_dropout=0.1, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| if cfg.gpt2_network == 'large': | |||
| gpt2_net_cfg = GPT2Config( | |||
| batch_size=6, | |||
| seq_length=1024, | |||
| vocab_size=50257, | |||
| d_model=1280, | |||
| num_hidden_layers=36, | |||
| num_attention_heads=20, | |||
| intermediate_size=5120, | |||
| hidden_act="gelu", | |||
| hidden_dropout=0.1, | |||
| attention_dropout=0.1, | |||
| max_position_embeddings=1024, | |||
| initializer_range=0.02, | |||
| input_mask_from_dataset=True, | |||
| summary_first_dropout=0.1, | |||
| dtype=mstype.float32, | |||
| compute_type=mstype.float16, | |||
| ) | |||
| @@ -0,0 +1,464 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GPT-2 finetune for downstream task""" | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from mindspore import context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import get_group_size | |||
| from .utils.CrossEntropy import CrossEntropyCalculationWithMask | |||
| from .clip_grad_utils import clip_grad | |||
| from .GPT2ForLanguageModel import GPT2LanguageModel | |||
| from .GPT2ForLambada import GPT2LambadaModel | |||
| from .GPT2ForCBT import GPT2CBTModel | |||
| from .GPT2ForTranslation import GPT2TranslationModel | |||
| from .GPT2ForReadComprehension import GPT2CoQAModel | |||
| from .GPT2ForSummarization import GPT2SummarizationModel | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @_grad_overflow.register("Tensor") | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| class GPT2FinetuneCell(nn.Cell): | |||
| """ | |||
| Specifically defined for finetuning where only three inputs tensor are needed. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| super(GPT2FinetuneCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.allreduce = P.AllReduce() | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| self.grad_reducer = None | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| degree = get_group_size() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.cast = P.Cast() | |||
| self.gpu_target = False | |||
| if context.get_context("device_target") == "GPU": | |||
| self.gpu_target = True | |||
| self.float_status = P.FloatStatus() | |||
| self.addn = P.AddN() | |||
| self.reshape = P.Reshape() | |||
| else: | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| name="loss_scale") | |||
| def construct(self, | |||
| input_ids, | |||
| input_mask, | |||
| label_ids, | |||
| sens=None): | |||
| """ | |||
| GPT2 Finetune. | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||
| """ | |||
| weights = self.weights | |||
| init = False | |||
| loss = self.network(input_ids, | |||
| input_mask, | |||
| label_ids) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| if not self.gpu_target: | |||
| init = self.alloc_status() | |||
| clear_before_grad = self.clear_before_grad(init) | |||
| F.control_depend(loss, init) | |||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| label_ids, | |||
| self.cast(scaling_sens, | |||
| mstype.float32)) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| if not self.gpu_target: | |||
| flag = self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| F.control_depend(grads, flag) | |||
| F.control_depend(flag, flag_sum) | |||
| else: | |||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||
| flag_sum = self.addn(flag_sum) | |||
| flag_sum = self.reshape(flag_sum, (())) | |||
| if self.is_distributed: | |||
| flag_reduce = self.allreduce(flag_sum) | |||
| cond = self.less_equal(self.base, flag_reduce) | |||
| else: | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, cond) | |||
| return F.depend(ret, succ) | |||
| class GPT2LM(nn.Cell): | |||
| """ | |||
| Train interface for Language Modeling finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||
| super(GPT2LM, self).__init__() | |||
| self.gpt2 = GPT2LanguageModel(config, is_training, use_one_hot_embeddings) | |||
| self.num_labels = config.vocab_size | |||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||
| num_labels=self.num_labels, | |||
| config=config) | |||
| self.is_training = is_training | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_mask, label_ids): | |||
| """ | |||
| construct function for Language Modeling | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||
| Returns: | |||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||
| otherwise, return the computed loss. | |||
| """ | |||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||
| if self.is_training: | |||
| shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||
| label_ids = label_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| return lm_logits | |||
| class GPT2Lambada(nn.Cell): | |||
| """ | |||
| Train interface for Lambada finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| super(GPT2Lambada, self).__init__() | |||
| self.gpt2 = GPT2LambadaModel(config, is_training, use_one_hot_embeddings) | |||
| self.num_labels = config.vocab_size | |||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||
| num_labels=self.num_labels, | |||
| config=config) | |||
| self.is_training = is_training | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_mask, label_ids=None): | |||
| """ | |||
| construct function for Lambada task | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| Returns: | |||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||
| otherwise, return the computed loss. | |||
| """ | |||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||
| if self.is_training: | |||
| shift_logits = lm_logits[:, :-1, :] # [batch_size, seq_length - 1, vocab_size] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||
| label_ids = label_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| return lm_logits | |||
| class GPT2CBT(nn.Cell): | |||
| """ | |||
| Train interface for Children's Book Test finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||
| super(GPT2CBT, self).__init__() | |||
| self.gpt2 = GPT2CBTModel(config, is_training, use_one_hot_embeddings) | |||
| self.num_labels = config.vocab_size | |||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||
| num_labels=self.num_labels, | |||
| config=config) | |||
| self.is_training = is_training | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_mask): | |||
| """ | |||
| construct function for CBT task | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| Returns: | |||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||
| otherwise, return the computed loss. | |||
| """ | |||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||
| if self.is_training: | |||
| shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||
| label_ids = input_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| return lm_logits | |||
| class GPT2Translation(nn.Cell): | |||
| """ | |||
| Train interface for Translation finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| super(GPT2Translation, self).__init__() | |||
| self.gpt2 = GPT2TranslationModel(config, is_training, use_one_hot_embeddings) | |||
| self.num_labels = config.vocab_size | |||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||
| num_labels=self.num_labels, | |||
| config=config) | |||
| self.is_training = is_training | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| def construct(self, input_ids, input_mask, label_ids): | |||
| """ | |||
| construct function for Translation task | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||
| Returns: | |||
| translation_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||
| otherwise, return the computed loss. | |||
| """ | |||
| translation_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||
| translation_logits = self.log_softmax(translation_logits) | |||
| if self.is_training: | |||
| shift_logits = translation_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||
| label_ids = label_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| return translation_logits | |||
| class GPT2Summarization(nn.Cell): | |||
| """ | |||
| Train interface for Summary finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||
| super(GPT2Summarization, self).__init__() | |||
| self.gpt2 = GPT2SummarizationModel(config, is_training, use_one_hot_embeddings) | |||
| self.is_training = is_training | |||
| self.last_idx = (-1,) | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.reshape = P.Reshape() | |||
| self.shape = P.Shape() | |||
| self.batch_size = config.batch_size | |||
| self.seq_length = config.seq_length | |||
| self.vocab_size = config.vocab_size | |||
| self.cast = P.Cast() | |||
| self.loss_function = CrossEntropyCalculationWithMask(num_labels=self.vocab_size, | |||
| is_training=self.is_training, | |||
| config=config) | |||
| def construct(self, input_ids, input_mask, label_ids): | |||
| """ | |||
| construct function for Language Modeling | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||
| Returns: | |||
| loss (mstype.float32): if is_training is True, return the computed loss. | |||
| """ | |||
| output = self.gpt2(input_ids, input_mask) | |||
| shift_logits = output[::, :-1, ::] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.vocab_size)) | |||
| shift_logits = self.log_softmax(shift_logits) | |||
| label_ids = label_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss_function(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| class GPT2CoQA(nn.Cell): | |||
| """ | |||
| Train interface for Reading Comprehension finetuning task. | |||
| Args: | |||
| config (class): the configuration of GPT-2 model. | |||
| is_training (bool): whether to train. | |||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||
| """ | |||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||
| super(GPT2CoQA, self).__init__() | |||
| self.gpt2 = GPT2CoQAModel(config, is_training, use_one_hot_embeddings) | |||
| self.num_labels = config.vocab_size | |||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||
| num_labels=self.num_labels, | |||
| config=config) | |||
| self.is_training = is_training | |||
| self.reshape = P.Reshape() | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| def construct(self, input_ids, input_mask, label_ids=None): | |||
| """ | |||
| construct function for reading comprehension task | |||
| Args: | |||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||
| Returns: | |||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||
| otherwise, return the computed loss. | |||
| """ | |||
| lm_logits = self.gpt2(input_ids, input_mask) | |||
| lm_logits = self.log_softmax(lm_logits) | |||
| if self.is_training: | |||
| shift_logits = lm_logits[::, :-1, ::] | |||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) | |||
| label_ids = label_ids[::, 1:] | |||
| input_mask = input_mask[::, 1:] | |||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||
| return loss | |||
| return lm_logits | |||
| @@ -0,0 +1,82 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Calculate Cross Entropy With Mask""" | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| import mindspore.nn as nn | |||
| class CrossEntropyCalculationWithMask(nn.Cell): | |||
| """ | |||
| Cross Entropy loss | |||
| """ | |||
| def __init__(self, is_training=None, num_labels=None, config=None): | |||
| super(CrossEntropyCalculationWithMask, self).__init__() | |||
| self.onehot = P.OneHot() | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.reshape = P.Reshape() | |||
| self.last_idx = (-1,) | |||
| self.neg = P.Neg() | |||
| self.cast = P.Cast() | |||
| self.is_training = is_training | |||
| self.num_labels = num_labels | |||
| if config is not None: | |||
| # for PPL calculation in evaluation | |||
| self.input_mask_length = Tensor(config.batch_size * (config.seq_length - 1), mstype.float32) | |||
| def construct(self, logits, label_ids, input_mask=None): | |||
| """ | |||
| Calculate loss | |||
| Args: | |||
| logits (Tensor): the probability distribution over vocabulary. | |||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||
| input_mask (Tensor): input sentences padding mask, where 0 indicates padding position. | |||
| Returns: | |||
| return_value (Tensor, mstype.float32): if is_training is False, directly return the logits, otherwise, | |||
| return the computed loss. | |||
| """ | |||
| # logits [batch * (seq_length-1), vocab_size] label_ids [batch, seq_length-1] | |||
| if self.is_training: | |||
| label_ids = self.reshape(label_ids, self.last_idx) # label_ids [batch * (seq_length-1)] | |||
| one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, | |||
| self.off_value) # [batch * (seq_length-1), vocab_size] | |||
| per_example_loss = self.neg( | |||
| self.reduce_sum(one_hot_labels * logits, self.last_idx)) # [batch * (seq_length-1)] | |||
| # for PPL calculation in evaluation | |||
| if input_mask is not None: | |||
| input_mask = self.cast(self.reshape(input_mask, self.last_idx), | |||
| mstype.float32) # [batch * (seq_length-1)] | |||
| valid_loss_sum = self.reduce_sum(input_mask * per_example_loss, ()) | |||
| valid_element_sum = self.reduce_sum(input_mask, ()) + self.cast(F.tuple_to_array((1e-5,)), | |||
| mstype.float32) | |||
| loss = valid_loss_sum / valid_element_sum | |||
| else: | |||
| loss = self.reduce_mean(per_example_loss, self.last_idx) # a number | |||
| return_value = self.cast(loss, mstype.float32) | |||
| else: | |||
| return_value = logits * 1.0 # [batch * (seq_length-1), vocab_size] | |||
| return return_value | |||
| @@ -0,0 +1,488 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """data preprocess for downstream task""" | |||
| import re | |||
| import json | |||
| import random | |||
| def lambada_detokenizer(string): | |||
| string = re.sub(r"``", "-DQ-", string) | |||
| string = re.sub(r"`", "-SQ-", string) | |||
| string = re.sub(r"''", "-DQ-", string) | |||
| string = re.sub(r" '", "-SQ-", string) | |||
| string = re.sub("-DQ-", '"', string) | |||
| string = re.sub("-SQ-", "'", string) | |||
| string = re.sub(r"([,?!.]['\"])(\w)", "\g<1> \g<2>", string) | |||
| # contractions | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # number separators | |||
| string = string.replace(" @-@ ", "-") | |||
| string = string.replace(" @,@ ", ",") | |||
| string = string.replace(" @.@ ", ".") | |||
| # miscellaneous | |||
| string = string.replace("= = = =", "====") | |||
| string = string.replace("= = =", "===") | |||
| string = string.replace("= =", "==") | |||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||
| string = string.replace(" \n", "\n") | |||
| string = string.replace("\n ", "\n") | |||
| string = string.replace(" N ", " 1 ") | |||
| string = string.replace(" 's", "'s") | |||
| string = string.replace(" 'd", "'d") | |||
| string = string.replace(" '", "'") | |||
| string = string.replace(" n't", "n't") | |||
| string = string.replace(" .", ".") | |||
| string = string.replace(" ,", ",") | |||
| string = string.replace(" !", "!") | |||
| string = string.replace(" ?", "?") | |||
| string = string.replace(" :", ":") | |||
| string = string.replace(" ;", ";") | |||
| string = string.replace(" : ", ": ") | |||
| string = string.replace(" ; ", "; ") | |||
| string = string.replace(" ,'", ",'") | |||
| string = string.replace(" .'", ".'") | |||
| string = string.replace(" !'", "!'") | |||
| string = string.replace(" ?'", "?'") | |||
| string = string.replace("~", "") | |||
| string = string.replace("---", "") | |||
| string = string.replace("<", "") | |||
| string = string.replace(">", "") | |||
| string = string.replace("#", "") | |||
| string = string.replace(', "', ',"') | |||
| string = string.replace('. "', '."') | |||
| string = string.replace('! "', '!"') | |||
| string = string.replace('? "', '?"') | |||
| string = string.replace('"" ', '" "') | |||
| string = string.replace('• • •', '') | |||
| # sensitive word process | |||
| string = string.replace("f ** k", "fuck") | |||
| string = string.replace("f ** king", "fucking") | |||
| string = string.replace("f ** ked", "fucked") | |||
| string = string.replace("c ** k", "cock") | |||
| string = string.replace("br ** sts", "breasts") | |||
| string = string.replace("n ** ples", "nipples") | |||
| string = string.replace("ni ** les", "nipples") | |||
| string = string.replace("a ** hole", "asshole") | |||
| string = string.replace("ass ** le", "asshole") | |||
| string = string.replace("p ** sy", "pussy") | |||
| string = string.replace("pu ** y", "pussy") | |||
| string = string.replace("na ** d", "naked") | |||
| string = string.replace("nak * d", "naked") | |||
| string = string.replace("cli ** x", "climax") | |||
| string = string.replace("h * ps", "hips") | |||
| string = string.replace("c * ck", "cock") | |||
| string = string.replace("coc ** ne", "cocaine") | |||
| string = string.replace("*", "") | |||
| string = re.sub(" "," ",string) | |||
| string = re.sub(" "," ",string) | |||
| string = re.sub(" "," ",string) | |||
| return string | |||
| def lambada_dataset_preprocess(input_file, output_file): | |||
| sentences = [] | |||
| count = 0 | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| line = lambada_detokenizer(line) | |||
| split_sentence_list = line.split() | |||
| final_word = split_sentence_list[-1] | |||
| context = split_sentence_list[:-1] | |||
| new_sentence = ' '.join(context) + '\t' + ' ' + final_word | |||
| sentences.append(new_sentence) | |||
| count += 1 | |||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for sentence in sentences: | |||
| sentence = sentence.strip() | |||
| if sentence: | |||
| f.write(sentence) | |||
| f.write('\n') | |||
| count -= 1 | |||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||
| def get_gold_answer_id(gold_answer, candidate_answer_list): | |||
| id_ = 0 | |||
| for candidate in candidate_answer_list: | |||
| if gold_answer == candidate: | |||
| return id_ | |||
| id_ += 1 | |||
| def get_passage_string(passage_string, candidate_answer, final_sentence, gold_answer_id): | |||
| """ | |||
| concat each candidate answer to the rest_sentence | |||
| Args: | |||
| candidate_answer (list): store each candidate answers | |||
| final_sentence (str): the 21st sentence string with "XXXXX" | |||
| gold_answer_id (int): the id of correct answer. | |||
| return: | |||
| candidate_passage (list): the length of candidate_sentence equals to length of candidate_answer. | |||
| """ | |||
| candidate_passage = [] | |||
| for answer in candidate_answer: | |||
| passage = passage_string + " " + final_sentence | |||
| passage = passage.replace(" XXXXX", "\t XXXXX") | |||
| final_passage = passage.replace("XXXXX", answer) | |||
| whole_passage = final_passage + "\t" + str(gold_answer_id) | |||
| candidate_passage.append(whole_passage) | |||
| return candidate_passage | |||
| def cbt_dataset_preprocess(input_file, output_file): | |||
| passages = [] | |||
| candidate_passage_list = [] | |||
| passage_string = "" | |||
| count = 0 | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| single_sentence = line.split(' ', 1) | |||
| line_id = int(single_sentence[0]) | |||
| string = single_sentence[1] | |||
| if line_id == 21: | |||
| string = string.replace("\t\t", "\t") | |||
| mini_string = string.split("\t") | |||
| candidate_answer = mini_string[-1] | |||
| candidate_answer_list = candidate_answer.split("|") | |||
| gold_answer_id = get_gold_answer_id(mini_string[-2], candidate_answer_list) | |||
| candidate_passage = get_passage_string(passage_string, | |||
| candidate_answer_list, | |||
| mini_string[0], | |||
| gold_answer_id) | |||
| assert len(candidate_passage) == 10 | |||
| count += 10 | |||
| else: | |||
| passage_string = passage_string + " " + string | |||
| else: | |||
| passages.append(candidate_passage) | |||
| candidate_passage_list = [] | |||
| passage_string = "" | |||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for passage in passages: | |||
| for candidate_passage in passage: | |||
| candidate_passage = candidate_passage.replace(" \t ", "\t ") | |||
| candidate_passage = candidate_passage.strip() | |||
| f.write(candidate_passage) | |||
| f.write("\n") | |||
| count -= 1 | |||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||
| def wikitext_detokenizer(string): | |||
| # contractions | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # number separators | |||
| string = string.replace(" @-@ ", "-") | |||
| string = string.replace(" @,@ ", ",") | |||
| string = string.replace(" @.@ ", ".") | |||
| # punctuation | |||
| string = string.replace(" : ", ": ") | |||
| string = string.replace(" ; ", "; ") | |||
| string = string.replace(" . ", ". ") | |||
| string = string.replace(" .", ".") | |||
| string = string.replace(" ! ", "! ") | |||
| string = string.replace(" ? ", "? ") | |||
| string = string.replace(" , ", ", ") | |||
| # double brackets | |||
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |||
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |||
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||
| # miscellaneous | |||
| string = string.replace("= = = =", "====") | |||
| string = string.replace("= = =", "===") | |||
| string = string.replace("= =", "==") | |||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||
| string = string.replace(" \n", "\n") | |||
| string = string.replace("\n ", "\n") | |||
| string = string.replace(" N ", " 1 ") | |||
| string = string.replace(" 's", "'s") | |||
| return string | |||
| def wikitext_dataset_preprocess(input_file, output_file): | |||
| dataset_test = [] | |||
| passage = [] | |||
| count = 0 | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| if line.startswith('=') and line.endswith('=') and len(passage) != 0: | |||
| dataset_test.append(passage) | |||
| count += 1 | |||
| passage = [] | |||
| elif line.startswith('=') and line.endswith('='): | |||
| continue | |||
| else: | |||
| passage.append(line) | |||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for line in dataset_test: | |||
| text = "" | |||
| for sentence in line: | |||
| sentence = wikitext_detokenizer(sentence) | |||
| text = text + " " + sentence | |||
| text = text.strip() | |||
| f.write(text) | |||
| f.write("\n") | |||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||
| def ptb_detokenizer(string): | |||
| string = string.replace(" '", "'") | |||
| string = string.replace(" \n", "\n") | |||
| string = string.replace("\n ", "\n") | |||
| string = string.replace(" n't", "n't") | |||
| string = string.replace(" N ", "1 ") | |||
| string = string.replace("$ 1", "$1") | |||
| string = string.replace("# 1", "#1") | |||
| string = string.replace("\/abc", "") | |||
| string = string.replace("\/ua", "") | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # punctuation | |||
| string = string.replace(" : ", ": ") | |||
| string = string.replace(" ; ", "; ") | |||
| string = string.replace(" . ", ". ") | |||
| string = string.replace(" ! ", "! ") | |||
| string = string.replace(" ? ", "? ") | |||
| string = string.replace(" , ", ", ") | |||
| string = string.replace(" 's", "'s") | |||
| return string | |||
| def ptb_dataset_preprocess(input_file, output_file): | |||
| sentences = [] | |||
| count = 0 | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| line = ptb_detokenizer(line) | |||
| sentences.append(line) | |||
| count += 1 | |||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for sentence in sentences: | |||
| sentence = sentence.strip() | |||
| if sentence: | |||
| f.write(sentence) | |||
| f.write("\n") | |||
| count -= 1 | |||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||
| def onebw_detokenizer(string): | |||
| # contractions | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # number separators | |||
| string = string.replace(" @-@ ", "-") | |||
| string = string.replace(" @,@ ", ",") | |||
| string = string.replace(" @.@ ", ".") | |||
| # punctuation | |||
| string = string.replace(" : ", ": ") | |||
| string = string.replace(" ; ", "; ") | |||
| string = string.replace(" . ", ". ") | |||
| string = string.replace(" ! ", "! ") | |||
| string = string.replace(" ? ", "? ") | |||
| string = string.replace(" , ", ", ") | |||
| # double brackets | |||
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |||
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |||
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||
| # miscellaneous | |||
| string = string.replace("= = = =", "====") | |||
| string = string.replace("= = =", "===") | |||
| string = string.replace("= =", "==") | |||
| string = string.replace(" --", "") | |||
| string = string.replace("--", "") | |||
| string = string.replace("? ? ?", " ?") | |||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||
| string = string.replace(" \n", "\n") | |||
| string = string.replace("\n ", "\n") | |||
| string = string.replace(" 't", "'t") | |||
| string = string.replace(" N ", " 1 ") | |||
| string = string.replace(" 's", "'s") | |||
| string = string.replace(" '", "'") | |||
| string = string.replace(" n't", "n't") | |||
| string = string.replace("$ 1", "$1") | |||
| string = string.replace("# 1", "#1") | |||
| return string | |||
| def test_length(string): | |||
| string_list = string.split() | |||
| return len(string_list) | |||
| def onebw_dataset_preprocess(condition, input_file, output_file): | |||
| sentences = [] | |||
| count = 0 | |||
| if condition.lower() == "test": | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| sentences.append(line) | |||
| count += 1 | |||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for sentence in sentences: | |||
| sentence = sentence.strip() | |||
| if sentence: | |||
| sentence = onebw_detokenizer(sentence) | |||
| f.write(sentence) | |||
| f.write("\n") | |||
| count -= 1 | |||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||
| elif condition.lower() == "train": | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| line = onebw_detokenizer(line) | |||
| length = test_length(line) | |||
| if length > 10 and length < 60: | |||
| sentences.append(line) | |||
| count += 1 | |||
| print('read finished! count = ', count) | |||
| sample_result_list = random.sample(range(0, count), 30000) | |||
| sample_result_list.sort() | |||
| count_sample = 0 | |||
| choiced_sentence = "" | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for i in range(len(sample_result_list)): | |||
| choiced_sentence = sentences[sample_result_list[i]] | |||
| f.write(choiced_sentence) | |||
| f.write("\n") | |||
| count_sample += 1 | |||
| print('write finished! ', count_sample) | |||
| else: | |||
| raise ValueError("condition error support: [train, test]") | |||
| def coqa_dataset_preprocess(input_file, output_file): | |||
| with open(input_file, 'r', encoding='utf-8') as f: | |||
| source_data = json.load(f) | |||
| stories = [] | |||
| instances = [] | |||
| end_sep = [',', '.', ';'] | |||
| question_before_sep = " " | |||
| question_after_sep = " A: " | |||
| answer_sep = " A:\t" | |||
| for i, dialog in enumerate(source_data["data"]): | |||
| story = dialog["story"].replace("\n", "") | |||
| stories.append(story) | |||
| concat_ = "" | |||
| concat_ += story | |||
| for question, answer in zip(dialog["questions"], dialog["answers"]): | |||
| question = question["input_text"] | |||
| answer = answer["input_text"] | |||
| concat_ += question_before_sep | |||
| concat_ += question | |||
| tmp = concat_ + question_after_sep | |||
| concat_ += answer_sep | |||
| concat_ += answer | |||
| instances.append(concat_) | |||
| concat_ = tmp + answer | |||
| if concat_[-1] not in end_sep: | |||
| concat_ += "." | |||
| instances.append("") | |||
| with open(output_file, 'w', encoding='utf-8') as f: | |||
| for i in range(len(instances)): | |||
| if instances[i]: | |||
| f.write(instances[i]) | |||
| f.write("\n") | |||
| print('write {} file finished!\n total count = {}'.format(output_file, len(instances))) | |||
| def wmt14_en_fr_preprocess(input_file, output_file): | |||
| input_file = input_file + "/newstest2014-fren-ref" | |||
| output_file = output_file + "/wmt14" | |||
| language = ['.en.sgm', '.fr.sgm'] | |||
| count = 0 | |||
| # en-fr | |||
| with open(input_file + language[0], "r", encoding='utf-8') as english, \ | |||
| open(input_file + language[1], "r", encoding='utf-8') as french, \ | |||
| open(output_file + '.en_fr.txt', "a", encoding='utf-8') as enfr_f, \ | |||
| open(output_file + '.fr_en.txt', "a", encoding='utf-8') as fren_f: | |||
| line_id = 0 | |||
| for en, fr in zip(english, french): | |||
| line_id += 1 | |||
| if (en[:7] == '<seg id'): | |||
| print("=" * 20, "\n", line_id, "\n", "=" * 20) | |||
| en_start = en.find('>', 0) | |||
| en_end = en.find('</seg>', 0) | |||
| print(en[en_start + 1:en_end]) | |||
| en_ = en[en_start + 1:en_end] | |||
| fr_start = fr.find('>', 0) | |||
| fr_end = fr.find('</seg>', 0) | |||
| print(fr[fr_start + 1:fr_end]) | |||
| fr_ = fr[fr_start + 1:fr_end] | |||
| en_fr_str = en_ + "\t" + fr_ + "\n" | |||
| enfr_f.write(en_fr_str) | |||
| fr_en_str = fr_ + "\t" + en_ + "\n" | |||
| fren_f.write(fr_en_str) | |||
| count += 1 | |||
| print('write {} file finished!\n total count = {}'.format(output_file + '.en_fr.txt', count)) | |||
| print('write {} file finished!\n total count = {}'.format(output_file + '.fr_en.txt', count)) | |||
| @@ -0,0 +1,542 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| generation utils | |||
| """ | |||
| import numpy as np | |||
| from scipy.special import softmax | |||
| from mindspore.ops import operations as P | |||
| from mindspore import dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from .tensor_manipulations import extract_single_token_logits, add_last_token | |||
| INF = 1. * 1e9 | |||
| class TopKTopP_Filter(): | |||
| """ | |||
| Top K sampling along with Top P sampling(Nucleus Sampling) | |||
| Choose top-K probability of ids and those with top-P probability ids into candidate sample sets. | |||
| Use np.random.multinomial to sample | |||
| Args: | |||
| batch_size (int): batch size of input dataset. | |||
| vocab_size (int): the shape of each embedding vector. | |||
| k (int): parameter for Top-K sampling, k should be in range of [0, vocab_size]. | |||
| 0 for no filter for TopK sampling(do nothing). Default: 0. | |||
| p (float) [Optional]: parameter for Top-P sampling a.k.a. Necleus Sampling, p is in between 0.0 and 1.0. | |||
| Default: 1.0. | |||
| temperature (float) [Optional]: parameter for generation, greater if generation more diverse. Default: 1.0. | |||
| """ | |||
| def __init__(self, | |||
| batch_size=None, | |||
| vocab_size=None, | |||
| k=0, | |||
| p=1.0, | |||
| temperature=1.0, | |||
| min_tokens_to_keep=1, | |||
| ): | |||
| self.k = k | |||
| self.p = p | |||
| self.temp = temperature | |||
| self.batch_size = batch_size | |||
| self.vocab_size = vocab_size | |||
| self.min_tokens_to_keep = min_tokens_to_keep | |||
| assert self.temp > 0.0, 'temperature must be positive' | |||
| assert self.k >= 0, 'the top_k number must be no negative.' | |||
| if self.k > 0: | |||
| assert self.min_tokens_to_keep <= self.k, 'k must be larger than or equal to min_token_to_keep ' \ | |||
| 'for Top-p sampling' | |||
| if self.k == 0: | |||
| self.k = self.vocab_size | |||
| self.safety_mask = np.concatenate((np.ones((self.batch_size, self.min_tokens_to_keep)), | |||
| np.zeros((self.batch_size, self.k - self.min_tokens_to_keep))), | |||
| axis=1).astype(np.bool) | |||
| def calculate(self, distribution): | |||
| """ | |||
| calculate sampling procedure with setting initialized before, return a list of sampled ids. | |||
| Args: | |||
| distribution (numpy.ndarray): with shape (batch_size,vocab_size) | |||
| Returns: | |||
| sampled ids: a list, with length of batch_size | |||
| """ | |||
| if self.temp != 1.0: | |||
| distribution = distribution / float(self.temp) | |||
| distribution_sorted = -np.sort(-distribution, axis=1) | |||
| index_sorted = np.argsort(-distribution, axis=1) | |||
| topk_distribution = distribution_sorted[::, :self.k if self.k > 0 else self.vocab_size] | |||
| topk_indices = index_sorted[::, :self.k if self.k > 0 else self.vocab_size] | |||
| # safety check of probability | |||
| self.p = max(0.0, min(1.0, self.p)) | |||
| cum_sum = np.cumsum(softmax(topk_distribution, axis=1), axis=1) | |||
| bool_map = np.logical_or((cum_sum <= self.p), self.safety_mask).astype(np.float32) | |||
| topk_distribution = topk_distribution * bool_map + np.float32(-1e5) * (1.0 - bool_map) | |||
| topk_distribution = softmax(topk_distribution, axis=1) | |||
| # normalize for np.float64 | |||
| # choose np.float64 to avoid overflow in softmax operation | |||
| topk_distribution = topk_distribution.astype(np.float64) | |||
| for batch_idx in range(self.batch_size): | |||
| topk_distribution[batch_idx] = topk_distribution[batch_idx] / np.sum(topk_distribution[batch_idx]) | |||
| ret_ids = [] | |||
| for batch_idx in range(self.batch_size): | |||
| select_index = np.argmax(np.random.multinomial(1, topk_distribution[batch_idx])) | |||
| ret_ids.append(topk_indices[batch_idx][select_index]) | |||
| return ret_ids | |||
| class Sample(): | |||
| """ | |||
| Initiate a Sample object for sampling next token(s) from previous text. | |||
| Args: | |||
| decoder (Model): GPT2 model to do generation. | |||
| config (GPT2Config): configuration of given GPT2 model. | |||
| tokenizer (GPT2Tokenizer): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||
| generate_length (int): number of tokens which should be generated. Default: 1. | |||
| topk_num (int): number of k in Top-k Sampling, 0 for no condition constrained, | |||
| equivalent to k = self.vocab_size. Default:0 | |||
| topp_prob (float): probability parameter of Top-p sampling. | |||
| if p = 1.0, it equals to do nothing. (nucleus sampling). Default: 1.0 | |||
| temperature (float): temperature for Top-k sampling. Default: 1.0 | |||
| min_tokens_to_keep (int): guarantee for there is at least min_tokens_to_keep token(s) generated. Default:1 | |||
| early_stop (bool): whether stop when the model generates <EOS> token. | |||
| It is functioned when batch_size is 1. Default: False | |||
| demo_mode(bool): True if input_str is a str not a List of str. | |||
| self.batch_size should be 1 if it is True. Default: False | |||
| return_ids (bool): whether return ids generated from Sample. Default: False | |||
| return_last_token_logits (bool): whether return logits of last token for each time step during generation. | |||
| Default: False | |||
| append_eos (bool): whether append <EOS> token id to input_ids pass directly to GPT2Model class. Default: False | |||
| """ | |||
| def __init__(self, | |||
| decoder, | |||
| config=None, | |||
| batch_size=None, | |||
| tokenizer=None, | |||
| generate_length=1, | |||
| topk_num=0, | |||
| topp_prob=1.0, | |||
| temperature=1.0, | |||
| min_tokens_to_keep=1, | |||
| early_stop=False, | |||
| demo_mode=False, | |||
| return_ids=False, | |||
| return_last_token_logits=False, | |||
| append_eos=False, | |||
| ): | |||
| assert config is not None, 'Config is a must for sampling.' | |||
| self.decoder = decoder | |||
| self.config = config | |||
| self.tokenizer = tokenizer | |||
| self.generate_length = generate_length | |||
| self.topk_num = topk_num | |||
| self.topp_prob = topp_prob | |||
| self.temperature = temperature | |||
| self.min_tokens_to_keep = min_tokens_to_keep | |||
| self.early_stop = early_stop | |||
| self.demo_mode = demo_mode | |||
| self.return_ids = return_ids | |||
| self.return_last_token_logits = return_last_token_logits | |||
| self.append_eos = append_eos | |||
| self.seq_length = config.seq_length | |||
| self.batch_size = config.batch_size if batch_size is None else batch_size | |||
| self.vocab_size = config.vocab_size | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.reshape = P.Reshape() | |||
| self.cumsum = P.CumSum() | |||
| self.onehot = P.OneHot() | |||
| self.cast = P.Cast() | |||
| self.concat = P.Concat() | |||
| self.sample_function = P.RandomCategorical(mstype.int32) | |||
| self.filter_distribution = TopKTopP_Filter(batch_size=self.batch_size, | |||
| vocab_size=self.vocab_size, | |||
| k=self.topk_num, | |||
| p=self.topp_prob, | |||
| temperature=self.temperature, | |||
| min_tokens_to_keep=self.min_tokens_to_keep) | |||
| if self.tokenizer is not None: | |||
| self.eos_id = self.tokenizer.eos_token_id | |||
| else: | |||
| self.eos_id = config.vocab_size - 1 | |||
| if self.tokenizer is not None: | |||
| self.eos_text = self.tokenizer.eos_token | |||
| else: | |||
| self.eos_text = "<|endoftext|>" | |||
| if self.demo_mode is True: | |||
| assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format( | |||
| self.batch_size) | |||
| def _extract_string_from_tensor(self, input_ids, mode="pair"): | |||
| """ | |||
| Args: | |||
| input_ids(Tensor): input sentences with shape [self.batch_size, self.seq_len] | |||
| mode (str): ["pair", "single"] | |||
| "pair" for tasks with paired inputs `<bos> A <eos> B <eos>`, | |||
| such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`, | |||
| reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`. | |||
| "single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task. | |||
| Returns: | |||
| source_list (list): the list of source_text or first part of text. | |||
| target_list (list): the list of target_text or second part of text. | |||
| If self.batch_size is 1, it will return the first sentence of list, that is to say, the string. | |||
| Example: | |||
| for pair mode, if self.demo_mode is True, it will return source_list[0], target_list[0] | |||
| """ | |||
| assert self.tokenizer is not None, 'There is no tokenizer' | |||
| source_list = [""] * self.batch_size | |||
| target_list = [""] * self.batch_size | |||
| eos_text = self.tokenizer.eos_token | |||
| len_eos_text = len(eos_text) | |||
| input_ids_np = input_ids.asnumpy() | |||
| input_ids_np = input_ids_np.reshape((self.batch_size, self.seq_length)) | |||
| # input_ids = self.reshape(input_ids, (self.batch_size, self.seq_length)) | |||
| if mode == "pair": | |||
| for batch_idx in range(self.batch_size): | |||
| sentence_tensor = input_ids_np[batch_idx] | |||
| sentence_list = sentence_tensor.tolist()[1:] | |||
| sentence = self.tokenizer.decode(sentence_list) | |||
| source_start = 0 | |||
| source_end = sentence.find(eos_text, 0) | |||
| target_start = source_end + len_eos_text | |||
| target_end = sentence[target_start:].find(eos_text, 0) + target_start | |||
| source_list[batch_idx] = sentence[source_start:source_end] | |||
| target_list[batch_idx] = sentence[target_start:target_end] | |||
| if self.batch_size == 1 and self.demo_mode is True: | |||
| return source_list[0], target_list[0] | |||
| return source_list, target_list | |||
| if mode == "single": | |||
| for batch_idx in range(self.batch_size): | |||
| sentence_tensor = input_ids_np[batch_idx] | |||
| sentence_list = sentence_tensor.tolist()[1:] | |||
| sentence = self.tokenizer.decode(sentence_list) | |||
| source_start = 0 | |||
| source_end = sentence.find(eos_text, 0) | |||
| source_list[batch_idx] = sentence[source_start:source_end] | |||
| if self.batch_size == 1 and self.demo_mode is True: | |||
| return source_list[0] | |||
| else: | |||
| raise ValueError('mode:{} not supported, only support [pair, single].'.format(mode)) | |||
| return source_list | |||
| def _tensorize_ids_with_masks(self, src_str): | |||
| """ | |||
| Transform from string to tensor | |||
| Args: | |||
| src_str: string or list of strings | |||
| Return: | |||
| input_ids (Tensor): shape with [self.batch_size, self.seq_length] | |||
| input_mask (Tensor): shape with [self.batch_size, self.seq_length] | |||
| src_len_list (list): the length of tokens of src_string after decoded by self.tokenzier | |||
| """ | |||
| if isinstance(src_str, str): | |||
| src_str = [src_str] | |||
| single_sentence_shape = (1, self.seq_length) | |||
| src_len_list = list() | |||
| input_ids = None | |||
| input_mask = None | |||
| for batch_idx in range(self.batch_size): | |||
| src_ids_list = self.tokenizer.encode(src_str[batch_idx]) | |||
| src_ids_len = len(src_ids_list) | |||
| if src_ids_len > self.seq_length: | |||
| src_ids_list = src_ids_list[:self.seq_length] | |||
| src_ids_len = self.seq_length | |||
| src_len_list.append(src_ids_len) | |||
| return_dict = self.tokenizer.prepare_for_model(src_ids_list, | |||
| max_length=self.config.seq_length, | |||
| add_special_tokens=False) | |||
| input_ids_list = return_dict['input_ids'] | |||
| input_mask_list = return_dict['attention_mask'] | |||
| input_ids_np = np.array(input_ids_list, dtype=int) | |||
| input_mask_np = np.array(input_mask_list, dtype=int) | |||
| input_ids_np = input_ids_np.reshape(single_sentence_shape) | |||
| input_mask_np = input_mask_np.reshape(single_sentence_shape) | |||
| # input_ids_tensor = self.reshape(Tensor(np.array(input_ids_list, dtype=int), dtype=mstype.int32), | |||
| # single_sentence_shape) | |||
| # input_mask_tensor = self.reshape(Tensor(np.array(input_mask_list, dtype=int), dtype=mstype.int32), | |||
| # single_sentence_shape) | |||
| if batch_idx == 0: | |||
| # input_ids = input_ids_tensor | |||
| # input_mask = input_mask_tensor | |||
| input_ids_np_ = input_ids_np | |||
| input_mask_np_ = input_mask_np | |||
| else: | |||
| # input_ids = self.concat((input_ids, input_ids_tensor)) | |||
| # input_mask = self.concat((input_mask, input_mask_tensor)) | |||
| input_ids_np_ = np.concatenate((input_ids_np_, input_ids_np), axis=0) | |||
| input_mask_np_ = np.concatenate((input_mask_np_, input_mask_np), axis=0) | |||
| input_ids = Tensor(input_ids_np_, dtype=mstype.int32) | |||
| input_mask = Tensor(input_mask_np_, dtype=mstype.int32) | |||
| return input_ids, input_mask, src_len_list | |||
| class LastTokenPos(): | |||
| """ | |||
| class for record input_strs and the position of their last tokens | |||
| Args: | |||
| input_ (Union[list, Tensor]): list if input is a list containing strings, | |||
| Tensor with shape (batch_size, seq_length) representing input_mask. | |||
| """ | |||
| def __init__(self, input_, seq_length=1024): | |||
| if isinstance(input_, list): | |||
| self.input_strs = input_ | |||
| self.input_mask = None | |||
| else: | |||
| self.input_strs = None | |||
| self.input_mask = input_ | |||
| self.seq_length = seq_length | |||
| if self.input_strs is not None: | |||
| self.pos_list = [len(input_str) - 1 for input_str in self.input_strs] | |||
| else: | |||
| input_mask_ = P.Cast()(self.input_mask, mstype.float32) | |||
| temp_pos_list = P.ReduceSum(keep_dims=False)(input_mask_, axis=1).asnumpy().astype(np.int32).tolist() | |||
| # minimum value is always 0 for safety | |||
| self.pos_list = [max(0, pos - 1) for pos in temp_pos_list] | |||
| def get_pos(self, shift: int = 0): | |||
| # return last token if overflow | |||
| shift_list = [min(self.seq_length - 1, pos + shift) for pos in self.pos_list] | |||
| return shift_list | |||
| def _sample_from_distribution(self, distribution): | |||
| """ | |||
| sample one token per batch from self.sample_function(). | |||
| Arg: | |||
| distribution (Tensor): the distribution or logits of the last token of different batches. | |||
| shape with [batch_size, vocab_size] | |||
| Return: | |||
| word_index (Tensor): shape with [batch_size, ] | |||
| """ | |||
| distribution = self.reshape(distribution, (self.vocab_size, self.batch_size)) | |||
| topk_distribution = distribution[:self.topk_num, ::] | |||
| topk_distribution = self.reshape(topk_distribution, (self.batch_size, -1)) | |||
| word_index = self.sample_function(P.Softmax()(topk_distribution), 1, 1) | |||
| word_index = self.reshape(word_index, (-1,)) | |||
| return word_index | |||
| def _demo_mode_check(self, input_str): | |||
| """ | |||
| type check for demo_mode: 1 batch, input_str is not None and initiate full_str as input_str | |||
| """ | |||
| if self.batch_size == 1 and self.demo_mode is True: | |||
| assert input_str is not None, "demo mode should have input str" | |||
| # type check | |||
| if isinstance(input_str, list): | |||
| assert isinstance(input_str[0], str), "type of input_str is {}, " \ | |||
| "which should be str instead.".format(type(input_str[0])) | |||
| if len(input_str) != 1: | |||
| print("[WARNING] Sample.generate: length of input_str is larger than 1, " | |||
| "choose input_str[0] as input_str.") | |||
| input_str = input_str[0] | |||
| assert isinstance(input_str, str), "type of input_str is {}, " \ | |||
| "which should be str instead.".format(input_str) | |||
| input_str = [input_str] | |||
| return input_str | |||
| def _input_check_and_normalize(self, input_str=None, input_ids=None, input_mask=None, generate_length=None): | |||
| """ | |||
| input check function | |||
| """ | |||
| if input_str is not None: | |||
| assert self.tokenizer is not None, 'if choose to give input_str, a tokenizer is necessary.' | |||
| input_str = self._demo_mode_check(input_str) | |||
| if input_ids is not None: | |||
| assert input_mask is not None, 'if input_ids is given, input_mask is required either.' | |||
| if input_str is not None and input_ids is not None and input_mask is not None: | |||
| print('[WARNING] Sample.generate got input_str, input_ids and input_mask, ' | |||
| 'choose input_str as default for input') | |||
| if input_ids is None and input_mask is None: | |||
| input_ids, input_mask, _ = self._tensorize_ids_with_masks(input_str) | |||
| else: | |||
| if input_str is None: | |||
| if input_ids is not None: | |||
| input_str = self._extract_string_from_tensor(input_ids, mode="full") | |||
| if generate_length is not None: | |||
| # reload generate_length | |||
| generate_length = int(generate_length) | |||
| assert generate_length >= 0, 'generate_length can not be negative.' | |||
| else: | |||
| generate_length = self.generate_length | |||
| return input_str, input_ids, input_mask, generate_length | |||
| def generate(self, input_str=None, input_ids=None, input_mask=None, generate_length=None, do_sample=True): | |||
| """ | |||
| base function for text generation given a batch_size list of str or str itself (when demo mode is on) | |||
| Args | |||
| input_str (list(str) or str): prompt string. | |||
| generate_length: number of tokens to generate. | |||
| Returns: | |||
| generate_str: string generated by the GPT-2 model. | |||
| full_str: input_str appended with generate_str. | |||
| """ | |||
| input_str, input_ids, input_mask, generate_length = self._input_check_and_normalize(input_str, | |||
| input_ids, | |||
| input_mask, | |||
| generate_length) | |||
| return_ids_list = [[]] * self.batch_size | |||
| last_token = self.LastTokenPos(input_mask, seq_length=self.seq_length) | |||
| for i in range(generate_length): | |||
| last_token_pos_list = last_token.get_pos(shift=i) | |||
| early_stop_mask = [0] * self.batch_size | |||
| # unsorted logits (distribution) of next word | |||
| logits = self.decoder.predict(input_ids, input_mask) | |||
| if self.return_last_token_logits is True: | |||
| if i == 0: | |||
| # [batch_size, 1, vocab_size] | |||
| return_last_logits = extract_single_token_logits(logits, last_token_pos_list) | |||
| else: | |||
| # [batch_size, 1, vocab_size] + [batch_size, i, vocab_size] --> [batch_size, i+1, vocab_size] | |||
| return_last_logits = P.Concat(axis=1)((return_last_logits, | |||
| extract_single_token_logits(logits, last_token_pos_list))) | |||
| nextword_distribution = self.reshape(logits[0, last_token_pos_list[0]:last_token_pos_list[0]+1:1, ::], | |||
| (1, -1)) | |||
| # stack up nextword_distribution if batch_size is larger than 1 | |||
| if self.batch_size > 1: | |||
| for batch_idx in range(1, self.batch_size): | |||
| nextword_distribution_rest = self.reshape( | |||
| logits[batch_idx, last_token_pos_list[batch_idx]:last_token_pos_list[batch_idx] + 1:1, ::], | |||
| (1, -1)) | |||
| nextword_distribution = self.concat((nextword_distribution, nextword_distribution_rest)) | |||
| if do_sample: | |||
| # get sampled ids | |||
| nextword_distribution = nextword_distribution.asnumpy().astype(np.float32) | |||
| real_next_word_index_list = self.filter_distribution.calculate(nextword_distribution) | |||
| else: | |||
| np_nextword_distribution = nextword_distribution.asnumpy() | |||
| next_word_index = np.argmax(np_nextword_distribution, axis=-1) | |||
| real_next_word_index_list = next_word_index.tolist() | |||
| append_ids = [] | |||
| # tokenizer.decode and early_stop (if all batched generates a EOS, then it is time to say goodbye) | |||
| for batch_idx in range(self.batch_size): | |||
| next_word_index = real_next_word_index_list[batch_idx] | |||
| # earlystop if the model generates a EOS token. | |||
| if self.early_stop is True: | |||
| if next_word_index == self.eos_id: | |||
| if self.batch_size == 1: | |||
| break | |||
| else: | |||
| early_stop_mask[batch_idx] = 1 | |||
| continue | |||
| return_ids_list[batch_idx].append(next_word_index) | |||
| append_ids.append(next_word_index) | |||
| # check early_stop mask at the end of each loop | |||
| if 0 not in early_stop_mask: | |||
| break | |||
| input_ids, input_mask = add_last_token(input_ids, | |||
| input_mask, | |||
| overflow_strategy="shift", | |||
| append_ids=append_ids, | |||
| next_token_pos=last_token.get_pos(shift=i + 1)) | |||
| # add str to full str | |||
| generate_str = [""] * self.batch_size | |||
| full_str = [""] * self.batch_size | |||
| text_cnt = 0 | |||
| for text_ids in return_ids_list: | |||
| text = self.tokenizer.decode(text_ids) | |||
| generate_str[text_cnt] = text | |||
| text_cnt += 1 | |||
| for batch_idx in range(self.batch_size): | |||
| full_str[batch_idx] = input_str[batch_idx] + generate_str[batch_idx] | |||
| # return by several conditions | |||
| if self.batch_size == 1 and self.demo_mode is True: | |||
| if self.return_ids: | |||
| return generate_str[0], input_str[0], return_ids_list[0] | |||
| return generate_str[0], input_str[0] | |||
| if self.return_ids: | |||
| if self.return_last_token_logits: | |||
| return return_ids_list, return_last_logits | |||
| return return_ids_list | |||
| return generate_str, full_str | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """get config setting""" | |||
| def get_train_setting(finetune_config): | |||
| """get train config setting""" | |||
| cfg = finetune_config | |||
| print("Loading GPT2 Finetune Config setting......") | |||
| print(" | optimizer: {}".format(cfg.optimizer)) | |||
| opt = cfg['optimizer'] | |||
| print(" | learning rate: {}".format(cfg[opt]['learning_rate'])) | |||
| print(" | end learning rate: {}".format( | |||
| cfg[opt]['end_learning_rate'] if 'end_learning_rate' in cfg[opt] else 'None')) | |||
| print(" | weight decay: {}\n".format(cfg[opt]['weight_decay'] if 'weight_decay' in cfg[opt] else 'None')) | |||
| def get_model_setting(finetune_config, model_config): | |||
| """get GPT-2 model config setting""" | |||
| cfg = finetune_config | |||
| gpt2_net_cfg = model_config | |||
| print("Loading GPT2 Model Config setting......") | |||
| print(" | model size: {}".format(cfg.gpt2_network)) | |||
| print(" | batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||
| print(" | seq_length: {}".format(gpt2_net_cfg.seq_length)) | |||
| print(" | vocab_size: {}".format(gpt2_net_cfg.vocab_size)) | |||
| print(" | d_model: {}".format(gpt2_net_cfg.d_model)) | |||
| print(" | num_hidden_layers: {}".format(gpt2_net_cfg.num_hidden_layers)) | |||
| print(" | num_attention_heads: {}".format(gpt2_net_cfg.num_attention_heads)) | |||
| print(" | hidden_dropout: {}".format(gpt2_net_cfg.hidden_dropout)) | |||
| print(" | attention_dropout: {}".format(gpt2_net_cfg.attention_dropout)) | |||
| print(" | summary_first_dropout: {}\n".format(gpt2_net_cfg.summary_first_dropout)) | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """learning schedule""" | |||
| import numpy as np | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | |||
| class GPT2LearningRate(LearningRateSchedule): | |||
| """ | |||
| Implements of warmup-polydecay learning rate scheduler. | |||
| Args: | |||
| learning_rate (float): The initial value of learning rate. | |||
| end_learning_rate (float): The end value of learning rate. | |||
| warmup_steps (int): The warm up steps of learning rate. | |||
| decay_steps (int): A value used to calculate decayed learning rate. | |||
| power (float): A value used to calculate decayed learning rate. | |||
| Returns: | |||
| lr (Tensor): The learning rate value for the current step. | |||
| """ | |||
| def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): | |||
| super(GPT2LearningRate, self).__init__() | |||
| self.warmup_flag = False | |||
| if warmup_steps > 0: | |||
| self.warmup_flag = True | |||
| self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) | |||
| self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) | |||
| self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) | |||
| self.greater = P.Greater() | |||
| self.one = Tensor(np.array([1.0]).astype(np.float32)) | |||
| self.cast = P.Cast() | |||
| def construct(self, global_step): | |||
| decay_lr = self.decay_lr(global_step) | |||
| if self.warmup_flag: | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) | |||
| warmup_lr = self.warmup_lr(global_step) | |||
| lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr | |||
| else: | |||
| lr = decay_lr | |||
| return lr | |||
| @@ -0,0 +1,185 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """metric method for downstream task""" | |||
| import string | |||
| import re | |||
| from collections import Counter | |||
| import numpy as np | |||
| from .rouge_score import get_rouge_score | |||
| from .bleu import compute_bleu | |||
| class LastWordAccuracy(): | |||
| """ | |||
| LastWordAccuracy class is for lambada task (predict the final word of sentence) | |||
| """ | |||
| def __init__(self): | |||
| self.acc_num = 0 | |||
| self.total_num = 0 | |||
| def normalize(self, word): | |||
| """normalization""" | |||
| word = word.lstrip() | |||
| word = word.rstrip() | |||
| def remove_punc(text): | |||
| exclude = set(string.punctuation) | |||
| return ''.join(ch for ch in text if ch not in exclude) | |||
| def lower(text): | |||
| return text.lower() | |||
| return remove_punc(lower(word)) | |||
| def update(self, predict_label, gold_label): | |||
| if isinstance(predict_label, str) and isinstance(gold_label, str): | |||
| predict_label = [predict_label] | |||
| gold_label = [gold_label] | |||
| for predict_word, gold_word in zip(predict_label, gold_label): | |||
| self.total_num += 1 | |||
| if self.normalize(predict_word) == self.normalize(gold_word): | |||
| self.acc_num += 1 | |||
| class Accuracy(): | |||
| """ | |||
| calculate accuracy | |||
| """ | |||
| def __init__(self): | |||
| self.acc_num = 0 | |||
| self.total_num = 0 | |||
| def update(self, logits, labels): | |||
| """accuracy update""" | |||
| labels = np.reshape(labels, -1) | |||
| logits_id = np.argmax(logits, axis=-1) | |||
| print(" | Preict Label: {} Gold Label: {}".format(logits_id, labels)) | |||
| self.acc_num += np.sum(labels == logits_id) | |||
| self.total_num += len(labels) | |||
| print("\n| Accuracy = {} \n".format(self.acc_num / self.total_num)) | |||
| class F1(): | |||
| """calculate F1 score""" | |||
| def __init__(self): | |||
| self.f1_score = 0.0 | |||
| def get_normalize_answer_token(self, string_): | |||
| """Lower text and remove punctuation, article and extra whitespace.""" | |||
| def remove_articles(text): | |||
| regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) | |||
| return re.sub(regex, ' ', text) | |||
| def white_space_fix(text): | |||
| return ' '.join(text.split()) | |||
| def remove_punc(text): | |||
| exclude = set(string.punctuation) | |||
| return ''.join(char for char in text if char not in exclude) | |||
| def lower(text): | |||
| return text.lower() | |||
| return white_space_fix(remove_articles(remove_punc(lower(string_)))).split() | |||
| def update(self, pred_answer, gold_answer): | |||
| """F1 update""" | |||
| common = Counter(pred_answer) & Counter(gold_answer) | |||
| num_same = sum(common.values()) | |||
| # the number of same tokens between pred_answer and gold_answer | |||
| precision = 1.0 * num_same / len(pred_answer) if pred_answer else 0 | |||
| recall = 1.0 * num_same / len(gold_answer) if gold_answer else 0 | |||
| if ' '.join(pred_answer).strip() == "" and ' '.join(gold_answer).strip() == "": | |||
| self.f1_score += 1 | |||
| else: | |||
| self.f1_score += 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0 | |||
| print('| precision: {}, recall: {}\n'.format(precision, recall)) | |||
| class BLEU(): | |||
| """calculate BLEU score""" | |||
| def __init__(self, tokenizer=None, max_order=4, smooth=True): | |||
| self.bleu = 0.0 | |||
| self.total_num = 0 | |||
| self.tokenizer = tokenizer | |||
| self.max_order = max_order | |||
| self.smooth = smooth | |||
| def sum_bleu(self, references, translations, max_order, smooth): | |||
| """calculate the sum of bleu score""" | |||
| all_result = [] | |||
| bleu_avg = 0.0 | |||
| for refer, trans in zip(references, translations): | |||
| result = compute_bleu([[refer]], [trans], max_order, smooth) | |||
| all_result.append(result) | |||
| bleu_avg += result[0] | |||
| bleu_avg /= len(references) | |||
| return bleu_avg, all_result | |||
| def update(self, hypotheses, references): | |||
| """BLEU update""" | |||
| hypo_l = [] | |||
| ref_l = [] | |||
| if self.tokenizer is not None: | |||
| for hypo, ref in zip(hypotheses, references): | |||
| if ref.strip() == '': | |||
| print("Reference is None, skip it !") | |||
| continue | |||
| if hypo.strip() == '': | |||
| print("translation is None, skip it !") | |||
| continue | |||
| hypo_l.append(self.tokenizer.encode(hypo)) | |||
| ref_l.append(self.tokenizer.encode(ref)) | |||
| if hypo_l and ref_l: | |||
| hypotheses = hypo_l | |||
| references = ref_l | |||
| bleu_avg, _ = self.sum_bleu(references, hypotheses, self.max_order, self.smooth) | |||
| self.bleu += bleu_avg * 100 | |||
| self.total_num += 1 | |||
| print("============== BLEU: {} ==============".format(float(self.bleu / self.total_num))) | |||
| class Rouge(): | |||
| ''' | |||
| Get Rouge Score | |||
| ''' | |||
| def __init__(self): | |||
| self.Rouge1 = 0.0 | |||
| self.Rouge2 = 0.0 | |||
| self.RougeL = 0.0 | |||
| self.total_num = 0 | |||
| def update(self, hypothesis, targets): | |||
| scores = get_rouge_score(hypothesis, targets) | |||
| self.Rouge1 += scores['rouge-1']['f'] * 100 | |||
| self.Rouge2 += scores['rouge-2']['f'] * 100 | |||
| self.RougeL += scores['rouge-l']['f'] * 100 | |||
| self.total_num += 1 | |||
| print("=============== ROUGE: {} ===============".format( | |||
| (self.Rouge1 + self.Rouge2 + self.RougeL) / float(3.0 * self.total_num))) | |||
| @@ -0,0 +1,466 @@ | |||
| , | |||
| . | |||
| ? | |||
| ! | |||
| # | |||
| ~ | |||
| = | |||
| - | |||
| " | |||
| ' | |||
| : | |||
| - | |||
| … | |||
| -- | |||
| | | |||
| a | |||
| about | |||
| above | |||
| across | |||
| after | |||
| again | |||
| against | |||
| all | |||
| almost | |||
| alone | |||
| along | |||
| already | |||
| also | |||
| although | |||
| always | |||
| among | |||
| an | |||
| and | |||
| another | |||
| any | |||
| anybody | |||
| anyone | |||
| anything | |||
| anywhere | |||
| are | |||
| area | |||
| areas | |||
| around | |||
| as | |||
| ask | |||
| asked | |||
| asking | |||
| asks | |||
| at | |||
| away | |||
| b | |||
| back | |||
| backed | |||
| backing | |||
| backs | |||
| be | |||
| became | |||
| because | |||
| become | |||
| becomes | |||
| been | |||
| before | |||
| began | |||
| behind | |||
| being | |||
| beings | |||
| best | |||
| better | |||
| between | |||
| big | |||
| both | |||
| bro | |||
| but | |||
| by | |||
| c | |||
| came | |||
| can | |||
| cannot | |||
| case | |||
| cases | |||
| certain | |||
| certainly | |||
| clear | |||
| clearly | |||
| come | |||
| could | |||
| d | |||
| did | |||
| differ | |||
| different | |||
| differently | |||
| do | |||
| does | |||
| done | |||
| down | |||
| down | |||
| downed | |||
| downing | |||
| downs | |||
| during | |||
| dr | |||
| e | |||
| each | |||
| early | |||
| eh | |||
| either | |||
| end | |||
| ended | |||
| ending | |||
| ends | |||
| enough | |||
| even | |||
| evenly | |||
| ever | |||
| every | |||
| everybody | |||
| everyone | |||
| everything | |||
| everywhere | |||
| f | |||
| fact | |||
| facts | |||
| far | |||
| felt | |||
| few | |||
| find | |||
| finds | |||
| first | |||
| for | |||
| four | |||
| from | |||
| full | |||
| fully | |||
| further | |||
| furthered | |||
| furthering | |||
| furthers | |||
| g | |||
| gave | |||
| general | |||
| generally | |||
| get | |||
| gets | |||
| give | |||
| given | |||
| gives | |||
| going | |||
| good | |||
| goods | |||
| got | |||
| great | |||
| greater | |||
| greatest | |||
| group | |||
| grouped | |||
| grouping | |||
| groups | |||
| h | |||
| had | |||
| has | |||
| have | |||
| having | |||
| he | |||
| her | |||
| here | |||
| herself | |||
| hey | |||
| high | |||
| high | |||
| high | |||
| higher | |||
| highest | |||
| him | |||
| himself | |||
| his | |||
| house | |||
| how | |||
| however | |||
| i | |||
| if | |||
| important | |||
| in | |||
| interest | |||
| interested | |||
| interesting | |||
| interests | |||
| into | |||
| is | |||
| it | |||
| its | |||
| itself | |||
| j | |||
| just | |||
| k | |||
| kae | |||
| keep | |||
| keeps | |||
| kind | |||
| knew | |||
| know | |||
| known | |||
| knows | |||
| kya | |||
| l | |||
| lads | |||
| large | |||
| largely | |||
| last | |||
| later | |||
| latest | |||
| least | |||
| less | |||
| let | |||
| lets | |||
| like | |||
| likely | |||
| long | |||
| longer | |||
| longest | |||
| m | |||
| made | |||
| make | |||
| making | |||
| man | |||
| many | |||
| may | |||
| me | |||
| member | |||
| members | |||
| men | |||
| might | |||
| mister | |||
| more | |||
| most | |||
| mostly | |||
| mr | |||
| Mr | |||
| mrs | |||
| much | |||
| must | |||
| my | |||
| myself | |||
| n | |||
| na | |||
| necessary | |||
| need | |||
| needed | |||
| needing | |||
| needs | |||
| never | |||
| new | |||
| new | |||
| newer | |||
| newest | |||
| next | |||
| no | |||
| nobody | |||
| non | |||
| noone | |||
| not | |||
| nothing | |||
| now | |||
| nowhere | |||
| number | |||
| numbers | |||
| nt | |||
| nn | |||
| nope | |||
| ny | |||
| o | |||
| oi | |||
| of | |||
| off | |||
| often | |||
| old | |||
| older | |||
| oldest | |||
| on | |||
| once | |||
| one | |||
| only | |||
| open | |||
| opened | |||
| opening | |||
| opens | |||
| or | |||
| order | |||
| ordered | |||
| ordering | |||
| orders | |||
| other | |||
| others | |||
| our | |||
| out | |||
| over | |||
| oh | |||
| p | |||
| part | |||
| parted | |||
| parting | |||
| parts | |||
| per | |||
| perhaps | |||
| place | |||
| places | |||
| please | |||
| point | |||
| pointed | |||
| pointing | |||
| points | |||
| possible | |||
| present | |||
| presented | |||
| presenting | |||
| presents | |||
| problem | |||
| problems | |||
| put | |||
| puts | |||
| q | |||
| quite | |||
| r | |||
| rather | |||
| really | |||
| right | |||
| right | |||
| room | |||
| rooms | |||
| s | |||
| said | |||
| same | |||
| saw | |||
| say | |||
| says | |||
| second | |||
| seconds | |||
| see | |||
| seem | |||
| seemed | |||
| seeming | |||
| seems | |||
| sees | |||
| several | |||
| shall | |||
| she | |||
| should | |||
| show | |||
| showed | |||
| showing | |||
| shows | |||
| side | |||
| sides | |||
| since | |||
| small | |||
| smaller | |||
| smallest | |||
| so | |||
| some | |||
| somebody | |||
| someone | |||
| something | |||
| somewhere | |||
| state | |||
| states | |||
| still | |||
| still | |||
| such | |||
| sure | |||
| t | |||
| take | |||
| taken | |||
| than | |||
| that | |||
| the | |||
| their | |||
| them | |||
| then | |||
| there | |||
| therefore | |||
| these | |||
| they | |||
| thing | |||
| things | |||
| think | |||
| thinks | |||
| this | |||
| those | |||
| though | |||
| thought | |||
| thoughts | |||
| three | |||
| through | |||
| thus | |||
| to | |||
| today | |||
| together | |||
| too | |||
| took | |||
| toward | |||
| turn | |||
| turned | |||
| turning | |||
| turns | |||
| two | |||
| u | |||
| uh | |||
| um | |||
| under | |||
| until | |||
| up | |||
| upon | |||
| us | |||
| use | |||
| used | |||
| uses | |||
| v | |||
| very | |||
| w | |||
| want | |||
| wanted | |||
| wanting | |||
| wants | |||
| was | |||
| way | |||
| ways | |||
| we | |||
| well | |||
| wells | |||
| went | |||
| were | |||
| what | |||
| when | |||
| where | |||
| whether | |||
| which | |||
| while | |||
| who | |||
| whole | |||
| whose | |||
| why | |||
| will | |||
| with | |||
| within | |||
| without | |||
| work | |||
| worked | |||
| working | |||
| works | |||
| would | |||
| x | |||
| y | |||
| ya | |||
| ye | |||
| year | |||
| years | |||
| yet | |||
| you | |||
| young | |||
| younger | |||
| youngest | |||
| your | |||
| yours | |||
| z | |||
| @@ -0,0 +1,39 @@ | |||
| """Calculate ROUGE score.""" | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from typing import List | |||
| from rouge import Rouge | |||
| def get_rouge_score(hypothesis: List[str], target: List[str]): | |||
| """ | |||
| Calculate ROUGE score. | |||
| Args: | |||
| hypothesis (List[str]): Inference result. | |||
| target (List[str]): Reference. | |||
| """ | |||
| if not hypothesis or not target: | |||
| raise ValueError(f"`hypothesis` and `target` can not be None.") | |||
| _rouge = Rouge() | |||
| print("hypothesis:", hypothesis) | |||
| print("target:", target) | |||
| scores = _rouge.get_scores(hypothesis, target, avg=True) | |||
| print(" | ROUGE Score:") | |||
| print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}") | |||
| print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}") | |||
| print(f" | RG-L(F): {scores['rouge-l']['f'] * 100:8.2f}") | |||
| return scores | |||
| @@ -0,0 +1,186 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| task utils | |||
| """ | |||
| import regex as re | |||
| from mindspore.ops import operations as P | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| # for lambada task | |||
| def extract_logits(logits=None, position=None): | |||
| """ | |||
| Args | |||
| logits (Tensor): Tensor(batch_size,seq_length,vocab_size) e.g.(8,1024,50257) | |||
| position (numpy.array): the array stored the fianl word position, shape with [batch_size, 2] | |||
| Return: | |||
| output_logits (Tensor): extract the Specified logit according to the position, | |||
| shape with [batch_size, vocab_size] | |||
| """ | |||
| batch_size = logits.shape[0] | |||
| for batch_idx in range(batch_size): | |||
| word_logits_pos = int(position[batch_idx, 0] - 1) | |||
| logit = logits[batch_idx:batch_idx+1:1, word_logits_pos, ::] # [1, vocab_size] | |||
| if batch_idx == 0: | |||
| output_logits = logit | |||
| else: | |||
| output_logits = P.Concat()((output_logits, logit)) # [batch_size, vocab_size] | |||
| return output_logits | |||
| def get_final_word_label(input_ids, input_length, tokenizer=None): | |||
| """ | |||
| get whole word label_str from input_ids | |||
| Args: | |||
| input_ids: Tensor(batch_size,seq_length), indices of input text | |||
| config: GPT2Config, config of GPT2 model, if not initiated, | |||
| this function will create a MockConfig by params of input_ids, optional | |||
| tokenizer: GPT2Tokenizer, if not initiated, it will be created using the default setting in utils. tokenization, | |||
| optional | |||
| Returns: | |||
| batch_word_label: [str], lastword str given lambada as label | |||
| """ | |||
| input_ids_np = input_ids.asnumpy() | |||
| input_length_np = input_length.asnumpy() | |||
| batch_word_label = [] | |||
| for batch_idx in range(len(input_ids_np)): | |||
| word_spos = input_length_np[batch_idx, 0] | |||
| word_epos = input_length_np[batch_idx, 1] | |||
| final_word_ids = input_ids_np[batch_idx, word_spos:word_epos] | |||
| final_word_str = tokenizer.decode(final_word_ids.tolist()) | |||
| batch_word_label.append(final_word_str) | |||
| return batch_word_label | |||
| def calculate_final_word_loss(logits, batch_size, input_ids, input_length, loss): | |||
| """ | |||
| Calculate the last word loss. | |||
| """ | |||
| logits = logits.asnumpy() | |||
| input_len_np = input_length.asnumpy() | |||
| input_ids_np = input_ids.asnumpy() | |||
| sum_batch_loss = 0.0 | |||
| for batch in range(batch_size): | |||
| lastword_spos = input_len_np[batch, 0] | |||
| lastword_epos = input_len_np[batch, 1] | |||
| last_word_logits = logits[batch, lastword_spos - 1:lastword_epos - 1:1, ::] | |||
| last_word_logits_tensor = Tensor(last_word_logits, mstype.float32) | |||
| last_word_label = input_ids_np[batch, lastword_spos:lastword_epos:1] | |||
| print("last word label: ", last_word_label) | |||
| last_word_label_tensor = Tensor(last_word_label, mstype.int32) | |||
| last_word_loss = loss(last_word_logits_tensor, last_word_label_tensor) | |||
| last_word_loss = float(last_word_loss.asnumpy()) | |||
| sum_batch_loss += last_word_loss | |||
| print(" | loss: ", last_word_loss) | |||
| avg_batch_loss = float(sum_batch_loss / batch_size) | |||
| return avg_batch_loss | |||
| # for cbt task | |||
| def calculate_choice_prob_for_cbt(logits, batch_size, input_length, input_ids): | |||
| """ | |||
| calculate choice prob for cbt | |||
| Args: | |||
| logits: | |||
| batch_size: Any | |||
| input_length: {asnumpy} | |||
| input_ids: {asnumpy} | |||
| Returns: | |||
| choice_prob: List[float] | |||
| """ | |||
| choice_prob = [] # [batch_size] | |||
| logits = logits.asnumpy() | |||
| input_len_np = input_length.asnumpy() | |||
| input_ids_np = input_ids.asnumpy() | |||
| for batch in range(batch_size): | |||
| sum_ = 0.0 | |||
| rest_spos = input_len_np[batch, 0] | |||
| rest_epos = input_len_np[batch, 1] + 1 | |||
| for rest_pos in range(rest_spos - 1, rest_epos - 1): | |||
| rest_token_id = input_ids_np[batch, rest_pos + 1] | |||
| log_prob = logits[batch, rest_pos, rest_token_id] | |||
| sum_ = sum_ + log_prob | |||
| choice_prob.append(sum_) | |||
| print("rest sentence prob: ", sum_) | |||
| return choice_prob | |||
| # for summarization task | |||
| def modify_paramdict(param_dict, mode="zero-shot", model_prefix="gpt2."): | |||
| """ | |||
| modify keys of param_dict to fit model. | |||
| Args: | |||
| param_dic: dict, dictionary of parameters imported from a ckpt file | |||
| mode: str, "zero-shot" for an pretrained GPT2 model; | |||
| "finetune" for an finetuned model for certain task. | |||
| Return: | |||
| reorganized_param_dict: dict, new param_dict to fit in model for different tasks. | |||
| """ | |||
| final_param_dict = dict() | |||
| if mode == "zero-shot": | |||
| for name in param_dict: | |||
| final_param_dict[model_prefix + name] = param_dict[name] | |||
| final_param_dict['lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||
| elif mode == "finetuned": | |||
| embedding_name = "gpt2_embedding_lookup.embedding_table" | |||
| embedding_name_old = "" | |||
| for name in param_dict: | |||
| name_remove_prefix = name[len(model_prefix):] | |||
| name_prefix = name[:len(model_prefix)] | |||
| final_param_dict[name_remove_prefix] = param_dict[name] | |||
| if embedding_name in name and name_prefix == model_prefix: | |||
| embedding_name_old = name | |||
| final_param_dict[embedding_name] = param_dict[embedding_name_old] | |||
| else: | |||
| raise ValueError("mode should be [zero-shot, finetuned]") | |||
| return final_param_dict | |||
| def clean_hypo(text): | |||
| """ | |||
| to prevent generation of empty string, and lower text | |||
| Arg: | |||
| text: str, input str | |||
| Return: | |||
| text: str, cleaned input str | |||
| """ | |||
| text = text.lower() | |||
| eng_re = re.compile(r'[a-z]+', re.I) | |||
| length_con = len(eng_re.findall(text)) | |||
| if length_con == 0: | |||
| return '<EMPTY>' | |||
| return text | |||
| @@ -0,0 +1,217 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| tensor manipulations | |||
| """ | |||
| import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops import operations as P | |||
| def extract_string_from_tensor(input_ids, mode="single", config=None, tokenizer=None): | |||
| """ | |||
| Args: | |||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||
| mode (str): ["pair", "single"] | |||
| "pair" for tasks with paired inputs `<bos> A <eos> B <eos>`, | |||
| such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`, | |||
| reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`. | |||
| "single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task. | |||
| config: the configuration of GPT-2 model. | |||
| tokenizer: the tokenizer of GPT-2 model. | |||
| Return: | |||
| prompt_list (list): list of prompt_text | |||
| reference_list (list): list of reference_text, or second part of text | |||
| rest_list (list): list of rest_text, or rest part of text | |||
| """ | |||
| batch_size = config.batch_size | |||
| seq_length = config.seq_length | |||
| prompt_list = [""] * batch_size | |||
| reference_list = [""] * batch_size | |||
| eos_text = tokenizer.eos_token | |||
| len_eos_text = len(eos_text) | |||
| input_ids_np = input_ids.asnumpy() | |||
| input_ids_np = input_ids_np.reshape((batch_size, seq_length)) | |||
| # input_ids = P.Reshape()(input_ids, (batch_size, seq_length)) | |||
| if mode == "pair": | |||
| for batch_idx in range(batch_size): | |||
| sentence_tensor = input_ids_np[batch_idx] | |||
| sentence_list = sentence_tensor.asnumpy().tolist()[1:] | |||
| sentence = tokenizer.decode(sentence_list) | |||
| prompt_start = 0 | |||
| prompt_end = sentence.find(eos_text, 0) | |||
| reference_start = prompt_end + len_eos_text | |||
| reference_end = sentence[reference_start:].find( | |||
| eos_text, 0) + reference_start | |||
| prompt_list[batch_idx] = sentence[prompt_start:prompt_end] | |||
| reference_list[batch_idx] = sentence[reference_start:reference_end] | |||
| return prompt_list, reference_list | |||
| # For single output datasets such as WikiText, etc. | |||
| if mode == "single": | |||
| for batch_idx in range(batch_size): | |||
| sentence_tensor = input_ids_np[batch_idx] | |||
| sentence_list = sentence_tensor.asnumpy().tolist()[1:] | |||
| sentence = tokenizer.decode(sentence_list) | |||
| prompt_start = 0 | |||
| prompt_end = sentence.find(eos_text, 0) | |||
| prompt_list[batch_idx] = sentence[prompt_start:prompt_end] | |||
| else: | |||
| raise NotImplementedError('mode:{} not supported.'.format(mode)) | |||
| return prompt_list | |||
| def extract_single_token_logits(logits=None, seq_pos=None): | |||
| """ | |||
| Args | |||
| logits: (batch_size,seq_length,vocab_size) e.g. when batchsize is 8, | |||
| sequence length is 1024 and vocab_size is 50257, | |||
| then logits is a Tensor with shape (8,1024,50257) | |||
| seq_pos:(batch_size) list | |||
| Return: | |||
| output_logits: (batch_size,1,vocab_size) extract the logit to predict the last token. | |||
| """ | |||
| batch_size = logits.shape[0] | |||
| logits_np = logits.asnumpy() | |||
| logits_type = P.DType()(logits) | |||
| for i in range(batch_size): | |||
| # logit = logits[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::] | |||
| logit_np = logits_np[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::] | |||
| if i == 0: | |||
| # output_logits = logit | |||
| output_logits = logit_np | |||
| else: | |||
| # output_logits = P.Concat()((output_logits, logit)) | |||
| output_logits = np.concatenate((output_logits, logit_np), axis=0) | |||
| output_logits = Tensor(output_logits, dtype=logits_type) | |||
| return output_logits | |||
| def get_last_one_pos(input_mask: Tensor): | |||
| """ | |||
| Arg: | |||
| input_mask (Tensor): (batch_size,seq_length) | |||
| Return: | |||
| pos (Tensor): (batch_size,) | |||
| """ | |||
| input_mask_ = P.Cast()(input_mask, mstype.float32) | |||
| pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,) | |||
| pos = P.Cast()(pos, mstype.int32) | |||
| pos = pos - 1 | |||
| return pos | |||
| def get_next_one_pos(input_mask: Tensor): | |||
| """ | |||
| Arg: | |||
| input_mask (Tensor): (batch_size,seq_length) | |||
| """ | |||
| input_mask_ = P.Cast()(input_mask, mstype.float32) | |||
| pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,) | |||
| pos = P.Cast()(pos, mstype.int32) | |||
| return pos | |||
| def add_last_token_mask(input_mask: Tensor, overflow_strategy: str = "shift"): | |||
| """ | |||
| add last token mask | |||
| Args: | |||
| input_mask: Tensor | |||
| overflow_strategy: str | |||
| Returns: | |||
| Tensor | |||
| """ | |||
| pos = get_next_one_pos(input_mask).asnumpy() | |||
| input_mask_np = input_mask.asnumpy() | |||
| maximum_length = input_mask.shape[1] | |||
| batch_size = input_mask.shape[0] | |||
| for idx in range(batch_size): | |||
| # not overflow | |||
| if pos[idx] < maximum_length: | |||
| input_mask_np[idx][pos[idx]] = 1 | |||
| # overflow | |||
| else: | |||
| if overflow_strategy == "shift": | |||
| continue | |||
| if overflow_strategy == "truncate": | |||
| continue | |||
| else: | |||
| raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy)) | |||
| return Tensor(input_mask_np, dtype=mstype.int32) | |||
| def add_last_token(input_ids: Tensor, input_mask: Tensor, overflow_strategy: str = "shift", append_ids=None, | |||
| next_token_pos=None): | |||
| """ | |||
| add last token | |||
| Args: | |||
| input_ids: Tensor | |||
| input_mask: Tensor | |||
| overflow_strategy: str | |||
| append_ids: Any | |||
| next_token_pos: Any | |||
| Returns: | |||
| Tensor | |||
| """ | |||
| # get positional list/numpy array | |||
| if next_token_pos is None: | |||
| pos = get_next_one_pos(input_mask).asnumpy() | |||
| else: | |||
| pos = next_token_pos | |||
| # get numpy of inputs | |||
| input_mask_np = input_mask.asnumpy() | |||
| input_ids_np = input_ids.asnumpy() | |||
| maximum_length = int(input_mask.shape[1]) | |||
| batch_size = int(input_mask.shape[0]) | |||
| for idx in range(batch_size): | |||
| # not overflow | |||
| if pos[idx] < maximum_length: | |||
| input_mask_np[idx][int(pos[idx])] = 1 | |||
| input_ids_np[idx][int(pos[idx])] = append_ids[idx] | |||
| # overflow | |||
| else: | |||
| if overflow_strategy == "shift": | |||
| # shift one token left | |||
| input_ids_np[idx][0:maximum_length - 1] = input_ids_np[idx][1:maximum_length] | |||
| input_ids_np[idx][maximum_length - 1] = append_ids[idx] | |||
| continue | |||
| if overflow_strategy == "truncate": | |||
| # do nothing | |||
| continue | |||
| else: | |||
| raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy)) | |||
| return Tensor(input_ids_np, dtype=mstype.int32), Tensor(input_mask_np, dtype=mstype.int32) | |||
| @@ -0,0 +1,517 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| tokenization | |||
| """ | |||
| import json | |||
| from functools import lru_cache | |||
| from typing import List, Optional | |||
| import logging | |||
| import regex as re | |||
| logger = logging.getLogger(__name__) | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| bytes to unicode | |||
| """ | |||
| bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2 ** 8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2 ** 8 + n) | |||
| n += 1 | |||
| cs = [chr(i) for i in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """ | |||
| Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| class GPT2Tokenizer(): | |||
| """ | |||
| GPT2Tokenizer | |||
| """ | |||
| def __init__( | |||
| self, | |||
| vocab_file, | |||
| merge_file, | |||
| add_prefix_space=False, | |||
| ): | |||
| with open(vocab_file, 'r', encoding="utf-8") as vocab_handle: | |||
| self.encoder = json.load(vocab_handle) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.vocab_size = len(self.decoder) | |||
| with open(merge_file, 'r', encoding="utf-8") as merge_handle: | |||
| bpe_merges = merge_handle.read().split('\n')[1:-1] | |||
| bpe_merges = [tuple(merge.split()) for merge in bpe_merges] | |||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | |||
| self.add_prefix_space = add_prefix_space | |||
| self.cache = {} | |||
| self.unk_token = "<|endoftext|>" | |||
| self.unk_token_id = 50256 | |||
| self.bos_token = "<|endoftext|>" | |||
| self.bos_token_id = 50256 | |||
| self.eos_token = "<|endoftext|>" | |||
| self.eos_token_id = 50256 | |||
| self.pad_token = "<|endoftext|>" | |||
| self.pad_token_id = 50256 | |||
| def bpe(self, token): | |||
| """ | |||
| bpe encode | |||
| """ | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token) | |||
| pairs = get_pairs(token) | |||
| if not pairs: | |||
| return token | |||
| while True: | |||
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| except ValueError: | |||
| new_word.extend(word[i:]) | |||
| break | |||
| else: | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| if word[i] == first and i + 1 < len(word) and word[i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = " ".join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def _tokenize(self, text): | |||
| """ Tokenize a string using bpe encode. """ | |||
| text = self.prepare_for_tokenization(text, is_pretokenized=False) | |||
| # print(text) | |||
| bpe_tokens = [] | |||
| for token in re.findall(self.pat, text): | |||
| token = "".join( | |||
| self.byte_encoder[b] for b in token.encode("utf-8") | |||
| ) | |||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) | |||
| return bpe_tokens | |||
| def _convert_token_to_id(self, token): | |||
| """ the index of the token in the vocabulary. """ | |||
| return self.encoder.get(token, self.encoder.get(self.unk_token)) | |||
| def _convert_id_to_token(self, _id): | |||
| """ return the origin bpe token according to id""" | |||
| return self.decoder.get(_id) | |||
| def _convert_tokens_to_string(self, tokens): | |||
| """ return a string according to the list of tokens""" | |||
| text = "".join(tokens) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors='ignore') | |||
| return text | |||
| def encode(self, text): | |||
| """ get the index list of text""" | |||
| text_id = [] | |||
| bpe_tokens = self._tokenize(text) | |||
| for token in bpe_tokens: | |||
| text_id.append(self._convert_token_to_id(token)) | |||
| return text_id | |||
| def decode(self, ids): | |||
| """ return a string according to the index list of tokens""" | |||
| tokens = [] | |||
| for id_ in ids: | |||
| tokens.append(self._convert_id_to_token(id_)) | |||
| return self._convert_tokens_to_string(tokens) | |||
| def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs): | |||
| """ whether to add a whitespace in the front of text """ | |||
| add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) | |||
| if is_pretokenized or add_prefix_space: | |||
| text = " " + text | |||
| return text | |||
| def add_special_tokens(self, special_tokens_dict): | |||
| """ | |||
| Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If | |||
| special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the | |||
| current vocabulary). | |||
| Args: | |||
| special_tokens_dict (dictionary `str` to `str`): | |||
| Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, | |||
| ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, | |||
| ``additional_special_tokens``]. | |||
| Returns: | |||
| added_tokens (int): Number of tokens added to the vocabulary | |||
| """ | |||
| # special_tokens_dict = {'cls_token': '<CLS>'} | |||
| if not special_tokens_dict: | |||
| return 0 | |||
| added_tokens = 0 | |||
| for key, value in special_tokens_dict.items(): | |||
| setattr(self, key, value) | |||
| assert isinstance(value, str), f"Token {value} for key {key} should be a str instance" | |||
| added_tokens += self.add_tokens([value], special_tokens=True) | |||
| return added_tokens | |||
| def add_tokens(self, new_tokens, special_tokens=False): | |||
| if not new_tokens: | |||
| return 0 | |||
| if not isinstance(new_tokens, (list, tuple)): | |||
| new_tokens = [new_tokens] | |||
| return self._add_tokens(new_tokens, special_tokens=special_tokens) | |||
| def _add_tokens(self, new_tokens, special_tokens=False): | |||
| """ | |||
| _add_tokens | |||
| Args: | |||
| new_tokens (list[str]): Token(s) to add in vocabulary. | |||
| special_tokens (bool): Whether or not the tokens should be added as special tokens. | |||
| Returns: | |||
| the number of the new added tokens. | |||
| """ | |||
| new_tokens = [str(token) for token in new_tokens] | |||
| tokens_to_add = [] | |||
| for token in new_tokens: | |||
| assert isinstance(token, str) | |||
| tokens_to_add.append(token) | |||
| logger.info("Adding %s to the vocabulary ! ", token) | |||
| added_tok_encoder = dict((tok, self.vocab_size + i) for i, tok in enumerate(tokens_to_add)) | |||
| added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} | |||
| self.encoder.update(added_tok_encoder) | |||
| self.decoder.update(added_tok_decoder) | |||
| return len(tokens_to_add) | |||
| def num_special_tokens_to_add(self, pair: bool = False): | |||
| token_ids_0 = [] | |||
| token_ids_1 = [] | |||
| return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) | |||
| def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None): | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence by concatenating and adding special tokens. | |||
| A GPT2 sequence has the following format: | |||
| - single sequence: ``<bos> X <eos>`` | |||
| - pair of sequences: ``<bos> A <eos> B <eos>`` | |||
| Args: | |||
| token_ids_0 (List[int]): List of IDs to which the special tokens will be added | |||
| token_ids_1 (List[int], `optional`, defaults to `None`): Optional second list of IDs for sequence pairs. | |||
| """ | |||
| bos = [self.bos_token_id] | |||
| eos = [self.eos_token_id] | |||
| if token_ids_1 is None: | |||
| return bos + token_ids_0 + eos | |||
| return bos + token_ids_0 + eos + token_ids_1 + eos | |||
| def truncate_sequences(self, ids, num_tokens_to_remove, truncation_strategy="ONLY_FIRST", direction="RIGHT"): | |||
| """ | |||
| truncate sequences | |||
| Args: | |||
| ids: Any | |||
| num_tokens_to_remove: | |||
| truncation_strategy: str | |||
| direction: str | |||
| Returns: | |||
| (ids, overflowing_tokens): (Any, list) | |||
| """ | |||
| if num_tokens_to_remove <= 0: | |||
| return ids, [] | |||
| overflowing_tokens = [] | |||
| if truncation_strategy == "ONLY_FIRST": | |||
| if len(ids) > num_tokens_to_remove: | |||
| if direction == "RIGHT": | |||
| overflowing_tokens = ids[-num_tokens_to_remove:] | |||
| ids = ids[:-num_tokens_to_remove] | |||
| if direction == "LEFT": | |||
| overflowing_tokens = ids[:num_tokens_to_remove] | |||
| ids = ids[num_tokens_to_remove:] | |||
| else: | |||
| logger.error("The first sequence length is smaller than removed tokens. ") | |||
| else: | |||
| logger.error("Please select correct truncation strategy, for instance 'ONLY_FIRST'") | |||
| return (ids, overflowing_tokens) | |||
| def _pad(self, encoded_inputs, max_length=None, padding_strategy=None, | |||
| return_attention_mask: Optional[bool] = None): | |||
| """ | |||
| _pad | |||
| Args: | |||
| encoded_inputs: | |||
| max_length: Any | |||
| padding_strategy: Any | |||
| return_attention_mask: Optional[bool] | |||
| Returns: | |||
| encoded_inputs: | |||
| """ | |||
| needs_to_be_padded = (len(encoded_inputs["input_ids"]) != max_length) | |||
| if needs_to_be_padded: | |||
| if padding_strategy == "MAX_LENGTH": | |||
| difference = max_length - len(encoded_inputs["input_ids"]) | |||
| if return_attention_mask: | |||
| encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference | |||
| encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference | |||
| else: | |||
| raise ValueError("Invalid padding strategy") | |||
| else: | |||
| if return_attention_mask: | |||
| encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) | |||
| return encoded_inputs | |||
| def pad(self, encoded_inputs, max_length: Optional[int] = None, padding_strategy="MAX_LENGTH", | |||
| return_attention_mask=True): | |||
| """ | |||
| pad | |||
| Args: | |||
| encoded_inputs: | |||
| max_length: Optional[int] | |||
| padding_strategy: str | |||
| return_attention_mask: bool | |||
| Returns: | |||
| batch_outputs: Dict[Any, list] | |||
| """ | |||
| # no batch encoded_inputs["input_ids"]--->[98, 67, 32388, 318, 1912, 287, 170, 8496, 318, 905, 2667, 32] | |||
| if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)): | |||
| encoded_inputs = self._pad( | |||
| encoded_inputs, | |||
| max_length=max_length, | |||
| padding_strategy=padding_strategy, | |||
| return_attention_mask=return_attention_mask | |||
| ) | |||
| return encoded_inputs | |||
| # encoded_inputs with batch_size | |||
| batch_size = len(encoded_inputs["input_ids"]) | |||
| assert all( | |||
| len(v) == batch_size for v in encoded_inputs.values() | |||
| ), "Some items in the output dictionary have a different batch size than others." | |||
| if padding_strategy == "LONGEST": | |||
| max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"]) | |||
| padding_strategy = "MAX_LENGTH" | |||
| batch_outputs = {} | |||
| for i in range(batch_size): | |||
| inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) | |||
| outputs = self._pad( | |||
| encoded_inputs=inputs, | |||
| max_length=max_length, | |||
| padding_strategy=padding_strategy, | |||
| return_attention_mask=return_attention_mask | |||
| ) | |||
| for key, value in outputs.items(): | |||
| if key not in batch_outputs: | |||
| batch_outputs[key] = [] | |||
| batch_outputs[key].append(value) | |||
| return batch_outputs | |||
| def prepare_for_model(self, | |||
| ids, | |||
| pair_ids=None, | |||
| add_special_tokens=True, | |||
| max_length=None, | |||
| padding=None, | |||
| truncate_direction="RIGHT", | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True): | |||
| """ | |||
| prepare for model | |||
| Args: | |||
| ids: | |||
| pair_ids: | |||
| add_special_tokens: bool | |||
| max_length: Any | |||
| padding: Any | |||
| truncate_direction: str | |||
| return_overflowing_tokens: bool | |||
| return_attention_mask: bool | |||
| Returns: | |||
| encoded_inputs:Dict | |||
| """ | |||
| pair = bool(pair_ids is not None) | |||
| len_ids = len(ids) | |||
| len_pair_ids = len(pair_ids) if pair else 0 | |||
| encoded_inputs = {} | |||
| # Compute the total size of the returned encodings | |||
| total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) | |||
| # Truncation: Handle max sequence length | |||
| if max_length and total_len > max_length: | |||
| ids, overflowing_tokens = self.truncate_sequences(ids=ids, | |||
| num_tokens_to_remove=total_len - max_length, | |||
| truncation_strategy="ONLY_FIRST", | |||
| direction=truncate_direction) | |||
| if return_overflowing_tokens: | |||
| encoded_inputs["overflowing_tokens"] = overflowing_tokens | |||
| encoded_inputs["num_truncated_tokens"] = total_len - max_length | |||
| if add_special_tokens: | |||
| sequence = self.build_inputs_with_special_tokens(ids, pair_ids) | |||
| else: | |||
| sequence = ids + pair_ids if pair else ids | |||
| # build output dictionary | |||
| encoded_inputs["input_ids"] = sequence | |||
| # check lengths | |||
| if max_length is None or len(encoded_inputs["input_ids"]) > max_length: | |||
| logger.warning( | |||
| "Token indices sequence length is longer than the specified maximum sequence length " | |||
| "for this model (%ids > %length). Running this sequence through the model will result in " | |||
| "indexing errors", len(ids), max_length | |||
| ) | |||
| # padding | |||
| if padding or return_attention_mask: | |||
| encoded_inputs = self.pad(encoded_inputs=encoded_inputs, | |||
| max_length=max_length, | |||
| padding_strategy="MAX_LENGTH", | |||
| return_attention_mask=return_attention_mask) | |||
| return encoded_inputs | |||
| class CNN_DailyMail_tokenizer(GPT2Tokenizer): | |||
| """ | |||
| CNN DailyMail tokenizer | |||
| """ | |||
| def prepare_for_model(self, | |||
| ids, | |||
| pair_ids, | |||
| max_length=1024, | |||
| max_summary_length=150, | |||
| add_special_tokens=True, | |||
| padding=None, | |||
| return_overflowing_tokens=False, | |||
| return_attention_mask=True): | |||
| len_ids = len(ids) | |||
| len_pair_ids = len(pair_ids) | |||
| encoded_inputs = {} | |||
| # Compute the total size of the returned encodings | |||
| total_len = len_ids + len_pair_ids | |||
| ids_overflowing_tokens = [] | |||
| pair_overflowing_tokens = [] | |||
| # Truncation: Handle max sequence length | |||
| if total_len > max_length-3: | |||
| if len_pair_ids > max_summary_length: | |||
| num_tokens_to_remove = len_pair_ids - max_summary_length | |||
| pair_ids, pair_overflowing_tokens = self.truncate_sequences(ids=pair_ids, | |||
| num_tokens_to_remove=num_tokens_to_remove, | |||
| truncation_strategy="ONLY_FIRST", | |||
| direction="RIGHT") | |||
| if len_ids+max_summary_length > max_length-3: | |||
| num_tokens_to_remove = (len_ids + max_summary_length) - (max_length - 3) | |||
| ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids, | |||
| num_tokens_to_remove=num_tokens_to_remove, | |||
| truncation_strategy="ONLY_FIRST", | |||
| direction="RIGHT") | |||
| else: | |||
| ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids, | |||
| num_tokens_to_remove=total_len - (max_length-3), | |||
| truncation_strategy="ONLY_FIRST", | |||
| direction="RIGHT") | |||
| if return_overflowing_tokens: | |||
| encoded_inputs["article_overflowing_tokens"] = ids_overflowing_tokens | |||
| encoded_inputs["highlights_overflowing_tokens"] = pair_overflowing_tokens | |||
| encoded_inputs["num_truncated_tokens"] = total_len - (max_length-3) | |||
| sequence = self.build_inputs_with_special_tokens(ids, pair_ids) | |||
| encoded_inputs["input_ids"] = sequence | |||
| # check lengths | |||
| if max_length is None or len(encoded_inputs["input_ids"]) > max_length: | |||
| logger.warning( | |||
| "Token indices sequence length is longer than the specified maximum sequence length " | |||
| "for this model (%ids > %length). Running this sequence through the model will result " | |||
| "in indexing errors", len(ids), max_length | |||
| ) | |||
| # padding | |||
| if padding or return_attention_mask: | |||
| encoded_inputs = self.pad(encoded_inputs=encoded_inputs, | |||
| max_length=max_length, | |||
| padding_strategy="MAX_LENGTH", | |||
| return_attention_mask=return_attention_mask) | |||
| return encoded_inputs | |||
| def Tokenizer(vocab_file="./pretrain-data/gpt2-vocab.json", | |||
| merge_file="./pretrain-data/gpt2-merges.txt", | |||
| mode="normal"): | |||
| """ use the GPT2Tokenizer""" | |||
| print(" | Tokenizer mode: {}".format(mode)) | |||
| if mode == "normal": | |||
| tokenizer = GPT2Tokenizer(vocab_file, merge_file, add_prefix_space=False) | |||
| elif mode == "cnn_dailymail": | |||
| tokenizer = CNN_DailyMail_tokenizer(vocab_file, merge_file, add_prefix_space=False) | |||
| else: | |||
| raise ValueError("No Such Mode for {} in src.utils.tokenization.Tokenizer()".format(mode)) | |||
| return tokenizer | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| init weight | |||
| """ | |||
| import math | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| def _average_units(shape): | |||
| if not shape: | |||
| return 1 | |||
| if len(shape) == 1: | |||
| return float(shape[0]) | |||
| if len(shape) == 2: | |||
| return float(shape[0] + shape[1]) / 2. | |||
| raise RuntimeError("not support shape.") | |||
| def weight_variable(shape): | |||
| scale_shape = shape | |||
| avg_units = _average_units(scale_shape) | |||
| scale = 1.0 / max(1., avg_units) | |||
| limit = math.sqrt(3.0 * scale) | |||
| values = np.random.uniform(-limit, limit, shape).astype(np.float32) | |||
| return Tensor(values) | |||
| def one_weight(shape): | |||
| ones = np.ones(shape).astype(np.float32) | |||
| return Tensor(ones) | |||
| def zero_weight(shape): | |||
| zeros = np.zeros(shape).astype(np.float32) | |||
| return Tensor(zeros) | |||
| def normal_weight(shape, num_units): | |||
| norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32) | |||
| return Tensor(norm) | |||
| @@ -0,0 +1,86 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """dataset preprocess""" | |||
| import argparse | |||
| from src.utils.data_preprocess import lambada_dataset_preprocess | |||
| from src.utils.data_preprocess import cbt_dataset_preprocess | |||
| from src.utils.data_preprocess import wikitext_dataset_preprocess | |||
| from src.utils.data_preprocess import ptb_dataset_preprocess | |||
| from src.utils.data_preprocess import onebw_dataset_preprocess | |||
| from src.utils.data_preprocess import coqa_dataset_preprocess | |||
| from src.utils.data_preprocess import wmt14_en_fr_preprocess | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="All Task dataset preprocessing") | |||
| parser.add_argument("--task", type=str, default="translation", | |||
| help="The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada" | |||
| "Summarization, ReadingComprehension]") | |||
| parser.add_argument("--input_file", type=str, default="", | |||
| help="The raw dataset path. ") | |||
| parser.add_argument("--dataset", type=str, default="onebw", | |||
| help="The name of dataset which should be processed, only for LanguageModeling task.") | |||
| parser.add_argument("--output_file", type=str, default="", | |||
| help="The output dataset path after preprocessing.") | |||
| parser.add_argument("--condition", type=str, default="test", | |||
| help="Process train or test dataset, including [train, test], only for 1BW and " | |||
| "CNN & DailyMail dataset.") | |||
| args_opt = parser.parse_args() | |||
| task = args_opt.task | |||
| condition = args_opt.condition | |||
| dataset = args_opt.dataset | |||
| input_file = args_opt.input_file | |||
| output_file = args_opt.output_file | |||
| if task.lower() == "languagemodeling": | |||
| print("Start processing Language Modeling dataset ...") | |||
| if dataset.lower() == "wikitext2" or dataset.lower() == "wikitext103": | |||
| wikitext_dataset_preprocess(input_file=input_file, output_file=output_file) | |||
| elif dataset.lower() == "ptb": | |||
| ptb_dataset_preprocess(input_file=input_file, output_file=output_file) | |||
| elif dataset.lower() == "onebw": | |||
| onebw_dataset_preprocess(condition, input_file=input_file, output_file=output_file) | |||
| else: | |||
| raise ValueError("Only support wikitext2, wikitext103, ptb, onebw dataset") | |||
| elif task.lower() == "lambada": | |||
| print("Start processing Lambada dataset ...") | |||
| lambada_dataset_preprocess(input_file=input_file, output_file=output_file) | |||
| elif task.lower() == "cbt": | |||
| print("Start processing CBT dataset ...") | |||
| cbt_dataset_preprocess(input_file=input_file, output_file=output_file) | |||
| elif task.lower() == "readingcomprehension": | |||
| print("Start processing ReadingComprehension dataset ...") | |||
| coqa_dataset_preprocess(input_file=input_file, output_file=output_file) | |||
| elif task.lower() == "summarization": | |||
| print("Start processing Summarization dataset ...") | |||
| elif task.lower() == "translation": | |||
| print("Start processing Translation dataset ...") | |||
| wmt14_en_fr_preprocess(input_file=input_file, output_file=output_file) | |||
| else: | |||
| raise ValueError("Only support Language Modeling, CBT, Translation, Lambada, " | |||
| "Summarization, Reading Comprehension task.") | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2017 Google Inc. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Python implementation of BLEU and smooth-BLEU. | |||
| This module provides a Python implementation of BLEU and smooth-BLEU. | |||
| Smooth BLEU is computed following the method outlined in the paper: | |||
| Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic | |||
| evaluation metrics for machine translation. COLING 2004. | |||
| """ | |||
| import collections | |||
| import math | |||
| def _get_ngrams(segment, max_order): | |||
| """ | |||
| Extracts all n-grams upto a given maximum order from an input segment. | |||
| Args: | |||
| segment: text segment from which n-grams will be extracted. | |||
| max_order: maximum length in tokens of the n-grams returned by this | |||
| methods. | |||
| Returns: | |||
| The Counter containing all n-grams upto max_order in segment | |||
| with a count of how many times each n-gram occurred. | |||
| """ | |||
| ngram_counts = collections.Counter() | |||
| for order in range(1, max_order + 1): | |||
| for i in range(0, len(segment) - order + 1): | |||
| ngram = tuple(segment[i:i + order]) | |||
| ngram_counts[ngram] += 1 | |||
| return ngram_counts | |||
| def compute_bleu(reference_corpus, translation_corpus, max_order=4, | |||
| smooth=False): | |||
| """Computes BLEU score of translated segments against one or more references. | |||
| Args: | |||
| reference_corpus: list of lists of references for each translation. Each | |||
| reference should be tokenized into a list of tokens. | |||
| translation_corpus: list of translations to score. Each translation | |||
| should be tokenized into a list of tokens. | |||
| max_order: Maximum n-gram order to use when computing BLEU score. | |||
| smooth: Whether or not to apply Lin et al. 2004 smoothing. | |||
| Returns: | |||
| 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram | |||
| precisions and brevity penalty. | |||
| """ | |||
| matches_by_order = [0] * max_order | |||
| possible_matches_by_order = [0] * max_order | |||
| reference_length = 0 | |||
| translation_length = 0 | |||
| for (references, translation) in zip(reference_corpus, translation_corpus): | |||
| reference_length += min(len(r) for r in references) | |||
| translation_length += len(translation) | |||
| merged_ref_ngram_counts = collections.Counter() | |||
| for reference in references: | |||
| merged_ref_ngram_counts |= _get_ngrams(reference, max_order) | |||
| translation_ngram_counts = _get_ngrams(translation, max_order) | |||
| overlap = translation_ngram_counts & merged_ref_ngram_counts | |||
| for ngram in overlap: | |||
| matches_by_order[len(ngram) - 1] += overlap[ngram] | |||
| for order in range(1, max_order + 1): | |||
| possible_matches = len(translation) - order + 1 | |||
| if possible_matches > 0: | |||
| possible_matches_by_order[order - 1] += possible_matches | |||
| precisions = [0] * max_order | |||
| for i in range(0, max_order): | |||
| if smooth: | |||
| precisions[i] = ((matches_by_order[i] + 1.) / | |||
| (possible_matches_by_order[i] + 1.)) | |||
| else: | |||
| if possible_matches_by_order[i] > 0: | |||
| precisions[i] = (float(matches_by_order[i]) / | |||
| possible_matches_by_order[i]) | |||
| else: | |||
| precisions[i] = 0.0 | |||
| if min(precisions) > 0: | |||
| p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) | |||
| geo_mean = math.exp(p_log_sum) | |||
| else: | |||
| geo_mean = 0 | |||
| ratio = float(translation_length) / reference_length | |||
| if ratio > 1.0: | |||
| bp = 1. | |||
| else: | |||
| bp = math.exp(1 - 1. / ratio) | |||
| bleu = geo_mean * bp | |||
| return (bleu, precisions, bp, ratio, translation_length, reference_length) | |||