| @@ -0,0 +1,931 @@ | |||||
| # 目录 | |||||
| <!-- TOC --> | |||||
| - [目录](#目录) | |||||
| - [GPT-2模型](#GPT-2模型) | |||||
| - [模型架构](#模型架构) | |||||
| - [下游任务](#下游任务) | |||||
| - [脚本说明](#脚本说明) | |||||
| - [模型转换](#模型转换) | |||||
| - [准备数据集](#准备数据集) | |||||
| - [Language Modeling 语言建模任务](#Language Modeling语言建模任务) | |||||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||||
| - [LAMBADA 任务](#LAMBADA任务) | |||||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||||
| - [Summarization 任务](#Summarization任务) | |||||
| - [Translation 任务](#Translation任务) | |||||
| - [配置](#配置) | |||||
| - [微调&评估过程](#微调&训练评估过程) | |||||
| - [Language Modeling 任务](#Language Modeling任务) | |||||
| - 微调 | |||||
| - 评估 | |||||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||||
| - 评估 | |||||
| - [LAMBADA 任务](#LAMBADA任务) | |||||
| - 评估 | |||||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||||
| - 评估 | |||||
| - [Summarization 任务](#Summarization任务) | |||||
| - 评估 | |||||
| - [Translation 任务](#Translation任务) | |||||
| - 评估 | |||||
| - [环境要求](#环境要求) | |||||
| - [平台](#平台) | |||||
| - [其他要求](#其他要求) | |||||
| - [性能](#性能) | |||||
| - [推理性能](#推理性能) | |||||
| - [Language Modeling 任务](#Language Modeling任务) | |||||
| - [Children's Book Test 任务](#Children's Book Test任务) | |||||
| - [LAMBADA 任务](#LAMBADA任务) | |||||
| - [Reading Comprehension 任务](#Reading Comprehension任务) | |||||
| - [Summarization 任务](#Summarization任务) | |||||
| - [Translation 任务](#Translation任务) | |||||
| - [训练性能](#训练性能) | |||||
| - [推理性能](#推理性能) | |||||
| - [其他](#其他) | |||||
| - [ModelZoo主页](#modelzoo主页) | |||||
| <!-- /TOC --> | |||||
| # GPT-2模型 | |||||
| [GPT-2介绍](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) 由Open于2019年发布。GPT-2模型是继承于GPT模型,GPT-2是一个非常庞大的语言模型,它主要是用于预测下一个单词。按照参数量的大小,GPT-2模型可分为small(117M)、medium(345M)、large(762M)、xlarge(1542M)。 | |||||
| [GPT-2介绍](https://openai.com/blog/better-language-models/) | |||||
| [GPT-2论文](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf): Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9. | |||||
| # 模型架构 | |||||
| GPT-2模型由Transformer的解码器实现,Transformer包括多个编码器层和多个解码器层,但在GPT-2模型中仅使用了Transformer的解码器部分。 | |||||
| 微调时,根据不同的任务,采用不同的数据集对预训练的模型进行微调。 | |||||
| 测试过程中,通过微调后的模型预测结果,对于某些任务可以直接进行zero-shot评估即可。 | |||||
| # 下游任务 | |||||
| 本文主要涉及6个下游任务,包括: | |||||
| - Language Modeling 任务 | |||||
| - Children‘s Book Test 任务 | |||||
| - LAMBADA任务 | |||||
| - Reading Comprehension任务 | |||||
| - Summarization任务 | |||||
| - Translation任务 | |||||
| 数据集相关信息,参见[https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)。 | |||||
| ## 脚本说明 | |||||
| GPT-2脚本及代码结构如下: | |||||
| ```text | |||||
| ├── GPT-2 | |||||
| ├── README.md // MASS模型介绍 | |||||
| ├── scripts | |||||
| │ ├──run_cbt.sh // CBT任务的微调&评估脚本 | |||||
| │ ├──run_lambada.sh // LAMBADA任务的微调&评估脚本 | |||||
| │ ├──run_language_model.sh // 语言建模任务的微调&评估脚本 | |||||
| │ ├──run_read_comprehension.sh // 阅读理解任务的微调&评估脚本 | |||||
| │ ├──run_summarization.sh // 摘要生成任务的微调&评估脚本 | |||||
| │ ├──run_translation.sh // 翻译任务的微调&评估脚本 | |||||
| ├──src | |||||
| │ ├──clip_grad_utils.py // 用于梯度裁剪 | |||||
| | ├──dataset.py // 数据集加载用于微调或推理 | |||||
| │ ├──finetune_eval_config.py // 微调和推理配置文件 | |||||
| │ ├──gpt2_for_finetune.py // 用于梯度裁剪 | |||||
| | ├──GPT2_generation.py // 生成模块 | |||||
| │ ├──GPT2_model.py // GPT2模型脚本 | |||||
| │ ├──GPT2ForCBT.py // CBT任务的模型脚本 | |||||
| │ ├──GPT2ForLanguageModel.py // 语言建模任务的模型脚本 | |||||
| │ ├──GPT2ForReadComprehension.py // 阅读理解任务的模型脚本 | |||||
| │ ├──GPT2ForSummarization.py // 摘要生成任务的模型脚本 | |||||
| │ ├──GPT2ForTranslation.py // 翻译任务的模型脚本 | |||||
| │ ├──weight_init.py // 初始化权重 | |||||
| │ ├──utils | |||||
| │ ├──bleu_score.py // 用于计算BLEU分数 | |||||
| │ ├──rouge_score.py // 用于计算ROUGE分数 | |||||
| │ ├──CrossEntropy.py // 交叉熵损失 | |||||
| │ ├──data_preprocess.py // 数据集预处理脚本 | |||||
| │ ├──generation_utils.py // 用于帮助生成模型,包含采样等方法 | |||||
| │ ├──get_config_setting.py // 获取配置信息 | |||||
| │ ├──task_utils.py // 辅助下游任务的功能脚本 | |||||
| │ ├──lr_schedule.py // 学习率策略脚本 | |||||
| │ ├──metric_method.py // 下游任务的评价指标 | |||||
| │ ├──tensor_manipulations.py // 涉及张量操作 | |||||
| │ ├──tokenization.py // 标记化,包含BPE编码和解码 | |||||
| │ ├──pretrain-data | |||||
| │ ├──stopwords.txt // 用于LAMBADA任务的stopword filter | |||||
| ├──create_cbt_data.py // 用于CBT任务创建mindrecord | |||||
| ├──create_lambada_data.py // 用于lambada任务创建mindrecord | |||||
| ├──create_lambada_data.py // 用于其他任务创建mindrecord | |||||
| ├──create_summary_data.py // 用于summarization任务创建mindrecord | |||||
| ├──download_cnn_dailymail.py // 下载CNN & Dailymail数据集 | |||||
| ├──cnn_dataset_sampler.py // CNN & Dailymail训练集采样器 | |||||
| ├──eval_rc_addition_answer.py // 使用addition_answer评估阅读理解任务 | |||||
| ├──run_CBT_task.py // CBT任务微调&推理API入口 | |||||
| ├──run_lambada.py // LAMBADA任务微调&推理API入口 | |||||
| ├──run_language_mdoel.py // 语言建模任务微调&推理API入口 | |||||
| ├──run_ReadComprehension.py // 阅读理解任务微调&推理API入口 | |||||
| ├──run_summarization.py // 摘要生成任务微调&推理API入口 | |||||
| ├──run_translation.py // 翻译任务微调&推理API入口 | |||||
| ├──task_dataset_preprocess.py // 各个任务的数据集处理入口 | |||||
| ├──convert_tf_ckpt | |||||
| │ ├──read_weight_tf.py // 读取tensorflow下的预训练模型 | |||||
| │ ├──trans_dict.py // 模型参数名称字典 | |||||
| │ ├──save_weight_ms.py // 生成mindspore ckpt | |||||
| ├──third_party | |||||
| │ ├──gpt2-merges.txt | |||||
| │ ├──gpt2-vocab.json // GPT-2预训练词表 | |||||
| │ ├──bleu.py // 辅助bleu值计算的第三方代码 | |||||
| ``` | |||||
| ## 模型转换 | |||||
| - 下载GPT-2的预训练模型 [GPT-2预训练模型下载](https://github.com/openai/gpt-2/blob/master/download_model.py) | |||||
| - 在tensorflow的环境下,运行`read_weight_tf.py`,示例代码如下: | |||||
| `python read_weight_tf.py --ckpt_file_path=/{path}/model.ckpt` | |||||
| - 在mindspore的环境下,运行`save_weight_ms.py`,示例代码如下: | |||||
| `python save_weight_ms.py --output_file_name="mindspore_gpt2_small.ckpt"` | |||||
| ## 准备数据集 | |||||
| ### Language Modeling语言建模任务 | |||||
| #### WikiText2 、WikiText103、PTB、1BW 数据集 | |||||
| - [WikiText2数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) 解压后使用`wikitext-2 /wiki.test.tokens`作为测试集 | |||||
| - [WikiText103数据集下载](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) 解压后使用`wikitext-103 /wiki.test.tokens`作为测试集 | |||||
| - [PTB数据集下载](http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz) 解压后使用 `/simple-examples/data/ptb.test.txt` 测试集,使用 `/simple-examples/data/ptb.test.txt` 作为训练集 | |||||
| - [1BW数据集下载](http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz) 解压后使用`1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050`作为测试集,使用`1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/news.en-00001-of-00100`作为原始训练集,进行随机采样后得到30000条训练集样本 | |||||
| 使用`task_dataset_preprocess.py`可以对以上数据集进行清洗。 | |||||
| `task_dataset_preprocess.py`的主要参数如下: | |||||
| ```bash | |||||
| --task: The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada, Summarization, ReadingComprehension]. | |||||
| --input_file: The raw dataset path. | |||||
| --dataset: The name of dataset which should be processed, only for LanguageModeling task. | |||||
| --output_file: The output dataset path after preprocessing. | |||||
| --condition: Process train or test dataset, including [train, test], only for 1BW and CNN & DailyMail dataset. | |||||
| ``` | |||||
| 示例代码如下: | |||||
| 清洗PTB训练集和测试集 | |||||
| ```bash | |||||
| python task_dataset_preprocess.py --task "LanguageModeling" --input_file /{path}/ptb.test.txt --dataset "ptb" --output_file /{path}/ptb_clean_test.txt --condition "test" | |||||
| ``` | |||||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||||
| `create_lm_data.py`的主要参数如下: | |||||
| ```bash | |||||
| --input_file: Input raw text file. | |||||
| --output_file: Output MindRecord file. | |||||
| --num_splits: The MindRecord file will be split into the number of partition. | |||||
| --max_seq_length: Maximum sequence length. | |||||
| --vocab_file: url of gpt2-vocab.json. | |||||
| --merge_file: url of gpt2-merges.txt | |||||
| ``` | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_lm_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||||
| ``` | |||||
| ### Children's Book Test任务 | |||||
| #### CBT-CN / CBT-NE 数据集 | |||||
| - [CBT数据集下载](http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz) 使用在`/data`目录下使用`cbtest_CN_valid_2000ex.txt、cbtest_NE_valid_2000ex.txt`作为该任务的评估集,清洗该数据集,示例代码如下: | |||||
| ```bash | |||||
| python task_dataset_preprocess.py --task "CBT" --input_file /{path}/cbtest_CN_valid_2000ex.txt --dataset "cbt" --output_file /{path}/cbt_cn_valid.txt | |||||
| ``` | |||||
| 使用`create_cbt_data.py`可以将以上数据集格式转换为mindrecord | |||||
| `create_cbt_data.py`的主要参数如下: | |||||
| ```bash | |||||
| --input_file: Input raw text file. | |||||
| --output_file: Output MindRecord file. | |||||
| --num_splits: The MindRecord file will be split into the number of partition. | |||||
| --max_seq_length: Maximum sequence length. | |||||
| --num_choice: Number of choices. | |||||
| --vocab_file: url of gpt2-vocab.json. | |||||
| --merge_file: url of gpt2-merges.txt | |||||
| ``` | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_cbt_data.py --input_file /{path}/ptb.test.txt --output_file /{path}/ptb-test-mindrecord --num_splits 1 --max_length 1024 --num_choice 10 --vocab_file={path} --merge_file={path} | |||||
| ``` | |||||
| ### LAMBADA任务 | |||||
| #### LAMBADA 数据集 | |||||
| - [LAMBADA数据集下载](https://zenodo.org/record/2630551#.X-yCSTTithH) 使用`lambada_test_plain_text.txt`作为该任务的评估集,清洗该数据集,示例代码如下: | |||||
| ```bash | |||||
| python task_dataset_preprocess.py --task "LAMBADA" --input_file /{path}/lambada_test_plain_text.txt --dataset "LAMBADA" --output_file /{path}/lambada_test_clean.txt | |||||
| ``` | |||||
| 使用`create_lambada_data.py`可以将以上数据集格式转换为mindrecord | |||||
| `create_lambada_data.py`的主要参数如下: | |||||
| ```bash | |||||
| --input_file: Input raw text file. | |||||
| --output_file: Output MindRecord file. | |||||
| --num_splits: The MindRecord file will be split into the number of partition. | |||||
| --max_seq_length: Maximum sequence length. | |||||
| --vocab_file: url of gpt2-vocab.json. | |||||
| --merge_file: url of gpt2-merges.txt | |||||
| ``` | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_lambada_data.py --input_file /{path}/lambada_test_clean.txt --output_file /{path}/lambada-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||||
| ``` | |||||
| ### Reading Comprehension 任务 | |||||
| #### CoQA数据集 | |||||
| - [CoQA数据集下载](http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json) 使用`coqa-dev-v1.0.json`作为该任务的评估集,清洗该数据集,示例代码如下: | |||||
| ```bash | |||||
| python task_dataset_preprocess.py --task "ReadingComprehension" --input_file /{path}/coqa-dev-v1.0.json --dataset "coqa" --output_file /{path}/coqa_dev.txt | |||||
| ``` | |||||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_lm_data.py --input_file /{path}/coqa_dev.txt --output_file /{path}/coqa-dev-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||||
| ``` | |||||
| ### Summarization 任务 | |||||
| #### CNN & Dailymail数据集 | |||||
| - 下载该数据集,使用`download_cnn_dailymail.py`脚本进行下载,示例代码如下: | |||||
| ```bash | |||||
| 下载测试集 | |||||
| python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split test | |||||
| 下载训练集 | |||||
| python download_cnn_dailymail.py --dir ./cnn_dailymail/ --split train | |||||
| ``` | |||||
| 从训练集中随机采用10000条样本作为最终的微调的训练集,使用`cnn_dataset_sampler.py`脚本进行训练的采样操作,生成新的训练集,示例代码如下: | |||||
| ```bash | |||||
| GPT-2 small和GPT-2 medium模型的训练集中seq_length=1024, 因此该脚本中设置max_length=1022 | |||||
| python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt" | |||||
| --output_path="/{path}/cnn_train_hint_small.txt" | |||||
| --replace_hint="true" | |||||
| --sample="true" | |||||
| --max_length=1022 | |||||
| --prob=0.25 | |||||
| --max_items=10000 | |||||
| --hint="TL;DR:" | |||||
| GPT-2 large模型的训练集中seq_length=768,因此该脚本中设置max_length=766 | |||||
| python cnn_dataset_sampler.py --input_path="/{path}/cnn_train.txt" | |||||
| --output_path="/{path}/cnn_train_hint_large.txt" | |||||
| --replace_hint="true" | |||||
| --sample="true" | |||||
| --max_length=766 | |||||
| --prob=0.25 | |||||
| --max_items=10000 | |||||
| --hint="TL;DR:" | |||||
| ``` | |||||
| 使用`create_summary_data.py`可以将以上数据集格式转换为mindrecord | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_summary_data.py --input_file /{path}/cnn_dailymail_test.txt --output_file /{path}/cnn_dailymail-test-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} --mode 'cnn_dailymail' | |||||
| ``` | |||||
| ### Translation 任务 | |||||
| #### WMT14 En-Fr数据集 | |||||
| - [WMT14 En-Fr数据集下载](http://statmt.org/wmt14/test-full.tgz) 使用`newstest2014-fren-ref.en.sgm`和`newstest2014-fren-ref.fr.sgm`作为该任务的评估集,合并且清洗该数据集,示例代码如下: | |||||
| ```bash | |||||
| python task_dataset_preprocess.py --task "Translation" --input_file /{path}/test-full --dataset "wmt14" --output_file /{path}/wmt14 | |||||
| ``` | |||||
| 在`output_file`路径下会生成两个文件`wmt14.en_fr.txt`和`wmt14.fr_en.txt`,分别用于评估`En-Fr`和`Fr-En`。 | |||||
| 使用`create_lm_data.py`可以将以上数据集格式转换为mindrecord | |||||
| 示例代码如下: | |||||
| ```bash | |||||
| python create_lm_data.py --input_file /{path}/wmt14.en_fr.txt --output_file /{path}/en-fr-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||||
| python create_lm_data.py --input_file /{path}/wmt14.fr_en.txt --output_file /{path}/fr-en-mindrecord --num_splits 1 --max_length 1024 --vocab_file={path} --merge_file={path} | |||||
| ``` | |||||
| ## 配置 | |||||
| `src/finetune_eval_config.py`为GPT-2模型训练和推理的配置文件,便于为大多数选项及参数赋值,包括GPT-2 模型规模、模型的配置、优化器参数等。 | |||||
| 有关属性的详细信息,参见`src/finetune_eval_config.py`文件。 | |||||
| ## 微调&评估过程 | |||||
| ### Language Modeling 语言建模任务 | |||||
| #### 微调 | |||||
| - PTB数据集 | |||||
| GPT-2 small / GPT-2 medium / GPT-2 large模型需要在PTB训练集上进行微调。微调模型时,只需要使用shell脚本`scripts/run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`scripts/run_language_model.sh`脚本。 | |||||
| 微调模型时,首先配置`src/finetune_eval_config.py`中的选项: | |||||
| 将`cfg`下的`gpt2_network`设置为相应的GPT-2模型大小`[small/medium/large]`。 | |||||
| 将`cfg`下的`optimizer`设置为`Lamb`,进行优化器的选择(可采用'momentum/adam/lamb’)。 | |||||
| 选定了GPT-2模型后需要设置模型的参数,包括`batch_size`和`seq_length`。 | |||||
| 而后执行`scripts/run_language_model.sh`这个shell脚本: | |||||
| ```bash | |||||
| sh scripts/run_language_model.sh --device_target="Ascend" | |||||
| --do_train="true" | |||||
| --do_eval="false" | |||||
| --epoch_num=1 | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --save_finetune_ckpt_path={save_finetune_ckpt_path} | |||||
| --load_pretrain_ckpt_path={load_pretrain_ckpt_path} | |||||
| --train_data_file_path={train_data_file_path} | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||||
| ```bash | |||||
| sh scripts/run_language_model.sh [--options] | |||||
| ``` | |||||
| `run_language_model.sh`的用法如下: | |||||
| ```text | |||||
| usage: run_language_model.sh [--device_target DEVICE_TARGET] [--device_id N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --metric_method The eval method including [PPL]. Default: "PPL" | |||||
| --do_train Enable train. Default: "false" | |||||
| --do_eval Enable evaluation. Default: "true" | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||||
| --epoch_num Epoch number. Default: 1 | |||||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||||
| --save_finetune_ckpt_path Save the finetuned checkpoint path | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path for train | |||||
| --load_finetune_ckpt_path Load the checkpoint file path for evaluation | |||||
| --train_data_file_path Data path, it is better to use absolute path | |||||
| --eval_data_file_path Data path, it is better to use absolute path | |||||
| ``` | |||||
| - 1BW数据集 | |||||
| GPT-2 large模型需要在1BW训练集上进行微调。微调模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。该微调方法与PTB数据集的一致。 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`WikiText2/WikiText103/PTB/1BW`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用PPL,即设置`--metric_method="PPL"`。 | |||||
| 评估模型时,只需要使用shell脚本`run_language_model.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_language_model.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_language_model.sh`这个shell脚本,若该模型在某个数据集上被微调了,则使用该模型进行对应测试集的评估时需要设置`--eval_type="finetuned"`,否则设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是微调好后的checkpoint文件位置 | |||||
| ```bash | |||||
| sh scripts/run_language_model.sh --device_target="Ascend" | |||||
| --metric_method="PPL" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="finetuned" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||||
| ### Children's Book Test任务 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`CBT-CN/CBT-NE`验证集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy,即设置`--metric_method="Accuracy"`。 | |||||
| 评估模型时,只需要使用shell脚本`run_cbt.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_CBT_task.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_cbt.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||||
| ```bash | |||||
| sh scripts/run_cbt.sh --device_target="Ascend" | |||||
| --num_choice=10 | |||||
| --metric_method="Accuarcy" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="zero-shot" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||||
| ```bash | |||||
| sh scripts/run_cbt.sh [--options] | |||||
| ``` | |||||
| `run_cbt.sh`的用法如下: | |||||
| ```text | |||||
| usage: run_CBT_task.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --num_choice The number of choice in CBT task | |||||
| --metric_method The eval method including [Accuracy]. Default: "Accuracy" | |||||
| --do_train Enable train. Default: "false" | |||||
| --do_eval Enable evaluation. Default: "true" | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||||
| --epoch_num Epoch number. Default: 1 | |||||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||||
| --save_finetune_ckpt_path Save the finetuned checkpoint path | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path for train | |||||
| --load_finetune_ckpt_path Load the checkpoint file path for evaluation | |||||
| --train_data_file_path Data path, it is better to use absolute path | |||||
| --eval_data_file_path Data path, it is better to use absolute path | |||||
| ``` | |||||
| ### LAMBADA任务 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`LAMBADA`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用Accuracy和PPL,即设置`--metric_method="Accuracy"` 或者`--metric_method="PPL"`。 | |||||
| 评估模型时,只需要使用shell脚本`run_lambada.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_lambada.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_lambada.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||||
| 评估Accuracy | |||||
| ```bash | |||||
| sh scripts/run_lambada.sh --device_target="Ascend" | |||||
| --metric_method="Accuarcy" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="zero-shot" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --generate_length_dynamically="true" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| --tokenizer_file_path={tokenizer_file_path} | |||||
| --stop_word_file_path={stop_word_file_path} | |||||
| ``` | |||||
| 评估PPL | |||||
| ```bash | |||||
| sh scripts/run_lambada.sh --device_target="Ascend" | |||||
| --metric_method="PPL" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="zero-shot" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||||
| ```bash | |||||
| sh scripts/run_lambada.sh [--options] | |||||
| ``` | |||||
| ```text | |||||
| usage: run_lambada.sh [--device_target DEVICE_TARGET] [--device_id N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--generate_length_dynamically GENERATE_LENGTH_DYNAMICALLY] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||||
| [--stop_word_file_path STOP_WORD_FILE_PATH] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --metric_method The eval method including [Accuracy, PPL]. Default: "Accuracy" | |||||
| --do_train Enable train. Default: "false" | |||||
| --do_eval Enable evaluation. Default: "true" | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||||
| --epoch_num Epoch number. Default: 1 | |||||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||||
| --generate_length_dynamically Enable generate_length_Dynamically. Default: "true" | |||||
| --save_finetune_ckpt_path Save the checkpoint path | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||||
| --load_finetune_ckpt_path Load the checkpoint file path | |||||
| --train_data_file_path Data path, it is better to use absolute path | |||||
| --eval_data_file_path Data path, it is better to use absolute path | |||||
| --tokenizer_file_path pretrained vocab and merge file path | |||||
| --stop_word_file_path The stop word file path | |||||
| ``` | |||||
| ### Reading Comprehension任务 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`CoQA`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="F1"` 。 | |||||
| 评估模型时,只需要使用shell脚本`run_read_comprehension.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_read_comprehension.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_read_comprehension.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||||
| ```bash | |||||
| sh scripts/run_read_comprehension.sh --device_target="Ascend" | |||||
| --metric_method="F1" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="zero-shot" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| --tokenizer_file_path={tokenizer_file_path} | |||||
| --generate_length=55 | |||||
| --top_k=1 | |||||
| --top_p="1.0" | |||||
| --temperature="1.0" | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。而后将得到的日志文件作为`eval_rc_addition_answer.py`脚本的`input_file`,同时将原CoQA开发集`coqa-dev-v1.0.json`作为`addition_file`。 | |||||
| 执行`python eval_rc_addition_answer.py --input_file={path} --addition_file={path}`得到最终的F1值。 | |||||
| ```bash | |||||
| sh scripts/run_read_comprehension.sh [--options] | |||||
| ``` | |||||
| ```text | |||||
| usage: run_read_comprehension.sh [--device_target DEVICE_TARGET] [--device_id N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||||
| [--generate_length N] [--top_k N] [--top_p TOP_P] | |||||
| [--temperature TEMPERATURE] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --metric_method The eval method including [F1]. Default: "F1" | |||||
| --do_train Enable train. Default: "false" | |||||
| --do_eval Enable evaluation. Default: "false" | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||||
| --epoch_num Epoch number. Default: 1 | |||||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||||
| --save_finetune_ckpt_path Save the checkpoint path | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||||
| --load_finetune_ckpt_path Load the checkpoint file path | |||||
| --train_data_file_path Data path, it is better to use absolute path | |||||
| --eval_data_file_path Data path, it is better to use absolute path | |||||
| --tokenizer_file_path pretrained vocab and merge file path | |||||
| --generate_length The generation length of answer sentence | |||||
| --top_k Parameter for Top-K sampling | |||||
| --top_p Parameter for Top-P sampling | |||||
| --temperature Parameter for generation, greater if generation more diverse | |||||
| ``` | |||||
| ### Summarization任务 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`CNN_Dailymail`开发集上进行对应的评估,针对以上数据集的评估,其评估方法采用F1,即设置`--metric_method="ROUGE"` 。 | |||||
| 评估模型时,只需要使用shell脚本`run_summarization.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_summarization.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_summarization.sh`这个shell脚本,且对于`hint`的情况设置`eval_type="finetuned"`,`--load_finetune_ckpt_path`是需要加载微调好的checkpoint文件;而对于`no hint`的情况设置`eval_type="zero-shot"`除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||||
| ```bash | |||||
| sh scripts/run_summarization.sh --device_target="Ascend" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --metric_method="Rouge" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --generate_length=100 | |||||
| --top_k=2 | |||||
| --top_p="1.0" | |||||
| --temperature="1.0" | |||||
| --eval_type="finetuned" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| --tokenizer_file_path={tokenizer_file_path} | |||||
| ``` | |||||
| 日志和输出文件可以在`./ms_log/`路径下获取。 | |||||
| ```bash | |||||
| sh scripts/run_summarization.sh [--options] | |||||
| ``` | |||||
| `run_summarization.sh`的用法如下: | |||||
| ```text | |||||
| usage: run_summarization.sh [--device_target DEVICE_TARGET] [--device_id N][--num_choice N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --do_train Enable train. Default: false. | |||||
| --do_eval Enable evaluation. Default: false. | |||||
| --metric_method The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge. Default: "false" | |||||
| --epoch_num Epoch number. Default: 2. | |||||
| --train_data_shuffle Enable train data shuffle. Default: true. | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: false. | |||||
| --save_finetune_ckpt_path Save the checkpoint path. | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path. | |||||
| --load_finetune_ckpt_path Load the checkpoint file path. | |||||
| --train_data_file_path Data path, it is better to use absolute path. | |||||
| --eval_data_file_path Data path, it is better to use absolute path. | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: zero-shot. | |||||
| --top_k Top k tokens chosen for sampling. | |||||
| --top_p Top p accumulated probability threshold for logit to be counted. | |||||
| --generate_length The number of generated tokens. | |||||
| --temperature Temperature on logits for sampling. | |||||
| --tokenizer_file_path Vocab & merge file path. | |||||
| ``` | |||||
| ### Translation任务 | |||||
| #### 评估 | |||||
| GPT-2模型可以在`WMT14 En-Fr`和`WMT14 Fr-En`测试集上进行对应的评估,针对以上数据集的评估,其评估方法采用BLEU,即设置`--metric_method="BLEU"` 。 | |||||
| 注:读者需要自行下载`bleu.py`脚本[脚本链接](https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py), 而后将该脚本放置于`src/utils/`目录下 | |||||
| 评估模型时,只需要使用shell脚本`run_translation.sh`即可,脚本中可以设置环境变量,执行`GPT-2`下的`run_translation.py`脚本。 | |||||
| 评估模型时,首先配置`src/finetune_eval_config.py`,而后执行`scripts/run_translation.sh`这个shell脚本,且设置`eval_type="zero-shot"`,除此之外`--load_finetune_ckpt_path`是只需加载预训练好的checkpoint文件 | |||||
| ```bash | |||||
| sh scripts/run_translation.sh --device_target="Ascend" | |||||
| --metric_method="BLEU" | |||||
| --do_train="false" | |||||
| --do_eval="true" | |||||
| --eval_type="zero-shot" | |||||
| --train_data_shuffle="true" | |||||
| --eval_data_shuffle="false" | |||||
| --load_finetune_ckpt_path={load_eval_ckpt_path} | |||||
| --eval_data_file_path={eval_data_file_path} | |||||
| --tokenizer_file_path={tokenizer_file_path} | |||||
| --generate_length=100 | |||||
| --top_k=1 | |||||
| --top_p="1.0" | |||||
| --temperature="1.0" | |||||
| ``` | |||||
| ```bash | |||||
| sh scripts/run_translation.sh [--options] | |||||
| ``` | |||||
| ```text | |||||
| usage: run_translation.sh [--device_target DEVICE_TARGET] [--device_id N] | |||||
| [--metric_method METRIC_METHOD] | |||||
| [--do_train DO_TRAIN] [--do_eval DO_EVAL] | |||||
| [--eval_type EVAL_TYPE] [--epoch_num N] | |||||
| [--train_data_shuffle TRAIN_DATA_SHUFFLE] | |||||
| [--eval_data_shuffle EVAL_DATA_SHUFFLE] | |||||
| [--save_finetune_ckpt_path SAVE_FINETUNE_CKPT_PATH] | |||||
| [--load_pretrain_ckpt_path LOAD_PRETRAIN_CKPT_PATH] | |||||
| [--load_finetune_ckpt_path LOAD_FINETUNE_CKPT_PATH] | |||||
| [--train_data_file_path TRAIN_DATA_FILE_PATH] | |||||
| [--eval_data_file_path EVAL_DATA_FILE_PATH] | |||||
| [--tokenizer_file_path TOKENIZER_FILE_PATH] | |||||
| [--generate_length N] [--top_k N] [--top_p TOP_P] | |||||
| [--temperature TEMPERATURE] | |||||
| options: | |||||
| --device_target Device type. Default: "Ascend" | |||||
| --device_id ID of target device | |||||
| --metric_method The eval method including [BLEU]. Default: "BLEU" | |||||
| --do_train Enable train. Default: "false" | |||||
| --do_eval Enable evaluation. Default: "true" | |||||
| --eval_type The type of evaluation including [zero-shot, finetuned]. Default: "zero-shot" | |||||
| --epoch_num Epoch number. Default: 1 | |||||
| --train_data_shuffle Enable train data shuffle. Default: "true" | |||||
| --eval_data_shuffle Enable eval data shuffle. Default: "false" | |||||
| --save_finetune_ckpt_path Save the checkpoint path | |||||
| --load_pretrain_ckpt_path Load the checkpoint file path | |||||
| --load_finetune_ckpt_path Load the checkpoint file path | |||||
| --train_data_file_path Data path, it is better to use absolute path | |||||
| --eval_data_file_path Data path, it is better to use absolute path | |||||
| --tokenizer_file_path pretrained vocab and merge file path | |||||
| --generate_length The generation length of translation sentence | |||||
| --top_k Parameter for Top-K sampling | |||||
| --top_p Parameter for Top-P sampling | |||||
| --temperature Parameter for generation, greater if generation more diverse | |||||
| ``` | |||||
| # 环境要求 | |||||
| ## 平台 | |||||
| - 硬件(Ascend) | |||||
| - 使用Ascend处理器准备硬件环境。- 如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,申请通过即可获得资源。 | |||||
| - 框架 | |||||
| - [MindSpore](https://www.mindspore.cn/install) | |||||
| - 更多关于Mindspore的信息,请查看以下资源: | |||||
| - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) | |||||
| ## 其他要求 | |||||
| ```text | |||||
| math | |||||
| numpy | |||||
| copy | |||||
| collections | |||||
| re | |||||
| rouge 1.0.0 | |||||
| datasets >=0.4.0 | |||||
| json | |||||
| tensorflow | |||||
| ``` | |||||
| # 性能 | |||||
| ## 推理性能 | |||||
| ### Language Modeling任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Language Modeling任务中的PPL得分情况。 | |||||
| | 模型 | dataset | device | eval_type | PPL | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | WikiText2 | Ascend | zero-shot | 24.5 | 29.41 | | |||||
| | GPT-2 medium | WikiText2 | Ascend | zero-shot | 19.41 | 22.76 | | |||||
| | GPT-2 large | WikiText2 | Ascend | zero-shot | 17.08 | 19.93 | | |||||
| | GPT-2 small | WikiText103 | Ascend | zero-shot | 26.89 | 37.5 | | |||||
| | GPT-2 medium | WikiText103 | Ascend | zero-shot | 20.23 | 26.37 | | |||||
| | GPT-2 large | WikiText103 | Ascend | zero-shot | 17.48 | 22.05 | | |||||
| | GPT-2 small | PTB | Ascend | finetune | 23.91 | 65.85 | | |||||
| | GPT-2 medium | PTB | Ascend | finetune | 20.06 | 47.33 | | |||||
| | GPT-2 large | PTB | Ascend | finetune | 18.84 | 40.31 | | |||||
| | GPT-2 small | 1BW | Ascend | zero-shot | 63.13 | 75.2 | | |||||
| | GPT-2 medium | 1BW | Ascend | zero-shot | 50.98 | 55.72 | | |||||
| | GPT-2 large | 1BW | Ascend | finetune | 29.28 | 44.575 | | |||||
| ### Children's Book Test 任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Children's Book Test 任务中的Accuracy得分情况。 | |||||
| | 模型 | dataset | device | eval_type | ACC | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | CBT-CN valid | Ascend | zero-shot | 87.85 | 87.65 | | |||||
| | GPT-2 medium | CBT-CN valid | Ascend | zero-shot | 92.1 | 92.35 | | |||||
| | GPT-2 large | CBT-CN valid | Ascend | zero-shot | 93.7 | 93.45 | | |||||
| | GPT-2 small | CBT-NE valid | Ascend | zero-shot | 85.1 | 83.4 | | |||||
| | GPT-2 medium | CBT-NE valid | Ascend | zero-shot | 87.55 | 87.1 | | |||||
| | GPT-2 large | CBT-NE valid | Ascend | zero-shot | 89.1 | 88 | | |||||
| ### LAMBADA 任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在LAMBADA 任务中的Accuracy和PPL得分情况。 | |||||
| | 模型 | dataset | device | eval_type | ACC | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | Lambada-test | Ascend | zero-shot | 45.99 | 45.99 | | |||||
| | GPT-2 medium | Lambada-test | Ascend | zero-shot | 58.59 | 55.48 | | |||||
| | GPT-2 large | Lambada-test | Ascend | zero-shot | 62.74 | 60.12 | | |||||
| | 模型 | dataset | device | eval_type | PPL | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | Lambada-test | Ascend | zero-shot | 22.95 | 35.13 | | |||||
| | GPT-2 medium | Lambada-test | Ascend | zero-shot | 10.69 | 15.6 | | |||||
| | GPT-2 large | Lambada-test | Ascend | zero-shot | 8.64 | 10.87 | | |||||
| ### Reading Comprehension 任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Reading Comprehension任务中的F1得分情况。 | |||||
| | 模型 | dataset | device | eval_type | F1 | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | CoQA | Ascend | zero-shot | 25.94 | 25~26 | | |||||
| | GPT-2 medium | CoQA | Ascend | zero-shot | 43.69 | 42~43 | | |||||
| | GPT-2 large | CoQA | Ascend | zero-shot | 49.39 | 49~51 | | |||||
| ### Summarization 任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Summarization任务中的ROUGE得分情况。 | |||||
| | 模型 | dataset | device | eval_type | ROUGE | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | CNN_Dailymail(TL;DR) | Ascend | finetune | 21.4 | 16.8~17 | | |||||
| | GPT-2 medium | CNN_Dailymail(TL;DR) | Ascend | finetune | 25.94 | 20.6~20.9 | | |||||
| | GPT-2 large | CNN_Dailymail(TL;DR) | Ascend | finetune | 26.73 | 21.5~21.6 | | |||||
| | 模型 | dataset | device | eval_type | ROUGE | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.08 | 15.03(xlarge) | | |||||
| | GPT-2 medium | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.16 | 15.03(xlarge) | | |||||
| | GPT-2 large | CNN_Dailymail(no hint) | Ascend | zero-shot | 12.29 | 15.03(xlarge) | | |||||
| ### Translation 任务 | |||||
| 下表展示了GPT-2 small、medium、large三种规模的模型在Translation任务中的BLEU得分情况。 | |||||
| | 模型 | dataset | device | eval_type | BLEU | OpenAI | | |||||
| | :--- | :------ | :------ | :------ | :------ | :------ | | |||||
| | GPT-2 small | WMT-14 Fr-En | Ascend | zero-shot | 4.49 | 0.7~0.8 | | |||||
| | GPT-2 medium | WMT-14 Fr-En | Ascend | zero-shot | 7.09 | 2.0~3.0 | | |||||
| | GPT-2 large | WMT-14 Fr-En | Ascend | zero-shot | 7.97 | 6.5~7.0 | | |||||
| | GPT-2 small | WMT-14 En-Fr | Ascend | zero-shot | 2.81 | 5(xlarge) | | |||||
| | GPT-2 medium | WMT-14 En-Fr | Ascend | zero-shot | 3.2 | 5(xlarge) | | |||||
| | GPT-2 large | WMT-14 En-Fr | Ascend | zero-shot | 3.06 | 5(xlarge) | | |||||
| # 其他 | |||||
| 该模型已在Ascend环境下环境下得到验证。 | |||||
| # ModelZoo主页 | |||||
| [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo) | |||||
| @@ -0,0 +1,141 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| CNN & DailyMail train dataset sampler | |||||
| """ | |||||
| import os | |||||
| import sys | |||||
| import shutil | |||||
| import argparse | |||||
| from random import random | |||||
| from src.utils.tokenization import Tokenizer | |||||
| def replace_split_word(read_path, output_path, tldr_str="TL;DR:", original_split='\t'): | |||||
| """ | |||||
| append tldr str | |||||
| """ | |||||
| with open(read_path, "r") as r, open(output_path, "a") as w: | |||||
| line = r.readline() | |||||
| while line: | |||||
| article = line[:line.find(original_split)] + ' ' + tldr_str + ' ' | |||||
| ref = line[line.rfind(original_split) + 1:] | |||||
| w.write(article + ref) | |||||
| line = r.readline() | |||||
| def sample(read_path, out_path, threshold=1.0, max_items=0xFFFFFFF): | |||||
| """ | |||||
| sample function | |||||
| """ | |||||
| cnt = 0 | |||||
| total_cnt = 0 | |||||
| with open(read_path, "r") as r, open(out_path, "a") as w: | |||||
| line = r.readline() | |||||
| while line: | |||||
| total_cnt += 1 | |||||
| if cnt >= max_items: | |||||
| break | |||||
| if random() > threshold: | |||||
| line = r.readline() | |||||
| continue | |||||
| w.write(line) | |||||
| if (cnt + 1) % 3000 == 0: | |||||
| print("Now Processed Samples: {}, total: {}".format(cnt, total_cnt)) | |||||
| cnt += 1 | |||||
| line = r.readline() | |||||
| def clip_article(input_path, out_path, hint, max_length): | |||||
| """ | |||||
| clip article that the sample (article + summary) exceed max_length | |||||
| """ | |||||
| tokenizer = Tokenizer() | |||||
| cnt = 0 | |||||
| with open(input_path, "r") as r, open(out_path, "a+") as a: | |||||
| line = r.readline() | |||||
| while line: | |||||
| pos = line.rfind(hint) | |||||
| article = line[:pos] | |||||
| summary = line[pos:] | |||||
| if len(tokenizer.encode(line)) > max_length: | |||||
| l_article = tokenizer.encode(article)[:max_length - len(tokenizer.encode(summary))] | |||||
| article = tokenizer.decode(l_article) + " " | |||||
| if cnt % 1000 == 0: | |||||
| print(article + summary) | |||||
| print("==============================") | |||||
| cnt += 1 | |||||
| a.write(article + summary) | |||||
| line = r.readline() | |||||
| def sampler_dataset(): | |||||
| """ | |||||
| run CNN & DailyMail train dataset sampler | |||||
| """ | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--input_path", type=str, default="", | |||||
| help="input file path") | |||||
| parser.add_argument("--output_path", type=str, default="", | |||||
| help="out file path") | |||||
| parser.add_argument("--replace_hint", type=str, default="true") | |||||
| parser.add_argument("--sample", type=str, default="true", | |||||
| help="do sample? true or false") | |||||
| parser.add_argument("--max_length", type=int, default=1022, | |||||
| help="max seq_length of input_raw_dataset") | |||||
| parser.add_argument("--prob", type=float, default=0.25, | |||||
| help="sample rate") | |||||
| parser.add_argument("--max_items", type=int, default=10000, | |||||
| help="max number of document") | |||||
| parser.add_argument("--hint", type=str, default="TL:DR;", | |||||
| help="hint text") | |||||
| args = parser.parse_args() | |||||
| # temp_files, one for storing inputs in every stage, the other for storing middle results. | |||||
| temp_file_input = sys.path[0] + '/temp_file1_by_sampler_py.txt' | |||||
| temp_file_proc = sys.path[0] + '/temp_file2_by_sampler_py.txt' | |||||
| read_path = args.input_path | |||||
| output_path = args.output_path | |||||
| prob = args.prob | |||||
| max_items = args.max_items | |||||
| hint = args.hint | |||||
| max_length = args.max_length | |||||
| split_str = '\t' | |||||
| shutil.copyfile(read_path, temp_file_input) | |||||
| clip_article(temp_file_input, temp_file_proc, hint=split_str, max_length=max_length) | |||||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||||
| os.remove(temp_file_proc) | |||||
| if args.replace_hint.lower() == "true": | |||||
| replace_split_word(temp_file_input, temp_file_proc, hint, split_str) | |||||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||||
| os.remove(temp_file_proc) | |||||
| if args.sample.lower() == "true": | |||||
| sample(temp_file_input, temp_file_proc, prob, max_items) | |||||
| shutil.copyfile(temp_file_proc, temp_file_input) | |||||
| os.remove(temp_file_proc) | |||||
| shutil.copyfile(temp_file_input, output_path) | |||||
| os.remove(temp_file_input) | |||||
| if __name__ == "__main__": | |||||
| sampler_dataset() | |||||
| @@ -0,0 +1,67 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Read weight using tensorflow | |||||
| to read the parameters of the gpt-2 pretrained model from tensorflow checkpoint | |||||
| and save them into npy files for mindspore to load. | |||||
| *this script is based on the gpt-2 model downloaded from openai.* | |||||
| """ | |||||
| import argparse | |||||
| import tensorflow as tf | |||||
| import numpy as np | |||||
| from .trans_dict import trans_dict_tf | |||||
| def read_weight(ckpt_path): | |||||
| """ | |||||
| read weight | |||||
| Args: | |||||
| ckpt_path: the path of tensorflow checkpoint | |||||
| """ | |||||
| # model path and model name | |||||
| init_vars = tf.train.list_variables(ckpt_path) | |||||
| # load the model parameters into vars | |||||
| save_param_num = 0 | |||||
| for name, _ in init_vars: | |||||
| array = tf.train.load_variable(ckpt_path, name) | |||||
| # By this you can understand the next step easily | |||||
| name = name[6:].replace(r"/", ".") | |||||
| # skip 'model/' and change var names to avoid path mistake | |||||
| if name not in trans_dict_tf.keys(): | |||||
| print(name + " is not in this model") | |||||
| else: | |||||
| np.save(trans_dict_tf[name] + ".npy", array) | |||||
| save_param_num = save_param_num + 1 | |||||
| # save the parameters by 'npy' | |||||
| print("finished!") | |||||
| print("save {num} parameters.".format(num=save_param_num)) | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight") | |||||
| parser.add_argument("--ckpt_file_path", type=str, default="", | |||||
| help="The tensorflow GPT-2 model checkpoint file path") | |||||
| args_opt = parser.parse_args() | |||||
| ckpt_path = args_opt.ckpt_file_path | |||||
| read_weight(ckpt_path=ckpt_path) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Save weight using mindspore, to load the parameters of gpt-2 model from npy file. | |||||
| npy files should be in the same path with this script. Otherwise you should change the path name of the script. | |||||
| """ | |||||
| import os | |||||
| import argparse | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.train.serialization import save_checkpoint | |||||
| from .trans_dict import trans_dict_tf | |||||
| def trans_model_parameter(ckpt_name): | |||||
| """ | |||||
| transform model parameters | |||||
| Args: | |||||
| ckpt_name (str): the name of the transformed checkpoint. | |||||
| """ | |||||
| file_names = [name for name in os.listdir() if name.endswith(".npy")] | |||||
| # to find all file names with suffix '.npy' in the current path. | |||||
| new_params_list = [] | |||||
| for file_name in file_names: | |||||
| var_name = file_name[:-4] | |||||
| param_dict = {"name": var_name, "data": Tensor(np.load(file_name))} | |||||
| if var_name in trans_dict_tf.values(): | |||||
| new_params_list.append(param_dict) | |||||
| print(var_name+" has been saved") | |||||
| save_checkpoint(new_params_list, ckpt_name) | |||||
| # to load the parameters from npy files and save them as mindspore checkpoint | |||||
| print("Finished:the parameters have been saved into mindspore checkpoint.") | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight") | |||||
| parser.add_argument("--output_file_name", type=str, default="", | |||||
| help="The name of output checkpoint name") | |||||
| args_opt = parser.parse_args() | |||||
| ckpt_name = args_opt.output_file_name | |||||
| trans_model_parameter(ckpt_name=ckpt_name) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,892 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """transform diction""" | |||||
| trans_dict_tf = { | |||||
| 'h0.attn.c_attn.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h0.attn.c_attn.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h0.attn.c_proj.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h0.attn.c_proj.w': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h0.ln_1.b': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h0.ln_1.g': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h0.ln_2.b': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta', | |||||
| 'h0.ln_2.g': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h0.mlp.c_fc.b': 'gpt2_decoder.layers.0.feedforward.c_fc.bias', | |||||
| 'h0.mlp.c_fc.w': 'gpt2_decoder.layers.0.feedforward.c_fc.weight', | |||||
| 'h0.mlp.c_proj.b': 'gpt2_decoder.layers.0.feedforward.c_proj.bias', | |||||
| 'h0.mlp.c_proj.w': 'gpt2_decoder.layers.0.feedforward.c_proj.weight', | |||||
| 'h1.attn.c_attn.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h1.attn.c_attn.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h1.attn.c_proj.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h1.attn.c_proj.w': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h1.ln_1.b': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h1.ln_1.g': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h1.ln_2.b': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta', | |||||
| 'h1.ln_2.g': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h1.mlp.c_fc.b': 'gpt2_decoder.layers.1.feedforward.c_fc.bias', | |||||
| 'h1.mlp.c_fc.w': 'gpt2_decoder.layers.1.feedforward.c_fc.weight', | |||||
| 'h1.mlp.c_proj.b': 'gpt2_decoder.layers.1.feedforward.c_proj.bias', | |||||
| 'h1.mlp.c_proj.w': 'gpt2_decoder.layers.1.feedforward.c_proj.weight', | |||||
| 'h2.attn.c_attn.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h2.attn.c_attn.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h2.attn.c_proj.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h2.attn.c_proj.w': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h2.ln_1.b': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h2.ln_1.g': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h2.ln_2.b': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta', | |||||
| 'h2.ln_2.g': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h2.mlp.c_fc.b': 'gpt2_decoder.layers.2.feedforward.c_fc.bias', | |||||
| 'h2.mlp.c_fc.w': 'gpt2_decoder.layers.2.feedforward.c_fc.weight', | |||||
| 'h2.mlp.c_proj.b': 'gpt2_decoder.layers.2.feedforward.c_proj.bias', | |||||
| 'h2.mlp.c_proj.w': 'gpt2_decoder.layers.2.feedforward.c_proj.weight', | |||||
| 'h3.attn.c_attn.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h3.attn.c_attn.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h3.attn.c_proj.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h3.attn.c_proj.w': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h3.ln_1.b': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h3.ln_1.g': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h3.ln_2.b': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta', | |||||
| 'h3.ln_2.g': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h3.mlp.c_fc.b': 'gpt2_decoder.layers.3.feedforward.c_fc.bias', | |||||
| 'h3.mlp.c_fc.w': 'gpt2_decoder.layers.3.feedforward.c_fc.weight', | |||||
| 'h3.mlp.c_proj.b': 'gpt2_decoder.layers.3.feedforward.c_proj.bias', | |||||
| 'h3.mlp.c_proj.w': 'gpt2_decoder.layers.3.feedforward.c_proj.weight', | |||||
| 'h4.attn.c_attn.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h4.attn.c_attn.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h4.attn.c_proj.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h4.attn.c_proj.w': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h4.ln_1.b': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h4.ln_1.g': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h4.ln_2.b': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta', | |||||
| 'h4.ln_2.g': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h4.mlp.c_fc.b': 'gpt2_decoder.layers.4.feedforward.c_fc.bias', | |||||
| 'h4.mlp.c_fc.w': 'gpt2_decoder.layers.4.feedforward.c_fc.weight', | |||||
| 'h4.mlp.c_proj.b': 'gpt2_decoder.layers.4.feedforward.c_proj.bias', | |||||
| 'h4.mlp.c_proj.w': 'gpt2_decoder.layers.4.feedforward.c_proj.weight', | |||||
| 'h5.attn.c_attn.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h5.attn.c_attn.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h5.attn.c_proj.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h5.attn.c_proj.w': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h5.ln_1.b': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h5.ln_1.g': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h5.ln_2.b': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta', | |||||
| 'h5.ln_2.g': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h5.mlp.c_fc.b': 'gpt2_decoder.layers.5.feedforward.c_fc.bias', | |||||
| 'h5.mlp.c_fc.w': 'gpt2_decoder.layers.5.feedforward.c_fc.weight', | |||||
| 'h5.mlp.c_proj.b': 'gpt2_decoder.layers.5.feedforward.c_proj.bias', | |||||
| 'h5.mlp.c_proj.w': 'gpt2_decoder.layers.5.feedforward.c_proj.weight', | |||||
| 'h6.attn.c_attn.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h6.attn.c_attn.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h6.attn.c_proj.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h6.attn.c_proj.w': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h6.ln_1.b': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h6.ln_1.g': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h6.ln_2.b': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta', | |||||
| 'h6.ln_2.g': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h6.mlp.c_fc.b': 'gpt2_decoder.layers.6.feedforward.c_fc.bias', | |||||
| 'h6.mlp.c_fc.w': 'gpt2_decoder.layers.6.feedforward.c_fc.weight', | |||||
| 'h6.mlp.c_proj.b': 'gpt2_decoder.layers.6.feedforward.c_proj.bias', | |||||
| 'h6.mlp.c_proj.w': 'gpt2_decoder.layers.6.feedforward.c_proj.weight', | |||||
| 'h7.attn.c_attn.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h7.attn.c_attn.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h7.attn.c_proj.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h7.attn.c_proj.w': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h7.ln_1.b': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h7.ln_1.g': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h7.ln_2.b': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta', | |||||
| 'h7.ln_2.g': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h7.mlp.c_fc.b': 'gpt2_decoder.layers.7.feedforward.c_fc.bias', | |||||
| 'h7.mlp.c_fc.w': 'gpt2_decoder.layers.7.feedforward.c_fc.weight', | |||||
| 'h7.mlp.c_proj.b': 'gpt2_decoder.layers.7.feedforward.c_proj.bias', | |||||
| 'h7.mlp.c_proj.w': 'gpt2_decoder.layers.7.feedforward.c_proj.weight', | |||||
| 'h8.attn.c_attn.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h8.attn.c_attn.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h8.attn.c_proj.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h8.attn.c_proj.w': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h8.ln_1.b': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h8.ln_1.g': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h8.ln_2.b': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta', | |||||
| 'h8.ln_2.g': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h8.mlp.c_fc.b': 'gpt2_decoder.layers.8.feedforward.c_fc.bias', | |||||
| 'h8.mlp.c_fc.w': 'gpt2_decoder.layers.8.feedforward.c_fc.weight', | |||||
| 'h8.mlp.c_proj.b': 'gpt2_decoder.layers.8.feedforward.c_proj.bias', | |||||
| 'h8.mlp.c_proj.w': 'gpt2_decoder.layers.8.feedforward.c_proj.weight', | |||||
| 'h9.attn.c_attn.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h9.attn.c_attn.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h9.attn.c_proj.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h9.attn.c_proj.w': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h9.ln_1.b': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h9.ln_1.g': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h9.ln_2.b': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta', | |||||
| 'h9.ln_2.g': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h9.mlp.c_fc.b': 'gpt2_decoder.layers.9.feedforward.c_fc.bias', | |||||
| 'h9.mlp.c_fc.w': 'gpt2_decoder.layers.9.feedforward.c_fc.weight', | |||||
| 'h9.mlp.c_proj.b': 'gpt2_decoder.layers.9.feedforward.c_proj.bias', | |||||
| 'h9.mlp.c_proj.w': 'gpt2_decoder.layers.9.feedforward.c_proj.weight', | |||||
| 'h10.attn.c_attn.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h10.attn.c_attn.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h10.attn.c_proj.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h10.attn.c_proj.w': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h10.ln_1.b': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h10.ln_1.g': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h10.ln_2.b': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta', | |||||
| 'h10.ln_2.g': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h10.mlp.c_fc.b': 'gpt2_decoder.layers.10.feedforward.c_fc.bias', | |||||
| 'h10.mlp.c_fc.w': 'gpt2_decoder.layers.10.feedforward.c_fc.weight', | |||||
| 'h10.mlp.c_proj.b': 'gpt2_decoder.layers.10.feedforward.c_proj.bias', | |||||
| 'h10.mlp.c_proj.w': 'gpt2_decoder.layers.10.feedforward.c_proj.weight', | |||||
| 'h11.attn.c_attn.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h11.attn.c_attn.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h11.attn.c_proj.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h11.attn.c_proj.w': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h11.ln_1.b': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h11.ln_1.g': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h11.ln_2.b': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta', | |||||
| 'h11.ln_2.g': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h11.mlp.c_fc.b': 'gpt2_decoder.layers.11.feedforward.c_fc.bias', | |||||
| 'h11.mlp.c_fc.w': 'gpt2_decoder.layers.11.feedforward.c_fc.weight', | |||||
| 'h11.mlp.c_proj.b': 'gpt2_decoder.layers.11.feedforward.c_proj.bias', | |||||
| 'h11.mlp.c_proj.w': 'gpt2_decoder.layers.11.feedforward.c_proj.weight', | |||||
| 'h12.attn.c_attn.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h12.attn.c_attn.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h12.attn.c_proj.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h12.attn.c_proj.w': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h12.ln_1.b': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h12.ln_1.g': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h12.ln_2.b': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta', | |||||
| 'h12.ln_2.g': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h12.mlp.c_fc.b': 'gpt2_decoder.layers.12.feedforward.c_fc.bias', | |||||
| 'h12.mlp.c_fc.w': 'gpt2_decoder.layers.12.feedforward.c_fc.weight', | |||||
| 'h12.mlp.c_proj.b': 'gpt2_decoder.layers.12.feedforward.c_proj.bias', | |||||
| 'h12.mlp.c_proj.w': 'gpt2_decoder.layers.12.feedforward.c_proj.weight', | |||||
| 'h13.attn.c_attn.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h13.attn.c_attn.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h13.attn.c_proj.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h13.attn.c_proj.w': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h13.ln_1.b': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h13.ln_1.g': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h13.ln_2.b': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta', | |||||
| 'h13.ln_2.g': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h13.mlp.c_fc.b': 'gpt2_decoder.layers.13.feedforward.c_fc.bias', | |||||
| 'h13.mlp.c_fc.w': 'gpt2_decoder.layers.13.feedforward.c_fc.weight', | |||||
| 'h13.mlp.c_proj.b': 'gpt2_decoder.layers.13.feedforward.c_proj.bias', | |||||
| 'h13.mlp.c_proj.w': 'gpt2_decoder.layers.13.feedforward.c_proj.weight', | |||||
| 'h14.attn.c_attn.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h14.attn.c_attn.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h14.attn.c_proj.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h14.attn.c_proj.w': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h14.ln_1.b': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h14.ln_1.g': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h14.ln_2.b': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta', | |||||
| 'h14.ln_2.g': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h14.mlp.c_fc.b': 'gpt2_decoder.layers.14.feedforward.c_fc.bias', | |||||
| 'h14.mlp.c_fc.w': 'gpt2_decoder.layers.14.feedforward.c_fc.weight', | |||||
| 'h14.mlp.c_proj.b': 'gpt2_decoder.layers.14.feedforward.c_proj.bias', | |||||
| 'h14.mlp.c_proj.w': 'gpt2_decoder.layers.14.feedforward.c_proj.weight', | |||||
| 'h15.attn.c_attn.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h15.attn.c_attn.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h15.attn.c_proj.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h15.attn.c_proj.w': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h15.ln_1.b': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h15.ln_1.g': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h15.ln_2.b': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta', | |||||
| 'h15.ln_2.g': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h15.mlp.c_fc.b': 'gpt2_decoder.layers.15.feedforward.c_fc.bias', | |||||
| 'h15.mlp.c_fc.w': 'gpt2_decoder.layers.15.feedforward.c_fc.weight', | |||||
| 'h15.mlp.c_proj.b': 'gpt2_decoder.layers.15.feedforward.c_proj.bias', | |||||
| 'h15.mlp.c_proj.w': 'gpt2_decoder.layers.15.feedforward.c_proj.weight', | |||||
| 'h16.attn.c_attn.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h16.attn.c_attn.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h16.attn.c_proj.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h16.attn.c_proj.w': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h16.ln_1.b': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h16.ln_1.g': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h16.ln_2.b': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta', | |||||
| 'h16.ln_2.g': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h16.mlp.c_fc.b': 'gpt2_decoder.layers.16.feedforward.c_fc.bias', | |||||
| 'h16.mlp.c_fc.w': 'gpt2_decoder.layers.16.feedforward.c_fc.weight', | |||||
| 'h16.mlp.c_proj.b': 'gpt2_decoder.layers.16.feedforward.c_proj.bias', | |||||
| 'h16.mlp.c_proj.w': 'gpt2_decoder.layers.16.feedforward.c_proj.weight', | |||||
| 'h17.attn.c_attn.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h17.attn.c_attn.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h17.attn.c_proj.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h17.attn.c_proj.w': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h17.ln_1.b': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h17.ln_1.g': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h17.ln_2.b': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta', | |||||
| 'h17.ln_2.g': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h17.mlp.c_fc.b': 'gpt2_decoder.layers.17.feedforward.c_fc.bias', | |||||
| 'h17.mlp.c_fc.w': 'gpt2_decoder.layers.17.feedforward.c_fc.weight', | |||||
| 'h17.mlp.c_proj.b': 'gpt2_decoder.layers.17.feedforward.c_proj.bias', | |||||
| 'h17.mlp.c_proj.w': 'gpt2_decoder.layers.17.feedforward.c_proj.weight', | |||||
| 'h18.attn.c_attn.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h18.attn.c_attn.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h18.attn.c_proj.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h18.attn.c_proj.w': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h18.ln_1.b': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h18.ln_1.g': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h18.ln_2.b': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta', | |||||
| 'h18.ln_2.g': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h18.mlp.c_fc.b': 'gpt2_decoder.layers.18.feedforward.c_fc.bias', | |||||
| 'h18.mlp.c_fc.w': 'gpt2_decoder.layers.18.feedforward.c_fc.weight', | |||||
| 'h18.mlp.c_proj.b': 'gpt2_decoder.layers.18.feedforward.c_proj.bias', | |||||
| 'h18.mlp.c_proj.w': 'gpt2_decoder.layers.18.feedforward.c_proj.weight', | |||||
| 'h19.attn.c_attn.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h19.attn.c_attn.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h19.attn.c_proj.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h19.attn.c_proj.w': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h19.ln_1.b': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h19.ln_1.g': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h19.ln_2.b': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta', | |||||
| 'h19.ln_2.g': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h19.mlp.c_fc.b': 'gpt2_decoder.layers.19.feedforward.c_fc.bias', | |||||
| 'h19.mlp.c_fc.w': 'gpt2_decoder.layers.19.feedforward.c_fc.weight', | |||||
| 'h19.mlp.c_proj.b': 'gpt2_decoder.layers.19.feedforward.c_proj.bias', | |||||
| 'h19.mlp.c_proj.w': 'gpt2_decoder.layers.19.feedforward.c_proj.weight', | |||||
| 'h20.attn.c_attn.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h20.attn.c_attn.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h20.attn.c_proj.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h20.attn.c_proj.w': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h20.ln_1.b': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h20.ln_1.g': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h20.ln_2.b': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta', | |||||
| 'h20.ln_2.g': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h20.mlp.c_fc.b': 'gpt2_decoder.layers.20.feedforward.c_fc.bias', | |||||
| 'h20.mlp.c_fc.w': 'gpt2_decoder.layers.20.feedforward.c_fc.weight', | |||||
| 'h20.mlp.c_proj.b': 'gpt2_decoder.layers.20.feedforward.c_proj.bias', | |||||
| 'h20.mlp.c_proj.w': 'gpt2_decoder.layers.20.feedforward.c_proj.weight', | |||||
| 'h21.attn.c_attn.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h21.attn.c_attn.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h21.attn.c_proj.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h21.attn.c_proj.w': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h21.ln_1.b': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h21.ln_1.g': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h21.ln_2.b': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta', | |||||
| 'h21.ln_2.g': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h21.mlp.c_fc.b': 'gpt2_decoder.layers.21.feedforward.c_fc.bias', | |||||
| 'h21.mlp.c_fc.w': 'gpt2_decoder.layers.21.feedforward.c_fc.weight', | |||||
| 'h21.mlp.c_proj.b': 'gpt2_decoder.layers.21.feedforward.c_proj.bias', | |||||
| 'h21.mlp.c_proj.w': 'gpt2_decoder.layers.21.feedforward.c_proj.weight', | |||||
| 'h22.attn.c_attn.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h22.attn.c_attn.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h22.attn.c_proj.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h22.attn.c_proj.w': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h22.ln_1.b': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h22.ln_1.g': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h22.ln_2.b': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta', | |||||
| 'h22.ln_2.g': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h22.mlp.c_fc.b': 'gpt2_decoder.layers.22.feedforward.c_fc.bias', | |||||
| 'h22.mlp.c_fc.w': 'gpt2_decoder.layers.22.feedforward.c_fc.weight', | |||||
| 'h22.mlp.c_proj.b': 'gpt2_decoder.layers.22.feedforward.c_proj.bias', | |||||
| 'h22.mlp.c_proj.w': 'gpt2_decoder.layers.22.feedforward.c_proj.weight', | |||||
| 'h23.attn.c_attn.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h23.attn.c_attn.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h23.attn.c_proj.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h23.attn.c_proj.w': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h23.ln_1.b': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h23.ln_1.g': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h23.ln_2.b': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta', | |||||
| 'h23.ln_2.g': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h23.mlp.c_fc.b': 'gpt2_decoder.layers.23.feedforward.c_fc.bias', | |||||
| 'h23.mlp.c_fc.w': 'gpt2_decoder.layers.23.feedforward.c_fc.weight', | |||||
| 'h23.mlp.c_proj.b': 'gpt2_decoder.layers.23.feedforward.c_proj.bias', | |||||
| 'h23.mlp.c_proj.w': 'gpt2_decoder.layers.23.feedforward.c_proj.weight', | |||||
| 'h24.attn.c_attn.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h24.attn.c_attn.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h24.attn.c_proj.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h24.attn.c_proj.w': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h24.ln_1.b': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h24.ln_1.g': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h24.ln_2.b': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta', | |||||
| 'h24.ln_2.g': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h24.mlp.c_fc.b': 'gpt2_decoder.layers.24.feedforward.c_fc.bias', | |||||
| 'h24.mlp.c_fc.w': 'gpt2_decoder.layers.24.feedforward.c_fc.weight', | |||||
| 'h24.mlp.c_proj.b': 'gpt2_decoder.layers.24.feedforward.c_proj.bias', | |||||
| 'h24.mlp.c_proj.w': 'gpt2_decoder.layers.24.feedforward.c_proj.weight', | |||||
| 'h25.attn.c_attn.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h25.attn.c_attn.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h25.attn.c_proj.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h25.attn.c_proj.w': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h25.ln_1.b': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h25.ln_1.g': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h25.ln_2.b': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta', | |||||
| 'h25.ln_2.g': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h25.mlp.c_fc.b': 'gpt2_decoder.layers.25.feedforward.c_fc.bias', | |||||
| 'h25.mlp.c_fc.w': 'gpt2_decoder.layers.25.feedforward.c_fc.weight', | |||||
| 'h25.mlp.c_proj.b': 'gpt2_decoder.layers.25.feedforward.c_proj.bias', | |||||
| 'h25.mlp.c_proj.w': 'gpt2_decoder.layers.25.feedforward.c_proj.weight', | |||||
| 'h26.attn.c_attn.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h26.attn.c_attn.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h26.attn.c_proj.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h26.attn.c_proj.w': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h26.ln_1.b': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h26.ln_1.g': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h26.ln_2.b': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta', | |||||
| 'h26.ln_2.g': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h26.mlp.c_fc.b': 'gpt2_decoder.layers.26.feedforward.c_fc.bias', | |||||
| 'h26.mlp.c_fc.w': 'gpt2_decoder.layers.26.feedforward.c_fc.weight', | |||||
| 'h26.mlp.c_proj.b': 'gpt2_decoder.layers.26.feedforward.c_proj.bias', | |||||
| 'h26.mlp.c_proj.w': 'gpt2_decoder.layers.26.feedforward.c_proj.weight', | |||||
| 'h27.attn.c_attn.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h27.attn.c_attn.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h27.attn.c_proj.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h27.attn.c_proj.w': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h27.ln_1.b': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h27.ln_1.g': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h27.ln_2.b': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta', | |||||
| 'h27.ln_2.g': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h27.mlp.c_fc.b': 'gpt2_decoder.layers.27.feedforward.c_fc.bias', | |||||
| 'h27.mlp.c_fc.w': 'gpt2_decoder.layers.27.feedforward.c_fc.weight', | |||||
| 'h27.mlp.c_proj.b': 'gpt2_decoder.layers.27.feedforward.c_proj.bias', | |||||
| 'h27.mlp.c_proj.w': 'gpt2_decoder.layers.27.feedforward.c_proj.weight', | |||||
| 'h28.attn.c_attn.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h28.attn.c_attn.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h28.attn.c_proj.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h28.attn.c_proj.w': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h28.ln_1.b': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h28.ln_1.g': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h28.ln_2.b': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta', | |||||
| 'h28.ln_2.g': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h28.mlp.c_fc.b': 'gpt2_decoder.layers.28.feedforward.c_fc.bias', | |||||
| 'h28.mlp.c_fc.w': 'gpt2_decoder.layers.28.feedforward.c_fc.weight', | |||||
| 'h28.mlp.c_proj.b': 'gpt2_decoder.layers.28.feedforward.c_proj.bias', | |||||
| 'h28.mlp.c_proj.w': 'gpt2_decoder.layers.28.feedforward.c_proj.weight', | |||||
| 'h29.attn.c_attn.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h29.attn.c_attn.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h29.attn.c_proj.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h29.attn.c_proj.w': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h29.ln_1.b': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h29.ln_1.g': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h29.ln_2.b': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta', | |||||
| 'h29.ln_2.g': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h29.mlp.c_fc.b': 'gpt2_decoder.layers.29.feedforward.c_fc.bias', | |||||
| 'h29.mlp.c_fc.w': 'gpt2_decoder.layers.29.feedforward.c_fc.weight', | |||||
| 'h29.mlp.c_proj.b': 'gpt2_decoder.layers.29.feedforward.c_proj.bias', | |||||
| 'h29.mlp.c_proj.w': 'gpt2_decoder.layers.29.feedforward.c_proj.weight', | |||||
| 'h30.attn.c_attn.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h30.attn.c_attn.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h30.attn.c_proj.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h30.attn.c_proj.w': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h30.ln_1.b': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h30.ln_1.g': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h30.ln_2.b': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta', | |||||
| 'h30.ln_2.g': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h30.mlp.c_fc.b': 'gpt2_decoder.layers.30.feedforward.c_fc.bias', | |||||
| 'h30.mlp.c_fc.w': 'gpt2_decoder.layers.30.feedforward.c_fc.weight', | |||||
| 'h30.mlp.c_proj.b': 'gpt2_decoder.layers.30.feedforward.c_proj.bias', | |||||
| 'h30.mlp.c_proj.w': 'gpt2_decoder.layers.30.feedforward.c_proj.weight', | |||||
| 'h31.attn.c_attn.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h31.attn.c_attn.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h31.attn.c_proj.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h31.attn.c_proj.w': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h31.ln_1.b': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h31.ln_1.g': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h31.ln_2.b': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta', | |||||
| 'h31.ln_2.g': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h31.mlp.c_fc.b': 'gpt2_decoder.layers.31.feedforward.c_fc.bias', | |||||
| 'h31.mlp.c_fc.w': 'gpt2_decoder.layers.31.feedforward.c_fc.weight', | |||||
| 'h31.mlp.c_proj.b': 'gpt2_decoder.layers.31.feedforward.c_proj.bias', | |||||
| 'h31.mlp.c_proj.w': 'gpt2_decoder.layers.31.feedforward.c_proj.weight', | |||||
| 'h32.attn.c_attn.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h32.attn.c_attn.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h32.attn.c_proj.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h32.attn.c_proj.w': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h32.ln_1.b': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h32.ln_1.g': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h32.ln_2.b': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta', | |||||
| 'h32.ln_2.g': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h32.mlp.c_fc.b': 'gpt2_decoder.layers.32.feedforward.c_fc.bias', | |||||
| 'h32.mlp.c_fc.w': 'gpt2_decoder.layers.32.feedforward.c_fc.weight', | |||||
| 'h32.mlp.c_proj.b': 'gpt2_decoder.layers.32.feedforward.c_proj.bias', | |||||
| 'h32.mlp.c_proj.w': 'gpt2_decoder.layers.32.feedforward.c_proj.weight', | |||||
| 'h33.attn.c_attn.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h33.attn.c_attn.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h33.attn.c_proj.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h33.attn.c_proj.w': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h33.ln_1.b': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h33.ln_1.g': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h33.ln_2.b': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta', | |||||
| 'h33.ln_2.g': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h33.mlp.c_fc.b': 'gpt2_decoder.layers.33.feedforward.c_fc.bias', | |||||
| 'h33.mlp.c_fc.w': 'gpt2_decoder.layers.33.feedforward.c_fc.weight', | |||||
| 'h33.mlp.c_proj.b': 'gpt2_decoder.layers.33.feedforward.c_proj.bias', | |||||
| 'h33.mlp.c_proj.w': 'gpt2_decoder.layers.33.feedforward.c_proj.weight', | |||||
| 'h34.attn.c_attn.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h34.attn.c_attn.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h34.attn.c_proj.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h34.attn.c_proj.w': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h34.ln_1.b': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h34.ln_1.g': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h34.ln_2.b': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta', | |||||
| 'h34.ln_2.g': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h34.mlp.c_fc.b': 'gpt2_decoder.layers.34.feedforward.c_fc.bias', | |||||
| 'h34.mlp.c_fc.w': 'gpt2_decoder.layers.34.feedforward.c_fc.weight', | |||||
| 'h34.mlp.c_proj.b': 'gpt2_decoder.layers.34.feedforward.c_proj.bias', | |||||
| 'h34.mlp.c_proj.w': 'gpt2_decoder.layers.34.feedforward.c_proj.weight', | |||||
| 'h35.attn.c_attn.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h35.attn.c_attn.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h35.attn.c_proj.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h35.attn.c_proj.w': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h35.ln_1.b': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h35.ln_1.g': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h35.ln_2.b': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta', | |||||
| 'h35.ln_2.g': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h35.mlp.c_fc.b': 'gpt2_decoder.layers.35.feedforward.c_fc.bias', | |||||
| 'h35.mlp.c_fc.w': 'gpt2_decoder.layers.35.feedforward.c_fc.weight', | |||||
| 'h35.mlp.c_proj.b': 'gpt2_decoder.layers.35.feedforward.c_proj.bias', | |||||
| 'h35.mlp.c_proj.w': 'gpt2_decoder.layers.35.feedforward.c_proj.weight', | |||||
| 'ln_f.b': 'layer_norm.layer_norm.gamma', | |||||
| 'ln_f.g': 'layer_norm.layer_norm.beta', | |||||
| 'wpe': 'gpt2_embedding_postprocess.position_embedding_table', | |||||
| 'wte': 'gpt2_embedding_lookup.embedding_table' | |||||
| } # transfer dictionary | |||||
| trans_dict_py = { | |||||
| 'h.0.attn.c_attn.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.0.attn.c_attn.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.0.attn.c_proj.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.0.attn.c_proj.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.0.ln_1.bias': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.0.ln_1.weight': 'gpt2_decoder.layers.0.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.0.ln_2.bias': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.0.ln_2.weight': 'gpt2_decoder.layers.0.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.0.mlp.c_fc.bias': 'gpt2_decoder.layers.0.feedforward.c_fc.bias', | |||||
| 'h.0.mlp.c_fc.weight': 'gpt2_decoder.layers.0.feedforward.c_fc.weight', | |||||
| 'h.0.mlp.c_proj.bias': 'gpt2_decoder.layers.0.feedforward.c_proj.bias', | |||||
| 'h.0.mlp.c_proj.weight': 'gpt2_decoder.layers.0.feedforward.c_proj.weight', | |||||
| 'h.1.attn.c_attn.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.1.attn.c_attn.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.1.attn.c_proj.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.1.attn.c_proj.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.1.ln_1.bias': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.1.ln_1.weight': 'gpt2_decoder.layers.1.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.1.ln_2.bias': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.1.ln_2.weight': 'gpt2_decoder.layers.1.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.1.mlp.c_fc.bias': 'gpt2_decoder.layers.1.feedforward.c_fc.bias', | |||||
| 'h.1.mlp.c_fc.weight': 'gpt2_decoder.layers.1.feedforward.c_fc.weight', | |||||
| 'h.1.mlp.c_proj.bias': 'gpt2_decoder.layers.1.feedforward.c_proj.bias', | |||||
| 'h.1.mlp.c_proj.weight': 'gpt2_decoder.layers.1.feedforward.c_proj.weight', | |||||
| 'h.2.attn.c_attn.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.2.attn.c_attn.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.2.attn.c_proj.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.2.attn.c_proj.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.2.ln_1.bias': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.2.ln_1.weight': 'gpt2_decoder.layers.2.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.2.ln_2.bias': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.2.ln_2.weight': 'gpt2_decoder.layers.2.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.2.mlp.c_fc.bias': 'gpt2_decoder.layers.2.feedforward.c_fc.bias', | |||||
| 'h.2.mlp.c_fc.weight': 'gpt2_decoder.layers.2.feedforward.c_fc.weight', | |||||
| 'h.2.mlp.c_proj.bias': 'gpt2_decoder.layers.2.feedforward.c_proj.bias', | |||||
| 'h.2.mlp.c_proj.weight': 'gpt2_decoder.layers.2.feedforward.c_proj.weight', | |||||
| 'h.3.attn.c_attn.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.3.attn.c_attn.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.3.attn.c_proj.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.3.attn.c_proj.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.3.ln_1.bias': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.3.ln_1.weight': 'gpt2_decoder.layers.3.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.3.ln_2.bias': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.3.ln_2.weight': 'gpt2_decoder.layers.3.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.3.mlp.c_fc.bias': 'gpt2_decoder.layers.3.feedforward.c_fc.bias', | |||||
| 'h.3.mlp.c_fc.weight': 'gpt2_decoder.layers.3.feedforward.c_fc.weight', | |||||
| 'h.3.mlp.c_proj.bias': 'gpt2_decoder.layers.3.feedforward.c_proj.bias', | |||||
| 'h.3.mlp.c_proj.weight': 'gpt2_decoder.layers.3.feedforward.c_proj.weight', | |||||
| 'h.4.attn.c_attn.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.4.attn.c_attn.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.4.attn.c_proj.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.4.attn.c_proj.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.4.ln_1.bias': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.4.ln_1.weight': 'gpt2_decoder.layers.4.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.4.ln_2.bias': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.4.ln_2.weight': 'gpt2_decoder.layers.4.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.4.mlp.c_fc.bias': 'gpt2_decoder.layers.4.feedforward.c_fc.bias', | |||||
| 'h.4.mlp.c_fc.weight': 'gpt2_decoder.layers.4.feedforward.c_fc.weight', | |||||
| 'h.4.mlp.c_proj.bias': 'gpt2_decoder.layers.4.feedforward.c_proj.bias', | |||||
| 'h.4.mlp.c_proj.weight': 'gpt2_decoder.layers.4.feedforward.c_proj.weight', | |||||
| 'h.5.attn.c_attn.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.5.attn.c_attn.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.5.attn.c_proj.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.5.attn.c_proj.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.5.ln_1.bias': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.5.ln_1.weight': 'gpt2_decoder.layers.5.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.5.ln_2.bias': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.5.ln_2.weight': 'gpt2_decoder.layers.5.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.5.mlp.c_fc.bias': 'gpt2_decoder.layers.5.feedforward.c_fc.bias', | |||||
| 'h.5.mlp.c_fc.weight': 'gpt2_decoder.layers.5.feedforward.c_fc.weight', | |||||
| 'h.5.mlp.c_proj.bias': 'gpt2_decoder.layers.5.feedforward.c_proj.bias', | |||||
| 'h.5.mlp.c_proj.weight': 'gpt2_decoder.layers.5.feedforward.c_proj.weight', | |||||
| 'h.6.attn.c_attn.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.6.attn.c_attn.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.6.attn.c_proj.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.6.attn.c_proj.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.6.ln_1.bias': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.6.ln_1.weight': 'gpt2_decoder.layers.6.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.6.ln_2.bias': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.6.ln_2.weight': 'gpt2_decoder.layers.6.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.6.mlp.c_fc.bias': 'gpt2_decoder.layers.6.feedforward.c_fc.bias', | |||||
| 'h.6.mlp.c_fc.weight': 'gpt2_decoder.layers.6.feedforward.c_fc.weight', | |||||
| 'h.6.mlp.c_proj.bias': 'gpt2_decoder.layers.6.feedforward.c_proj.bias', | |||||
| 'h.6.mlp.c_proj.weight': 'gpt2_decoder.layers.6.feedforward.c_proj.weight', | |||||
| 'h.7.attn.c_attn.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.7.attn.c_attn.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.7.attn.c_proj.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.7.attn.c_proj.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.7.ln_1.bias': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.7.ln_1.weight': 'gpt2_decoder.layers.7.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.7.ln_2.bias': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.7.ln_2.weight': 'gpt2_decoder.layers.7.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.7.mlp.c_fc.bias': 'gpt2_decoder.layers.7.feedforward.c_fc.bias', | |||||
| 'h.7.mlp.c_fc.weight': 'gpt2_decoder.layers.7.feedforward.c_fc.weight', | |||||
| 'h.7.mlp.c_proj.bias': 'gpt2_decoder.layers.7.feedforward.c_proj.bias', | |||||
| 'h.7.mlp.c_proj.weight': 'gpt2_decoder.layers.7.feedforward.c_proj.weight', | |||||
| 'h.8.attn.c_attn.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.8.attn.c_attn.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.8.attn.c_proj.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.8.attn.c_proj.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.8.ln_1.bias': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.8.ln_1.weight': 'gpt2_decoder.layers.8.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.8.ln_2.bias': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.8.ln_2.weight': 'gpt2_decoder.layers.8.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.8.mlp.c_fc.bias': 'gpt2_decoder.layers.8.feedforward.c_fc.bias', | |||||
| 'h.8.mlp.c_fc.weight': 'gpt2_decoder.layers.8.feedforward.c_fc.weight', | |||||
| 'h.8.mlp.c_proj.bias': 'gpt2_decoder.layers.8.feedforward.c_proj.bias', | |||||
| 'h.8.mlp.c_proj.weight': 'gpt2_decoder.layers.8.feedforward.c_proj.weight', | |||||
| 'h.9.attn.c_attn.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.9.attn.c_attn.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.9.attn.c_proj.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.9.attn.c_proj.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.9.ln_1.bias': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.9.ln_1.weight': 'gpt2_decoder.layers.9.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.9.ln_2.bias': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.9.ln_2.weight': 'gpt2_decoder.layers.9.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.9.mlp.c_fc.bias': 'gpt2_decoder.layers.9.feedforward.c_fc.bias', | |||||
| 'h.9.mlp.c_fc.weight': 'gpt2_decoder.layers.9.feedforward.c_fc.weight', | |||||
| 'h.9.mlp.c_proj.bias': 'gpt2_decoder.layers.9.feedforward.c_proj.bias', | |||||
| 'h.9.mlp.c_proj.weight': 'gpt2_decoder.layers.9.feedforward.c_proj.weight', | |||||
| 'h.10.attn.c_attn.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.10.attn.c_attn.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.10.attn.c_proj.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.10.attn.c_proj.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.10.ln_1.bias': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.10.ln_1.weight': 'gpt2_decoder.layers.10.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.10.ln_2.bias': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.10.ln_2.weight': 'gpt2_decoder.layers.10.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.10.mlp.c_fc.bias': 'gpt2_decoder.layers.10.feedforward.c_fc.bias', | |||||
| 'h.10.mlp.c_fc.weight': 'gpt2_decoder.layers.10.feedforward.c_fc.weight', | |||||
| 'h.10.mlp.c_proj.bias': 'gpt2_decoder.layers.10.feedforward.c_proj.bias', | |||||
| 'h.10.mlp.c_proj.weight': 'gpt2_decoder.layers.10.feedforward.c_proj.weight', | |||||
| 'h.11.attn.c_attn.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.11.attn.c_attn.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.11.attn.c_proj.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.11.attn.c_proj.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.11.ln_1.bias': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.11.ln_1.weight': 'gpt2_decoder.layers.11.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.11.ln_2.bias': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.11.ln_2.weight': 'gpt2_decoder.layers.11.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.11.mlp.c_fc.bias': 'gpt2_decoder.layers.11.feedforward.c_fc.bias', | |||||
| 'h.11.mlp.c_fc.weight': 'gpt2_decoder.layers.11.feedforward.c_fc.weight', | |||||
| 'h.11.mlp.c_proj.bias': 'gpt2_decoder.layers.11.feedforward.c_proj.bias', | |||||
| 'h.11.mlp.c_proj.weight': 'gpt2_decoder.layers.11.feedforward.c_proj.weight', | |||||
| 'h.12.attn.c_attn.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.12.attn.c_attn.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.12.attn.c_proj.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.12.attn.c_proj.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.12.ln_1.bias': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.12.ln_1.weight': 'gpt2_decoder.layers.12.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.12.ln_2.bias': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.12.ln_2.weight': 'gpt2_decoder.layers.12.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.12.mlp.c_fc.bias': 'gpt2_decoder.layers.12.feedforward.c_fc.bias', | |||||
| 'h.12.mlp.c_fc.weight': 'gpt2_decoder.layers.12.feedforward.c_fc.weight', | |||||
| 'h.12.mlp.c_proj.bias': 'gpt2_decoder.layers.12.feedforward.c_proj.bias', | |||||
| 'h.12.mlp.c_proj.weight': 'gpt2_decoder.layers.12.feedforward.c_proj.weight', | |||||
| 'h.13.attn.c_attn.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.13.attn.c_attn.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.13.attn.c_proj.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.13.attn.c_proj.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.13.ln_1.bias': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.13.ln_1.weight': 'gpt2_decoder.layers.13.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.13.ln_2.bias': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.13.ln_2.weight': 'gpt2_decoder.layers.13.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.13.mlp.c_fc.bias': 'gpt2_decoder.layers.13.feedforward.c_fc.bias', | |||||
| 'h.13.mlp.c_fc.weight': 'gpt2_decoder.layers.13.feedforward.c_fc.weight', | |||||
| 'h.13.mlp.c_proj.bias': 'gpt2_decoder.layers.13.feedforward.c_proj.bias', | |||||
| 'h.13.mlp.c_proj.weight': 'gpt2_decoder.layers.13.feedforward.c_proj.weight', | |||||
| 'h.14.attn.c_attn.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.14.attn.c_attn.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.14.attn.c_proj.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.14.attn.c_proj.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.14.ln_1.bias': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.14.ln_1.weight': 'gpt2_decoder.layers.14.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.14.ln_2.bias': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.14.ln_2.weight': 'gpt2_decoder.layers.14.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.14.mlp.c_fc.bias': 'gpt2_decoder.layers.14.feedforward.c_fc.bias', | |||||
| 'h.14.mlp.c_fc.weight': 'gpt2_decoder.layers.14.feedforward.c_fc.weight', | |||||
| 'h.14.mlp.c_proj.bias': 'gpt2_decoder.layers.14.feedforward.c_proj.bias', | |||||
| 'h.14.mlp.c_proj.weight': 'gpt2_decoder.layers.14.feedforward.c_proj.weight', | |||||
| 'h.15.attn.c_attn.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.15.attn.c_attn.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.15.attn.c_proj.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.15.attn.c_proj.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.15.ln_1.bias': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.15.ln_1.weight': 'gpt2_decoder.layers.15.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.15.ln_2.bias': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.15.ln_2.weight': 'gpt2_decoder.layers.15.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.15.mlp.c_fc.bias': 'gpt2_decoder.layers.15.feedforward.c_fc.bias', | |||||
| 'h.15.mlp.c_fc.weight': 'gpt2_decoder.layers.15.feedforward.c_fc.weight', | |||||
| 'h.15.mlp.c_proj.bias': 'gpt2_decoder.layers.15.feedforward.c_proj.bias', | |||||
| 'h.15.mlp.c_proj.weight': 'gpt2_decoder.layers.15.feedforward.c_proj.weight', | |||||
| 'h.16.attn.c_attn.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.16.attn.c_attn.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.16.attn.c_proj.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.16.attn.c_proj.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.16.ln_1.bias': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.16.ln_1.weight': 'gpt2_decoder.layers.16.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.16.ln_2.bias': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.16.ln_2.weight': 'gpt2_decoder.layers.16.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.16.mlp.c_fc.bias': 'gpt2_decoder.layers.16.feedforward.c_fc.bias', | |||||
| 'h.16.mlp.c_fc.weight': 'gpt2_decoder.layers.16.feedforward.c_fc.weight', | |||||
| 'h.16.mlp.c_proj.bias': 'gpt2_decoder.layers.16.feedforward.c_proj.bias', | |||||
| 'h.16.mlp.c_proj.weight': 'gpt2_decoder.layers.16.feedforward.c_proj.weight', | |||||
| 'h.17.attn.c_attn.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.17.attn.c_attn.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.17.attn.c_proj.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.17.attn.c_proj.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.17.ln_1.bias': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.17.ln_1.weight': 'gpt2_decoder.layers.17.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.17.ln_2.bias': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.17.ln_2.weight': 'gpt2_decoder.layers.17.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.17.mlp.c_fc.bias': 'gpt2_decoder.layers.17.feedforward.c_fc.bias', | |||||
| 'h.17.mlp.c_fc.weight': 'gpt2_decoder.layers.17.feedforward.c_fc.weight', | |||||
| 'h.17.mlp.c_proj.bias': 'gpt2_decoder.layers.17.feedforward.c_proj.bias', | |||||
| 'h.17.mlp.c_proj.weight': 'gpt2_decoder.layers.17.feedforward.c_proj.weight', | |||||
| 'h.18.attn.c_attn.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.18.attn.c_attn.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.18.attn.c_proj.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.18.attn.c_proj.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.18.ln_1.bias': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.18.ln_1.weight': 'gpt2_decoder.layers.18.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.18.ln_2.bias': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.18.ln_2.weight': 'gpt2_decoder.layers.18.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.18.mlp.c_fc.bias': 'gpt2_decoder.layers.18.feedforward.c_fc.bias', | |||||
| 'h.18.mlp.c_fc.weight': 'gpt2_decoder.layers.18.feedforward.c_fc.weight', | |||||
| 'h.18.mlp.c_proj.bias': 'gpt2_decoder.layers.18.feedforward.c_proj.bias', | |||||
| 'h.18.mlp.c_proj.weight': 'gpt2_decoder.layers.18.feedforward.c_proj.weight', | |||||
| 'h.19.attn.c_attn.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.19.attn.c_attn.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.19.attn.c_proj.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.19.attn.c_proj.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.19.ln_1.bias': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.19.ln_1.weight': 'gpt2_decoder.layers.19.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.19.ln_2.bias': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.19.ln_2.weight': 'gpt2_decoder.layers.19.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.19.mlp.c_fc.bias': 'gpt2_decoder.layers.19.feedforward.c_fc.bias', | |||||
| 'h.19.mlp.c_fc.weight': 'gpt2_decoder.layers.19.feedforward.c_fc.weight', | |||||
| 'h.19.mlp.c_proj.bias': 'gpt2_decoder.layers.19.feedforward.c_proj.bias', | |||||
| 'h.19.mlp.c_proj.weight': 'gpt2_decoder.layers.19.feedforward.c_proj.weight', | |||||
| 'h.20.attn.c_attn.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.20.attn.c_attn.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.20.attn.c_proj.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.20.attn.c_proj.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.20.ln_1.bias': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.20.ln_1.weight': 'gpt2_decoder.layers.20.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.20.ln_2.bias': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.20.ln_2.weight': 'gpt2_decoder.layers.20.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.20.mlp.c_fc.bias': 'gpt2_decoder.layers.20.feedforward.c_fc.bias', | |||||
| 'h.20.mlp.c_fc.weight': 'gpt2_decoder.layers.20.feedforward.c_fc.weight', | |||||
| 'h.20.mlp.c_proj.bias': 'gpt2_decoder.layers.20.feedforward.c_proj.bias', | |||||
| 'h.20.mlp.c_proj.weight': 'gpt2_decoder.layers.20.feedforward.c_proj.weight', | |||||
| 'h.21.attn.c_attn.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.21.attn.c_attn.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.21.attn.c_proj.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.21.attn.c_proj.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.21.ln_1.bias': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.21.ln_1.weight': 'gpt2_decoder.layers.21.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.21.ln_2.bias': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.21.ln_2.weight': 'gpt2_decoder.layers.21.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.21.mlp.c_fc.bias': 'gpt2_decoder.layers.21.feedforward.c_fc.bias', | |||||
| 'h.21.mlp.c_fc.weight': 'gpt2_decoder.layers.21.feedforward.c_fc.weight', | |||||
| 'h.21.mlp.c_proj.bias': 'gpt2_decoder.layers.21.feedforward.c_proj.bias', | |||||
| 'h.21.mlp.c_proj.weight': 'gpt2_decoder.layers.21.feedforward.c_proj.weight', | |||||
| 'h.22.attn.c_attn.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.22.attn.c_attn.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.22.attn.c_proj.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.22.attn.c_proj.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.22.ln_1.bias': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.22.ln_1.weight': 'gpt2_decoder.layers.22.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.22.ln_2.bias': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.22.ln_2.weight': 'gpt2_decoder.layers.22.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.22.mlp.c_fc.bias': 'gpt2_decoder.layers.22.feedforward.c_fc.bias', | |||||
| 'h.22.mlp.c_fc.weight': 'gpt2_decoder.layers.22.feedforward.c_fc.weight', | |||||
| 'h.22.mlp.c_proj.bias': 'gpt2_decoder.layers.22.feedforward.c_proj.bias', | |||||
| 'h.22.mlp.c_proj.weight': 'gpt2_decoder.layers.22.feedforward.c_proj.weight', | |||||
| 'h.23.attn.c_attn.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.23.attn.c_attn.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.23.attn.c_proj.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.23.attn.c_proj.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.23.ln_1.bias': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.23.ln_1.weight': 'gpt2_decoder.layers.23.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.23.ln_2.bias': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.23.ln_2.weight': 'gpt2_decoder.layers.23.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.23.mlp.c_fc.bias': 'gpt2_decoder.layers.23.feedforward.c_fc.bias', | |||||
| 'h.23.mlp.c_fc.weight': 'gpt2_decoder.layers.23.feedforward.c_fc.weight', | |||||
| 'h.23.mlp.c_proj.bias': 'gpt2_decoder.layers.23.feedforward.c_proj.bias', | |||||
| 'h.23.mlp.c_proj.weight': 'gpt2_decoder.layers.23.feedforward.c_proj.weight', | |||||
| 'h.24.attn.c_attn.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.24.attn.c_attn.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.24.attn.c_proj.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.24.attn.c_proj.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.24.ln_1.bias': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.24.ln_1.weight': 'gpt2_decoder.layers.24.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.24.ln_2.bias': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.24.ln_2.weight': 'gpt2_decoder.layers.24.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.24.mlp.c_fc.bias': 'gpt2_decoder.layers.24.feedforward.c_fc.bias', | |||||
| 'h.24.mlp.c_fc.weight': 'gpt2_decoder.layers.24.feedforward.c_fc.weight', | |||||
| 'h.24.mlp.c_proj.bias': 'gpt2_decoder.layers.24.feedforward.c_proj.bias', | |||||
| 'h.24.mlp.c_proj.weight': 'gpt2_decoder.layers.24.feedforward.c_proj.weight', | |||||
| 'h.25.attn.c_attn.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.25.attn.c_attn.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.25.attn.c_proj.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.25.attn.c_proj.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.25.ln_1.bias': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.25.ln_1.weight': 'gpt2_decoder.layers.25.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.25.ln_2.bias': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.25.ln_2.weight': 'gpt2_decoder.layers.25.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.25.mlp.c_fc.bias': 'gpt2_decoder.layers.25.feedforward.c_fc.bias', | |||||
| 'h.25.mlp.c_fc.weight': 'gpt2_decoder.layers.25.feedforward.c_fc.weight', | |||||
| 'h.25.mlp.c_proj.bias': 'gpt2_decoder.layers.25.feedforward.c_proj.bias', | |||||
| 'h.25.mlp.c_proj.weight': 'gpt2_decoder.layers.25.feedforward.c_proj.weight', | |||||
| 'h.26.attn.c_attn.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.26.attn.c_attn.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.26.attn.c_proj.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.26.attn.c_proj.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.26.ln_1.bias': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.26.ln_1.weight': 'gpt2_decoder.layers.26.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.26.ln_2.bias': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.26.ln_2.weight': 'gpt2_decoder.layers.26.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.26.mlp.c_fc.bias': 'gpt2_decoder.layers.26.feedforward.c_fc.bias', | |||||
| 'h.26.mlp.c_fc.weight': 'gpt2_decoder.layers.26.feedforward.c_fc.weight', | |||||
| 'h.26.mlp.c_proj.bias': 'gpt2_decoder.layers.26.feedforward.c_proj.bias', | |||||
| 'h.26.mlp.c_proj.weight': 'gpt2_decoder.layers.26.feedforward.c_proj.weight', | |||||
| 'h.27.attn.c_attn.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.27.attn.c_attn.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.27.attn.c_proj.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.27.attn.c_proj.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.27.ln_1.bias': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.27.ln_1.weight': 'gpt2_decoder.layers.27.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.27.ln_2.bias': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.27.ln_2.weight': 'gpt2_decoder.layers.27.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.27.mlp.c_fc.bias': 'gpt2_decoder.layers.27.feedforward.c_fc.bias', | |||||
| 'h.27.mlp.c_fc.weight': 'gpt2_decoder.layers.27.feedforward.c_fc.weight', | |||||
| 'h.27.mlp.c_proj.bias': 'gpt2_decoder.layers.27.feedforward.c_proj.bias', | |||||
| 'h.27.mlp.c_proj.weight': 'gpt2_decoder.layers.27.feedforward.c_proj.weight', | |||||
| 'h.28.attn.c_attn.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.28.attn.c_attn.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.28.attn.c_proj.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.28.attn.c_proj.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.28.ln_1.bias': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.28.ln_1.weight': 'gpt2_decoder.layers.28.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.28.ln_2.bias': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.28.ln_2.weight': 'gpt2_decoder.layers.28.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.28.mlp.c_fc.bias': 'gpt2_decoder.layers.28.feedforward.c_fc.bias', | |||||
| 'h.28.mlp.c_fc.weight': 'gpt2_decoder.layers.28.feedforward.c_fc.weight', | |||||
| 'h.28.mlp.c_proj.bias': 'gpt2_decoder.layers.28.feedforward.c_proj.bias', | |||||
| 'h.28.mlp.c_proj.weight': 'gpt2_decoder.layers.28.feedforward.c_proj.weight', | |||||
| 'h.29.attn.c_attn.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.29.attn.c_attn.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.29.attn.c_proj.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.29.attn.c_proj.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.29.ln_1.bias': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.29.ln_1.weight': 'gpt2_decoder.layers.29.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.29.ln_2.bias': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.29.ln_2.weight': 'gpt2_decoder.layers.29.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.29.mlp.c_fc.bias': 'gpt2_decoder.layers.29.feedforward.c_fc.bias', | |||||
| 'h.29.mlp.c_fc.weight': 'gpt2_decoder.layers.29.feedforward.c_fc.weight', | |||||
| 'h.29.mlp.c_proj.bias': 'gpt2_decoder.layers.29.feedforward.c_proj.bias', | |||||
| 'h.29.mlp.c_proj.weight': 'gpt2_decoder.layers.29.feedforward.c_proj.weight', | |||||
| 'h.30.attn.c_attn.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.30.attn.c_attn.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.30.attn.c_proj.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.30.attn.c_proj.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.30.ln_1.bias': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.30.ln_1.weight': 'gpt2_decoder.layers.30.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.30.ln_2.bias': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.30.ln_2.weight': 'gpt2_decoder.layers.30.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.30.mlp.c_fc.bias': 'gpt2_decoder.layers.30.feedforward.c_fc.bias', | |||||
| 'h.30.mlp.c_fc.weight': 'gpt2_decoder.layers.30.feedforward.c_fc.weight', | |||||
| 'h.30.mlp.c_proj.bias': 'gpt2_decoder.layers.30.feedforward.c_proj.bias', | |||||
| 'h.30.mlp.c_proj.weight': 'gpt2_decoder.layers.30.feedforward.c_proj.weight', | |||||
| 'h.31.attn.c_attn.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.31.attn.c_attn.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.31.attn.c_proj.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.31.attn.c_proj.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.31.ln_1.bias': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.31.ln_1.weight': 'gpt2_decoder.layers.31.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.31.ln_2.bias': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.31.ln_2.weight': 'gpt2_decoder.layers.31.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.31.mlp.c_fc.bias': 'gpt2_decoder.layers.31.feedforward.c_fc.bias', | |||||
| 'h.31.mlp.c_fc.weight': 'gpt2_decoder.layers.31.feedforward.c_fc.weight', | |||||
| 'h.31.mlp.c_proj.bias': 'gpt2_decoder.layers.31.feedforward.c_proj.bias', | |||||
| 'h.31.mlp.c_proj.weight': 'gpt2_decoder.layers.31.feedforward.c_proj.weight', | |||||
| 'h.32.attn.c_attn.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.32.attn.c_attn.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.32.attn.c_proj.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.32.attn.c_proj.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.32.ln_1.bias': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.32.ln_1.weight': 'gpt2_decoder.layers.32.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.32.ln_2.bias': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.32.ln_2.weight': 'gpt2_decoder.layers.32.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.32.mlp.c_fc.bias': 'gpt2_decoder.layers.32.feedforward.c_fc.bias', | |||||
| 'h.32.mlp.c_fc.weight': 'gpt2_decoder.layers.32.feedforward.c_fc.weight', | |||||
| 'h.32.mlp.c_proj.bias': 'gpt2_decoder.layers.32.feedforward.c_proj.bias', | |||||
| 'h.32.mlp.c_proj.weight': 'gpt2_decoder.layers.32.feedforward.c_proj.weight', | |||||
| 'h.33.attn.c_attn.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.33.attn.c_attn.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.33.attn.c_proj.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.33.attn.c_proj.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.33.ln_1.bias': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.33.ln_1.weight': 'gpt2_decoder.layers.33.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.33.ln_2.bias': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.33.ln_2.weight': 'gpt2_decoder.layers.33.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.33.mlp.c_fc.bias': 'gpt2_decoder.layers.33.feedforward.c_fc.bias', | |||||
| 'h.33.mlp.c_fc.weight': 'gpt2_decoder.layers.33.feedforward.c_fc.weight', | |||||
| 'h.33.mlp.c_proj.bias': 'gpt2_decoder.layers.33.feedforward.c_proj.bias', | |||||
| 'h.33.mlp.c_proj.weight': 'gpt2_decoder.layers.33.feedforward.c_proj.weight', | |||||
| 'h.34.attn.c_attn.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.34.attn.c_attn.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.34.attn.c_proj.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.34.attn.c_proj.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.34.ln_1.bias': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.34.ln_1.weight': 'gpt2_decoder.layers.34.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.34.ln_2.bias': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.34.ln_2.weight': 'gpt2_decoder.layers.34.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.34.mlp.c_fc.bias': 'gpt2_decoder.layers.34.feedforward.c_fc.bias', | |||||
| 'h.34.mlp.c_fc.weight': 'gpt2_decoder.layers.34.feedforward.c_fc.weight', | |||||
| 'h.34.mlp.c_proj.bias': 'gpt2_decoder.layers.34.feedforward.c_proj.bias', | |||||
| 'h.34.mlp.c_proj.weight': 'gpt2_decoder.layers.34.feedforward.c_proj.weight', | |||||
| 'h.35.attn.c_attn.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.bias', | |||||
| 'h.35.attn.c_attn.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_attn.weight', | |||||
| 'h.35.attn.c_proj.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.bias', | |||||
| 'h.35.attn.c_proj.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.masked_self_attention.c_proj.weight', | |||||
| 'h.35.ln_1.bias': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.beta', | |||||
| 'h.35.ln_1.weight': 'gpt2_decoder.layers.35.masked_multi_head_attention.layer_norm.layer_norm.gamma', | |||||
| 'h.35.ln_2.bias': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.beta', | |||||
| 'h.35.ln_2.weight': 'gpt2_decoder.layers.35.feedforward.layernorm.layer_norm.gamma', | |||||
| 'h.35.mlp.c_fc.bias': 'gpt2_decoder.layers.35.feedforward.c_fc.bias', | |||||
| 'h.35.mlp.c_fc.weight': 'gpt2_decoder.layers.35.feedforward.c_fc.weight', | |||||
| 'h.35.mlp.c_proj.bias': 'gpt2_decoder.layers.35.feedforward.c_proj.bias', | |||||
| 'h.35.mlp.c_proj.weight': 'gpt2_decoder.layers.35.feedforward.c_proj.weight', | |||||
| 'ln_f.bias': 'layer_norm.layer_norm.gamma', | |||||
| 'ln_f.weight': 'layer_norm.layer_norm.beta', | |||||
| 'wpe.weight': 'gpt2_embedding_postprocess.position_embedding_table', | |||||
| 'wte.weight': 'gpt2_embedding_lookup.embedding_table' | |||||
| } | |||||
| @@ -0,0 +1,148 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """create mindrecord data for Children's Book Test task""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import argparse | |||||
| import collections | |||||
| import logging | |||||
| import numpy as np | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from src.utils.tokenization import Tokenizer | |||||
| def create_instance(tokenizer, text, max_length=None, num_choice=None): | |||||
| """A single sample instance for cbt task.""" | |||||
| text = text.replace(" \t ", "\t ") | |||||
| sentence = text.strip().split("\t") | |||||
| context_length = len(tokenizer.encode(sentence[0])) | |||||
| whole_sentence = sentence[0] + sentence[1] | |||||
| whole_sentence = whole_sentence.strip() | |||||
| assert whole_sentence != "" | |||||
| print(" | whole sentence: ", whole_sentence) | |||||
| ids = tokenizer.encode(whole_sentence) | |||||
| input_length = len(ids) | |||||
| pair_ids = None | |||||
| output = tokenizer.prepare_for_model(ids=ids, | |||||
| pair_ids=pair_ids, | |||||
| add_special_tokens=True, | |||||
| max_length=max_length, | |||||
| padding=True, | |||||
| truncate_direction="RIGHT", | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True) | |||||
| output["length"] = [context_length + 1] + [input_length + 1] | |||||
| gold_answer_id = int(sentence[2]) | |||||
| assert gold_answer_id < 10 | |||||
| output["mc_labels"] = gold_answer_id | |||||
| for name, value in output.items(): | |||||
| print(name) | |||||
| print(value) | |||||
| print("==================================") | |||||
| return output | |||||
| def write_instance_to_file(writer, instance): | |||||
| """write the instance to file""" | |||||
| input_ids = instance["input_ids"] | |||||
| input_mask = instance["attention_mask"] | |||||
| assert len(input_ids) == len(input_mask) | |||||
| length = instance["length"] # list | |||||
| mc_labels = instance["mc_labels"] | |||||
| features = collections.OrderedDict() | |||||
| features["input_ids"] = np.asarray(input_ids) | |||||
| features["input_mask"] = np.asarray(input_mask) | |||||
| features["input_length"] = np.asarray(length) | |||||
| features["mc_labels"] = mc_labels | |||||
| writer.write_raw_data([features]) | |||||
| return features | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--input_file", type=str, required=True, default="", help='Input raw text file. ') | |||||
| parser.add_argument("--output_file", type=str, required=True, default="", help='Output MindRecord file. ') | |||||
| parser.add_argument("--num_splits", type=int, default=1, | |||||
| help='The MindRecord file will be split into the number of partition. ') | |||||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||||
| parser.add_argument("--num_choice", type=int, required=True, help='Number of choices. ') | |||||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||||
| args = parser.parse_args() | |||||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||||
| num_choice = args.num_choice | |||||
| input_file = args.input_file | |||||
| logging.info("***** Reading from input files *****") | |||||
| logging.info("Input File: %s", input_file) | |||||
| output_file = args.output_file | |||||
| logging.info("***** Writing to output files *****") | |||||
| logging.info("Output File: %s", output_file) | |||||
| writer = FileWriter(output_file, args.num_splits) | |||||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||||
| "input_length": {"type": "int64", "shape": [-1]}, | |||||
| "mc_labels": {"type": "int64"} | |||||
| } | |||||
| writer.add_schema(data_schema, "cbt-schema") | |||||
| total_written = 0 | |||||
| total_read = 0 | |||||
| logging.info("***** Reading from %s *****", input_file) | |||||
| with open(input_file, "r") as f: | |||||
| while True: | |||||
| line = f.readline() | |||||
| if not line: | |||||
| break | |||||
| total_read += 1 | |||||
| if total_read % 500 == 0: | |||||
| logging.info("%d ...", total_read) | |||||
| output = create_instance(tokenizer, line, args.max_seq_length, num_choice) | |||||
| features = write_instance_to_file(writer, instance=output) | |||||
| total_written += 1 | |||||
| if total_written <= 20: | |||||
| logging.info("***** Example *****") | |||||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||||
| for feature_name in features.keys(): | |||||
| feature = features[feature_name] | |||||
| logging.info("%s: %s", feature_name, feature) | |||||
| writer.commit() | |||||
| logging.info("Wrote %d total instances", total_written) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,140 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """create mindrecord data for LAMBADA task""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import argparse | |||||
| import collections | |||||
| import logging | |||||
| import numpy as np | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from src.utils.tokenization import Tokenizer | |||||
| def create_instance(tokenizer, text, max_length=None): | |||||
| """A single sample instance for LAMBADA task.""" | |||||
| text = text.replace(" \t ", "\t ") | |||||
| sentence = text.strip().split("\t") | |||||
| context_length = len(tokenizer.encode(sentence[0])) | |||||
| whole_sentence = sentence[0] + sentence[1] | |||||
| whole_sentence = whole_sentence.strip() | |||||
| assert whole_sentence != "" | |||||
| print(" | whole sentence: ", whole_sentence) | |||||
| ids = tokenizer.encode(whole_sentence) | |||||
| input_length = len(ids) | |||||
| pair_ids = None | |||||
| output = tokenizer.prepare_for_model(ids=ids, | |||||
| pair_ids=pair_ids, | |||||
| add_special_tokens=True, | |||||
| max_length=max_length, | |||||
| padding=True, | |||||
| truncate_direction="RIGHT", | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True) | |||||
| # input_length = <bos> + text_length, not include <eos> | |||||
| output["length"] = [context_length + 1] + [input_length + 1] | |||||
| for k, v in output.items(): | |||||
| print(k) | |||||
| print(v) | |||||
| print("==================================") | |||||
| return output | |||||
| def write_instance_to_file(writer, instance): | |||||
| """write the instance to file""" | |||||
| input_ids = instance["input_ids"] | |||||
| input_mask = instance["attention_mask"] | |||||
| assert len(input_ids) == len(input_mask) | |||||
| length = instance["length"] # list | |||||
| features = collections.OrderedDict() | |||||
| features["input_ids"] = np.asarray(input_ids) | |||||
| features["input_mask"] = np.asarray(input_mask) | |||||
| features["input_length"] = np.asarray(length) | |||||
| writer.write_raw_data([features]) | |||||
| return features | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ') | |||||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ') | |||||
| parser.add_argument("--num_splits", type=int, default=1, | |||||
| help='The MindRecord file will be split into the number of partition. ') | |||||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||||
| args = parser.parse_args() | |||||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||||
| input_file = args.input_file | |||||
| logging.info("***** Reading from input files *****") | |||||
| logging.info("Input File: %s", input_file) | |||||
| output_file = args.output_file | |||||
| logging.info("***** Writing to output files *****") | |||||
| logging.info("Output File: %s", output_file) | |||||
| writer = FileWriter(output_file, args.num_splits) | |||||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||||
| "input_length": {"type": "int64", "shape": [-1]}, | |||||
| } | |||||
| writer.add_schema(data_schema, "lambada-schema") | |||||
| total_written = 0 | |||||
| total_read = 0 | |||||
| logging.info("***** Reading from %s *****", input_file) | |||||
| with open(input_file, "r") as f: | |||||
| while True: | |||||
| line = f.readline() | |||||
| if not line: | |||||
| break | |||||
| total_read += 1 | |||||
| if total_read % 500 == 0: | |||||
| logging.info("%d ...", total_read) | |||||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||||
| features = write_instance_to_file(writer, instance=output) | |||||
| total_written += 1 | |||||
| if total_written <= 20: | |||||
| logging.info("***** Example *****") | |||||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||||
| for feature_name in features.keys(): | |||||
| feature = features[feature_name] | |||||
| logging.info("%s: %s", feature_name, feature) | |||||
| writer.commit() | |||||
| logging.info("Wrote %d total instances", total_written) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,126 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """create mindrecord data for LM task""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import argparse | |||||
| import collections | |||||
| import logging | |||||
| import numpy as np | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from src.utils.tokenization import Tokenizer | |||||
| def create_instance(tokenizer, text, max_length=None): | |||||
| """A single sample instance for LM task.""" | |||||
| sentence = text.strip().split("\t") | |||||
| ids = tokenizer.encode(sentence[0]) | |||||
| pair_ids = None | |||||
| if len(sentence) == 2: | |||||
| pair_ids = tokenizer.encode(sentence[1]) | |||||
| output = tokenizer.prepare_for_model(ids=ids, | |||||
| pair_ids=pair_ids, | |||||
| add_special_tokens=True, | |||||
| max_length=max_length, | |||||
| padding=True, | |||||
| truncate_direction="LEFT", | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True) | |||||
| return output | |||||
| def write_instance_to_file(writer, instance): | |||||
| """write the instance to file""" | |||||
| input_ids = instance["input_ids"] | |||||
| input_mask = instance["attention_mask"] | |||||
| label_ids = instance["input_ids"] | |||||
| assert len(input_ids) == len(label_ids) | |||||
| features = collections.OrderedDict() | |||||
| features["input_ids"] = np.asarray(input_ids) | |||||
| features["input_mask"] = np.asarray(input_mask) | |||||
| features["label_ids"] = np.asarray(label_ids) | |||||
| writer.write_raw_data([features]) | |||||
| return features | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ') | |||||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ') | |||||
| parser.add_argument("--num_splits", type=int, default=1, | |||||
| help='The MindRecord file will be split into the number of partition. ') | |||||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ') | |||||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||||
| args = parser.parse_args() | |||||
| tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file) | |||||
| input_file = args.input_file | |||||
| logging.info("***** Reading from input files *****") | |||||
| logging.info("Input File: %s", input_file) | |||||
| output_file = args.output_file | |||||
| logging.info("***** Writing to output files *****") | |||||
| logging.info("Output File: %s", output_file) | |||||
| writer = FileWriter(output_file, args.num_splits) | |||||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||||
| "label_ids": {"type": "int64", "shape": [-1]} | |||||
| } | |||||
| writer.add_schema(data_schema, "lm-schema") | |||||
| total_written = 0 | |||||
| total_read = 0 | |||||
| logging.info("***** Reading from %s *****", input_file) | |||||
| with open(input_file, "r") as f: | |||||
| while True: | |||||
| line = f.readline() | |||||
| if not line: | |||||
| break | |||||
| total_read += 1 | |||||
| if total_read % 500 == 0: | |||||
| logging.info("%d ...", total_read) | |||||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||||
| features = write_instance_to_file(writer, instance=output) | |||||
| total_written += 1 | |||||
| if total_written <= 20: | |||||
| logging.info("***** Example *****") | |||||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||||
| for feature_name in features.keys(): | |||||
| feature = features[feature_name] | |||||
| logging.info("%s: %s", feature_name, feature) | |||||
| writer.commit() | |||||
| logging.info("Wrote %d total instances", total_written) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,130 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """create mindrecord data for Summarization task""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import argparse | |||||
| import collections | |||||
| import logging | |||||
| import numpy as np | |||||
| from mindspore.mindrecord import FileWriter | |||||
| from src.utils import tokenization | |||||
| def create_instance(tokenizer, text, max_length=None): | |||||
| """A single sample instance for Summarization task.""" | |||||
| sentence = text.strip().split("\t") | |||||
| ids = tokenizer.encode(sentence[0]) | |||||
| pair_ids = None | |||||
| if len(sentence) == 2: | |||||
| pair_ids = tokenizer.encode(sentence[1]) | |||||
| if len(sentence) >= 3: | |||||
| article = sentence[0] | |||||
| for i in range(1, len(sentence) - 1): | |||||
| article += sentence[i] | |||||
| ids = tokenizer.encode(article) | |||||
| pair_ids = tokenizer.encode(sentence[-1]) | |||||
| output = tokenizer.prepare_for_model(ids=ids, | |||||
| pair_ids=pair_ids, | |||||
| add_special_tokens=True, | |||||
| max_length=max_length, | |||||
| padding=True, | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True) | |||||
| return output | |||||
| def write_instance_to_file(writer, instance): | |||||
| """write the instance to file""" | |||||
| input_ids = instance["input_ids"] | |||||
| input_mask = instance["attention_mask"] | |||||
| label_ids = instance["input_ids"] | |||||
| assert len(input_ids) == len(label_ids) | |||||
| features = collections.OrderedDict() | |||||
| features["input_ids"] = np.asarray(input_ids) | |||||
| features["input_mask"] = np.asarray(input_mask) | |||||
| features["label_ids"] = np.asarray(label_ids) | |||||
| writer.write_raw_data([features]) | |||||
| return features | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument("--input_file", type=str, required=True, help='Input raw text file.') | |||||
| parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.') | |||||
| parser.add_argument("--num_splits", type=int, default=1, | |||||
| help='The MindRecord file will be split into the number of partition. ') | |||||
| parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length.') | |||||
| parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ') | |||||
| parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ') | |||||
| parser.add_argument("--mode", type=str, required=True, default='cnn_dailymail', help='mode of dataset creation') | |||||
| args = parser.parse_args() | |||||
| tokenizer = tokenization.Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file, mode=args.mode) | |||||
| input_file = args.input_file | |||||
| logging.info("***** Reading from input files *****") | |||||
| logging.info("Input File: %s", input_file) | |||||
| output_file = args.output_file | |||||
| logging.info("***** Writing to output files *****") | |||||
| logging.info("Output File: %s", output_file) | |||||
| writer = FileWriter(output_file, args.num_splits) | |||||
| data_schema = {"input_ids": {"type": "int64", "shape": [-1]}, | |||||
| "input_mask": {"type": "int64", "shape": [-1]}, | |||||
| "label_ids": {"type": "int64", "shape": [-1]} | |||||
| } | |||||
| writer.add_schema(data_schema, "wikitext2-schema") | |||||
| total_written = 0 | |||||
| total_read = 0 | |||||
| logging.info("***** Reading from %s *****", input_file) | |||||
| with open(input_file, "r") as f: | |||||
| while True: | |||||
| line = f.readline() | |||||
| if not line: | |||||
| break | |||||
| total_read += 1 | |||||
| if total_read % 500 == 0: | |||||
| logging.info("%d ...", total_read) | |||||
| output = create_instance(tokenizer, line, args.max_seq_length) | |||||
| features = write_instance_to_file(writer, instance=output) | |||||
| total_written += 1 | |||||
| if total_written <= 20: | |||||
| logging.info("***** Example *****") | |||||
| logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1])) | |||||
| logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:])) | |||||
| for feature_name in features.keys(): | |||||
| feature = features[feature_name] | |||||
| logging.info("%s: %s", feature_name, feature) | |||||
| writer.commit() | |||||
| logging.info("Wrote %d total instances", total_written) | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,59 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """download the CNN & DailyMail for Summarization task""" | |||||
| import argparse | |||||
| from datasets import load_dataset | |||||
| def generate_txt(url, split_, number=None, version="3.0.0"): | |||||
| """ | |||||
| generate txt file of cnn_dailymail dataset | |||||
| Args: | |||||
| url (str): directory of dataset txt file. | |||||
| split_ (str): test or train. | |||||
| number (int): top-n number of samples from dataset | |||||
| version (str): "3.0.0" by default | |||||
| """ | |||||
| cnn = load_dataset("cnn_dailymail", version, split=split_) | |||||
| if number == -1: | |||||
| number = len(cnn) | |||||
| f = open(url + split_ + '.txt', 'w') | |||||
| for idx in range(number): | |||||
| article = cnn[idx]['article'] | |||||
| article = article.replace('\n', ' ') | |||||
| highlights = cnn[idx]['highlights'] | |||||
| highlights = highlights.replace('\n', ' ') | |||||
| f.write(article + "\t" + highlights + '\n') | |||||
| f.close() | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='Download CNN_Dailymail 3.0.0 using datasets by Huggingface') | |||||
| parser.add_argument('--dir', type=str, default="", help="directory of dataset") | |||||
| parser.add_argument('--split', type=str, default='test', help="[test,train]") | |||||
| parser.add_argument('--num', type=int, default=-1, | |||||
| help=" number of samples by default order. " | |||||
| "If num is -1, it will download whole dataset. Default: -1") | |||||
| args = parser.parse_args() | |||||
| data_directory = args.dir | |||||
| split = args.split | |||||
| num = args.num | |||||
| generate_txt(url=data_directory, split_=split, number=num) | |||||
| @@ -0,0 +1,135 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Evaluation reading comprehension result with additional answer.""" | |||||
| import json | |||||
| import re | |||||
| import string | |||||
| import argparse | |||||
| from collections import Counter | |||||
| def get_normalize_answer_token(string_): | |||||
| """normalize the answer token, Lower text and remove punctuation, article and extra whitespace""" | |||||
| def remove_articles(text): | |||||
| regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) | |||||
| return re.sub(regex, ' ', text) | |||||
| def white_space_fix(text): | |||||
| return ' '.join(text.split()) | |||||
| def remove_punc(text): | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(char for char in text if char not in exclude) | |||||
| def lower(text): | |||||
| return text.lower() | |||||
| return white_space_fix(remove_articles(remove_punc(lower(string_)))).split() | |||||
| def calculate_f1(pred_answer, gold_answer): | |||||
| """ | |||||
| calculate final F1 score with addition answer | |||||
| """ | |||||
| f1_score = 0 | |||||
| pred_answer = get_normalize_answer_token(pred_answer) | |||||
| gold_answer = get_normalize_answer_token(gold_answer) | |||||
| common = Counter(pred_answer) & Counter(gold_answer) | |||||
| num_same = sum(common.values()) | |||||
| # the number of same tokens between pred_answer and gold_answer | |||||
| precision = 1.0 * num_same / len(pred_answer) if pred_answer.strip() == "" else 0 | |||||
| recall = 1.0 * num_same / len(gold_answer) if gold_answer.strip() == "" else 0 | |||||
| if pred_answer.strip() == "" and gold_answer.strip() == "": | |||||
| f1_score = 1 | |||||
| else: | |||||
| f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0 | |||||
| return f1_score | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser(description="All Task dataset preprocessing") | |||||
| parser.add_argument("--input_file", type=str, default="", | |||||
| help="The log file path of evaluation in Reading Comprehension. ") | |||||
| parser.add_argument("--addition_file", type=str, default="", help="Coqa-dev-v1.0.json path") | |||||
| args_opt = parser.parse_args() | |||||
| input_file = args_opt.input_file | |||||
| addition_file = args_opt.addition_file | |||||
| find_word = 'Pred_answer:' | |||||
| find_word_length = len(find_word) | |||||
| pred_answer_list = [] | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| while True: | |||||
| line = f.readline() | |||||
| if not line: | |||||
| break | |||||
| index = line.find(find_word) | |||||
| if index != -1: | |||||
| pred_answer = line[index + find_word_length:].strip() | |||||
| pred_answer_list.append(pred_answer) | |||||
| dataset = json.load(open(addition_file)) | |||||
| pred_answer_num = 0 | |||||
| total_f1score = 0 | |||||
| average_f1score = 0 | |||||
| data_num = len(pred_answer_list) | |||||
| for story in dataset['data']: | |||||
| questions = story['questions'] | |||||
| multiple_answers = [story['answers']] | |||||
| multiple_answers += story['additional_answers'].values() | |||||
| for question in questions: | |||||
| pred_a = pred_answer_list[pred_answer_num] | |||||
| turn_id = question['turn_id'] | |||||
| max_score = 0 | |||||
| max_group = 0 | |||||
| flag = 0 | |||||
| for i, answer in enumerate(multiple_answers): | |||||
| gold_a = answer[turn_id - 1]['input_text'] | |||||
| score = calculate_f1(pred_a, gold_a) | |||||
| if score > max_score: | |||||
| max_score = score | |||||
| max_group = i | |||||
| # calculate the max score in multiple answers and record it's number. | |||||
| gold_a = multiple_answers[max_group][turn_id - 1]['input_text'] | |||||
| pred_answer_num += 1 | |||||
| total_f1score += max_score | |||||
| average_f1score = total_f1score / pred_answer_num | |||||
| print('==================== data {} ===================='.format(pred_answer_num)) | |||||
| print('| Gold_answer:{}'.format(gold_a)) | |||||
| print('| Pred_answer:{}'.format(pred_a)) | |||||
| print('| F1_Score:{:.8f}'.format(average_f1score)) | |||||
| print('=====================================================\n') | |||||
| if pred_answer_num >= data_num: | |||||
| flag = 1 | |||||
| break | |||||
| # Stop flag | |||||
| if flag: | |||||
| print('Finished evaluation with addition answer! \n') | |||||
| print("********************** Testing Finished **********************") | |||||
| print('| Test file name: {}'.format(input_file)) | |||||
| print('| Final F1 score: {:.8f}'.format(average_f1score)) | |||||
| print('| Total data num: {}'.format(pred_answer_num)) | |||||
| print("**************************************************************") | |||||
| break | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,270 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for Children's Book Test task. | |||||
| """ | |||||
| import argparse | |||||
| import time | |||||
| import numpy as np | |||||
| from mindspore import context | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CBT | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.utils.metric_method import Accuracy | |||||
| from src.dataset import create_cbt_dataset, create_language_model_dataset | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.utils.task_utils import calculate_choice_prob_for_cbt | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||||
| epoch_num: the number of epoch. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_" + "cbt_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" + str(epoch_num) +\ | |||||
| "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load pretrained parameter successfully!\n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("==================== Starting Finetuning ====================") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("==================== Finetuning Success ====================") | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, num_choice=None): | |||||
| """ | |||||
| Do evaluation for CBT task. | |||||
| Args: | |||||
| dataset: the eval dataset. | |||||
| network: the network with loss. | |||||
| metric: the evaluation method. | |||||
| load_checkpoint_path: the file path which saved finetuned model checkpoint. | |||||
| eval_type: | |||||
| num_choice: | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if metric.lower() == "accuracy": | |||||
| print("Prepare to calculate the accuracy score ...") | |||||
| gpt2_cbt = network(config=gpt2_net_cfg, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False | |||||
| ) | |||||
| gpt2_cbt.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| if eval_type == "zero-shot": | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(gpt2_cbt, final_param_dict) | |||||
| print("load pretrained parameter successfully!\n") | |||||
| elif eval_type == "finetuned": | |||||
| load_param_into_net(gpt2_cbt, param_dict) | |||||
| print("load finetuned parameter successfully!\n") | |||||
| else: | |||||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||||
| model = Model(gpt2_cbt) | |||||
| callback = Accuracy() | |||||
| columns_list = ["input_ids", "input_mask", "input_length", "mc_labels"] | |||||
| print("==================== [ACC] Testing ====================") | |||||
| num_data = 1 | |||||
| all_choice_prob = [] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, input_mask, input_length, mc_labels = input_data | |||||
| print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||||
| # print("mc_labels: {}".format(mc_labels)) # [batch_size] | |||||
| logits = model.predict(input_ids, input_mask) | |||||
| # choice_prob_list [batch_size] | |||||
| choice_prob_list = calculate_choice_prob_for_cbt(logits=logits, | |||||
| batch_size=gpt2_net_cfg.batch_size, | |||||
| input_length=input_length, | |||||
| input_ids=input_ids) | |||||
| all_choice_prob.append(choice_prob_list) | |||||
| if (num_data * gpt2_net_cfg.batch_size) % num_choice == 0: | |||||
| all_choice_prob_np = np.array(all_choice_prob) | |||||
| all_choice_prob_np = all_choice_prob_np.reshape((-1, num_choice)) | |||||
| print("| all_choice_prob_np: ", all_choice_prob_np) | |||||
| print("| all_choice_prob_np shape: ", all_choice_prob_np.shape) | |||||
| mc_labels = np.array([mc_labels.asnumpy()[0]]) | |||||
| callback.update(all_choice_prob_np, mc_labels) | |||||
| all_choice_prob = [] | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**************************************************************") | |||||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||||
| callback.acc_num / callback.total_num)) | |||||
| print("********************** Testing Finished **********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported, support: [Accuracy]") | |||||
| def run_cbt_task(): | |||||
| """ | |||||
| run Children's Book Test (CBT) task | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate CBT task") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=1, | |||||
| help="ID of target device. ") | |||||
| parser.add_argument("--num_choice", type=int, default=10, | |||||
| help="The number of choice in CBT task. ") | |||||
| parser.add_argument("--metric_method", type=str, default="Accuracy", | |||||
| help="The eval method including [Accuracy]. Default: Accuracy.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: true.") | |||||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--epoch_num", type=int, default=1, | |||||
| help="Epoch number. Default: 1.") | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the finetuned checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path for train.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path for evaluation.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device_target = args_opt.device_target | |||||
| if device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=device_target, | |||||
| device_id=args_opt.device_id, | |||||
| max_call_depth=3000) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| gpt2_loss = GPT2CBT(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| print("============== Start Loading Train Dataset ============") | |||||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| print("============== Start Loading Evaluation Dataset ============") | |||||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||||
| eval_dataset = create_cbt_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.eval_data_file_path) | |||||
| do_eval(eval_dataset, GPT2CBT, metric, load_finetune_ckpt_path, args_opt.eval_type, args_opt.num_choice) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_cbt_task() | |||||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,293 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for Reading Comprehension task. | |||||
| """ | |||||
| import argparse | |||||
| import time | |||||
| from mindspore import context | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2CoQA | |||||
| from src.GPT2ForReadComprehension import GPT2CoQAModel | |||||
| from src.utils.metric_method import F1 | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.dataset import create_language_model_dataset | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.utils.tokenization import Tokenizer | |||||
| from src.GPT2_generation import GenerateForReadComprehension | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||||
| epoch_num: the number of epoch. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_rc_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load the pretrained parameter successfully! \n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("=================== Starting Training For Translation Task ====================") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("=================== Translation Training Success ====================") | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="", | |||||
| generate_length=1, top_k=1, top_p=1.0, temperature=1.0): | |||||
| """ | |||||
| Do evaluation on Translation | |||||
| Args: | |||||
| dataset: the eval dataset. | |||||
| network: the network with loss. | |||||
| metric: the evaluation method. | |||||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if metric.lower() == "f1": | |||||
| print("Prepare to calculate the BLEU score ...") | |||||
| gpt2_rc = network(config=gpt2_net_cfg, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False) | |||||
| gpt2_rc.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| if eval_type == "zero-shot": | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(gpt2_rc, final_param_dict) | |||||
| print("load pretrained parameter successfully!\n") | |||||
| elif eval_type == "finetuned": | |||||
| load_param_into_net(gpt2_rc, param_dict) | |||||
| print("load finetuned parameter successfully!\n") | |||||
| else: | |||||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||||
| model = Model(gpt2_rc) | |||||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||||
| callback = F1() | |||||
| rc_generator = GenerateForReadComprehension(decoder=model, | |||||
| config=gpt2_net_cfg, | |||||
| tokenizer=tokenizer, | |||||
| generate_length=generate_length, | |||||
| topk_num=top_k, | |||||
| topp_prob=float(top_p), | |||||
| temperature=float(temperature) | |||||
| ) | |||||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||||
| print("==================== [F1] Testing ====================") | |||||
| num_data = 0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, _, label_ids = input_data | |||||
| print("input_ids shape: {}".format(input_ids.shape)) | |||||
| print("label_ids shape: {}".format(label_ids.shape)) | |||||
| passage, pred_answer, gold_answer = rc_generator.generate_for_read_comprehension(input_ids) | |||||
| for batch_id in range(gpt2_net_cfg.batch_size): | |||||
| print("============== [F1] {} ================".format(num_data + 1)) | |||||
| print(" | Passage:{}".format(passage[batch_id])) | |||||
| print(" | Gold_answer:{}".format(gold_answer[batch_id])) | |||||
| print(" | Pred_answer:{}".format(pred_answer[batch_id])) | |||||
| pred = callback.get_normalize_answer_token(pred_answer[batch_id]) | |||||
| gold = callback.get_normalize_answer_token(gold_answer[batch_id]) | |||||
| callback.update(pred, gold) | |||||
| num_data += 1 | |||||
| average_f1_score = callback.f1_score / num_data | |||||
| print("============== Evaluation =================") | |||||
| print("| Avg F1 Score:{:.8f}".format(average_f1_score)) | |||||
| print("=============================================\n\n") | |||||
| print("********************** Testing Finished **********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported in Reading Comprehension task, support: [F1]") | |||||
| def run_Readcomprehension(): | |||||
| ''' | |||||
| run Readcomprehension task | |||||
| ''' | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate translation") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=0, | |||||
| help="ID of target device. ") | |||||
| parser.add_argument("--metric_method", type=str, default="F1", | |||||
| help="The eval method including [F1]. Default: F1.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: false.") | |||||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--epoch_num", type=int, default=1, | |||||
| help="Epoch number. Default: 1.") | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||||
| help="pretrained vocab and merge file path.") | |||||
| parser.add_argument("--generate_length", type=int, default=55, | |||||
| help="The generation length of translation sentence.") | |||||
| parser.add_argument("--top_k", type=int, default=1, | |||||
| help="Parameter for Top-K sampling.") | |||||
| parser.add_argument("--top_p", type=str, default="1.0", | |||||
| help="parameter for Top-P sampling.") | |||||
| parser.add_argument("--temperature", type=str, default="1.0", | |||||
| help="Parameter for generation, greater if generation more diverse. ") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device_target = args_opt.device_target | |||||
| if device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=device_target, | |||||
| device_id=args_opt.device_id, | |||||
| max_call_depth=3000) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| gpt2_loss = GPT2CoQA(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| print("============== Start Loading Translation Train Dataset ==============") | |||||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| print("============ Start Loading Translation Evaluation Dataset ============") | |||||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.eval_data_file_path) | |||||
| do_eval(eval_dataset, GPT2CoQAModel, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||||
| args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p, | |||||
| args_opt.temperature) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_Readcomprehension() | |||||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,328 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for LAMBADA task. | |||||
| """ | |||||
| import argparse | |||||
| import math | |||||
| import time | |||||
| from mindspore import context | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Lambada | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.utils.metric_method import LastWordAccuracy | |||||
| from src.dataset import create_language_model_dataset, create_lambada_control_dataset | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.utils.task_utils import get_final_word_label | |||||
| from src.utils.tokenization import Tokenizer | |||||
| from src.GPT2_generation import GenerateForLambada | |||||
| from src.utils.CrossEntropy import CrossEntropyCalculationWithMask | |||||
| from src.utils.get_config_setting import get_train_setting, get_model_setting | |||||
| from src.utils.task_utils import calculate_final_word_loss | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrain model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetune model checkpoint. | |||||
| epoch_num: the number of epoch | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_" + "lambada_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load pretrained parameter successfully!\n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("==================== Starting Finetuning ====================") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("==================== Finetuning Success ====================") | |||||
| def eval_result_print(metric="accuracy", callback=None): | |||||
| """ | |||||
| Print eval result. | |||||
| """ | |||||
| if metric.lower() == "accuracy": | |||||
| print("acc_num {}, total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||||
| callback.acc_num / callback.total_num)) | |||||
| else: | |||||
| raise ValueError("metric method not supported, support: [accuracy]") | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, stop_word_file="", | |||||
| generate_length_dynamic=True, tokenizer_file_path=""): | |||||
| """ | |||||
| Do eval | |||||
| Args: | |||||
| dataset: the eval dataset. | |||||
| network: the network with loss. | |||||
| metric: the evaluation method. | |||||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||||
| eval_type: the eval type, i.e. zero-shot, finetuned. | |||||
| generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True. | |||||
| tokenizer_file_path: the tokenizer file path for vocab file and merge file. | |||||
| stop_word_file: stop word file for calculating Accuracy. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||||
| gpt2_lambada = network(config=gpt2_net_cfg, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False) | |||||
| gpt2_lambada.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| if eval_type == "zero-shot": | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(gpt2_lambada, final_param_dict) | |||||
| print("load pretrained parameter successfully!\n") | |||||
| elif eval_type == "finetuned": | |||||
| load_param_into_net(gpt2_lambada, param_dict) | |||||
| print("load finetuned parameter successfully!\n") | |||||
| model = Model(gpt2_lambada) | |||||
| if metric.lower() == "accuracy": | |||||
| print("Prepare to calculate the accuracy score ...") | |||||
| callback = LastWordAccuracy() | |||||
| columns_list = ["input_ids", "input_mask", "input_length"] | |||||
| print("==================== [ACC] Testing ====================") | |||||
| lambada_generator = GenerateForLambada(decoder=model, | |||||
| config=gpt2_net_cfg, | |||||
| tokenizer=tokenizer, | |||||
| generate_length_dynamic=generate_length_dynamic, | |||||
| max_iterations=200, | |||||
| stop_word_file=stop_word_file) | |||||
| num_data = 1 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, input_mask, input_length = input_data | |||||
| print("| [ACC] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||||
| logits = model.predict(input_ids, input_mask) | |||||
| predict_str = lambada_generator.generate_for_lambada(input_ids=input_ids, | |||||
| logits=logits, | |||||
| input_length=input_length) | |||||
| label_str = get_final_word_label(input_ids=input_ids, input_length=input_length, tokenizer=tokenizer) | |||||
| callback.update(predict_str, label_str) | |||||
| eval_result_print(metric, callback) | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**********************************************************") | |||||
| eval_result_print(metric, callback) | |||||
| print("******************** Testing Finished ********************") | |||||
| elif metric.lower() == "ppl": | |||||
| print("Prepare to calculate the ppl score ...") | |||||
| cross_entropy = CrossEntropyCalculationWithMask(is_training=True, | |||||
| num_labels=gpt2_net_cfg.vocab_size, | |||||
| config=gpt2_net_cfg) | |||||
| columns_list = ["input_ids", "input_mask", "input_length"] | |||||
| num_data = 1 | |||||
| total_loss = 0.0 | |||||
| print("==================== [PPL] Testing ====================") | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, input_mask, input_length = input_data | |||||
| print("| [PPL] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||||
| logits = model.predict(input_ids, input_mask) # (batch_size, seq_len, vocab_size) | |||||
| avg_batch_loss = calculate_final_word_loss(logits, | |||||
| gpt2_net_cfg.batch_size, | |||||
| input_ids, | |||||
| input_length, | |||||
| cross_entropy) | |||||
| total_loss += avg_batch_loss | |||||
| avg_total_loss = total_loss / num_data | |||||
| print(" | Current AVG loss:", avg_total_loss) | |||||
| print(" | Current AVG ppl:", math.exp(avg_total_loss)) | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**********************************************************") | |||||
| print("Average PPL: {:.6f}".format(math.exp(avg_total_loss))) | |||||
| print("******************** Testing Finished ********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported, support: [accuracy, ppl]") | |||||
| def run_lambada(): | |||||
| """ | |||||
| Run Lambada task. | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate languagemodel") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=2, | |||||
| help="ID of target device.") | |||||
| parser.add_argument("--metric_method", type=str, default="PPL", | |||||
| help="The eval method including [Accuracy, PPL]. Default: Accuracy.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: false.") | |||||
| parser.add_argument("--eval_type", type=str, default="finetuned", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--epoch_num", type=int, default=3, | |||||
| help="Epoch number. Default: 1.") | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="false", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--generate_length_dynamically", type=str, default="true", | |||||
| help="Enable generate_length_Dynamically. Default: true.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path.") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path.") | |||||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||||
| help="pretrained vocab and merge file path.") | |||||
| parser.add_argument("--stop_word_file_path", type=str, default="", | |||||
| help="The stop word file path.") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device = args_opt.device_target | |||||
| if device == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| gpt2_loss = GPT2Lambada(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| get_train_setting(cfg) | |||||
| get_model_setting(cfg, gpt2_net_cfg) | |||||
| print("============== Start Loading Train Dataset ============") | |||||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| get_model_setting(cfg, gpt2_net_cfg) | |||||
| print("============== Start Loading Evaluation Dataset ============") | |||||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||||
| eval_dataset = create_lambada_control_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.eval_data_file_path) | |||||
| do_eval(eval_dataset, GPT2Lambada, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||||
| args_opt.stop_word_file_path, args_opt.generate_length_dynamically, args_opt.tokenizer_file_path) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_lambada() | |||||
| print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,255 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for Language Modeling task. | |||||
| """ | |||||
| import argparse | |||||
| import math | |||||
| import time | |||||
| from mindspore import context | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2LM | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.dataset import create_language_model_dataset | |||||
| from src.utils.get_config_setting import get_train_setting, get_model_setting | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||||
| epoch_num: the number of epoch. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_language_model_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load pretrained parameter successfully!\n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("==================== Starting Finetuning ====================") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("==================== Finetuning Success ====================") | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None): | |||||
| """ | |||||
| Do eval | |||||
| Args: | |||||
| dataset: the eval dataset. | |||||
| network: the network with loss. | |||||
| metric: the evaluation method. | |||||
| load_checkpoint_path: the file path which saved finetuned model checkpoint. | |||||
| eval_type: | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if metric.lower() == "ppl": | |||||
| print("Prepare to calculate the ppl score ...") | |||||
| gpt2_loss = network(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| gpt2_loss.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| if eval_type == "zero-shot": | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(gpt2_loss, final_param_dict) | |||||
| print("load pretrained parameter successfully!\n") | |||||
| elif eval_type == "finetuned": | |||||
| load_param_into_net(gpt2_loss, param_dict) | |||||
| print("load finetuned parameter successfully!\n") | |||||
| else: | |||||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||||
| model = Model(gpt2_loss) | |||||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||||
| print("==================== [PPL] Testing ====================") | |||||
| num_data = 1 | |||||
| total_loss = 0.0 | |||||
| avg_loss = 0.0 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, input_mask, label_ids = input_data | |||||
| loss = model.predict(input_ids, input_mask, label_ids) | |||||
| loss = float(loss.asnumpy()) | |||||
| total_loss += loss | |||||
| avg_loss = float(total_loss / num_data) | |||||
| print(" | Current Loss: {:.6f}".format(avg_loss)) | |||||
| print(" | Current PPL: {}\n\n".format(math.exp(avg_loss))) | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**************************************************************") | |||||
| print("Average Loss: {:.6f}".format(avg_loss)) | |||||
| print("Average PPL: {:.6f}".format(math.exp(avg_loss))) | |||||
| print("********************** Testing Finished **********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported, support: [ppl]") | |||||
| def run_languagemodel(): | |||||
| """ | |||||
| run Language Modeling task | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate language modelings task") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=1, | |||||
| help="ID of target device. ") | |||||
| parser.add_argument("--metric_method", type=str, default="PPL", | |||||
| help="The eval method including [PPL]. Default: PPL.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: true.") | |||||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--epoch_num", type=int, default=1, | |||||
| help="Epoch number. Default: 1.") | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the finetuned checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path for train.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path for evaluation.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device_target = args_opt.device_target | |||||
| if device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=device_target, | |||||
| device_id=args_opt.device_id, | |||||
| max_call_depth=3000) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| gpt2_loss = GPT2LM(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| get_train_setting(cfg) | |||||
| get_model_setting(cfg, gpt2_net_cfg) | |||||
| print("==================== Start Loading Train Dataset ==================") | |||||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| get_model_setting(cfg, gpt2_net_cfg) | |||||
| print("==================== Start Loading Evaluation Dataset ==================") | |||||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.eval_data_file_path) | |||||
| do_eval(eval_dataset, GPT2LM, metric, load_finetune_ckpt_path, args_opt.eval_type) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_languagemodel() | |||||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,296 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for Summarization task. | |||||
| """ | |||||
| import time | |||||
| import argparse | |||||
| from mindspore import context | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.GPT2ForSummarization import GPT2SummarizationModel | |||||
| from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.utils.metric_method import Rouge | |||||
| from src.dataset import create_language_model_dataset | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.utils.tokenization import Tokenizer | |||||
| from src.utils.task_utils import clean_hypo, modify_paramdict | |||||
| from src.GPT2_generation import GenerateForSummarization | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrain model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetune model checkpoint. | |||||
| epoch_num: the number of epoch | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list( | |||||
| filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load pretrained parameter successfully!\n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("============== Starting Finetuning ==============") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("============== Finetuning Success ==============") | |||||
| def eval_result_print(metric="Rouge", callback=None): | |||||
| """ | |||||
| print eval result | |||||
| """ | |||||
| if metric == "Rouge": | |||||
| print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}". | |||||
| format(callback.Rouge1 / callback.total_num, | |||||
| callback.Rouge2 / callback.total_num, | |||||
| callback.RougeL / callback.total_num, | |||||
| (callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num))) | |||||
| else: | |||||
| raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric))) | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="", | |||||
| top_k=None, top_p=None, temperature=None, generate_length=None): | |||||
| """ | |||||
| Do evaluation on summarization | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if metric.lower() == "rouge": | |||||
| print("Prepare to calculate the Rouge score ...") | |||||
| callback = Rouge() | |||||
| gpt2_loss = network(config=gpt2_net_cfg, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False) | |||||
| gpt2_loss.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.") | |||||
| load_param_into_net(gpt2_loss, reorganized_param_dict) | |||||
| # load nn.Cell into Model and initiate tokenizer and Sample | |||||
| model = Model(gpt2_loss) | |||||
| tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json', | |||||
| merge_file=tokenizer_file + 'gpt2-merges.txt') | |||||
| # load data and process text generation | |||||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||||
| summarization_generator = GenerateForSummarization(model, | |||||
| config=gpt2_net_cfg, | |||||
| tokenizer=tokenizer, | |||||
| select_sentence=3, | |||||
| eval_type=eval_type, | |||||
| topk=top_k, | |||||
| topp=float(top_p), | |||||
| temperature=float(temperature), | |||||
| generate_length=generate_length) | |||||
| num_data = 1 | |||||
| print("==================== [Summrization] Testing ====================") | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for value in columns_list: | |||||
| input_data.append(data[value]) | |||||
| input_ids, _, label_ids = input_data | |||||
| print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size())) | |||||
| print("input_ids shape: {}".format(input_ids.shape)) | |||||
| print("label_ids shape: {}".format(label_ids.shape)) | |||||
| hypothesis, ref = summarization_generator.generate_for_summarization(input_ids) | |||||
| if ref[0] == '' or ref[0] is None: | |||||
| print("Sorry ref_list is None, skip it!") | |||||
| continue | |||||
| print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n") | |||||
| for batch_idx in range(gpt2_net_cfg.batch_size): | |||||
| hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx]) | |||||
| for batch_idx in range(gpt2_net_cfg.batch_size): | |||||
| hypothesis[batch_idx] = hypothesis[batch_idx].lower() | |||||
| ref[batch_idx] = ref[batch_idx].lower() | |||||
| callback.update(hypothesis, ref) | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**********************************************************") | |||||
| eval_result_print(metric, callback) | |||||
| print("******************** Testing Finished ********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported in summarization, support: [Rouge]") | |||||
| def run_summarization(): | |||||
| """ | |||||
| Run Summarization task. | |||||
| """ | |||||
| # set argument parser | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization") | |||||
| # context and task settings | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=4, | |||||
| help="ID of target device.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: false.") | |||||
| parser.add_argument("--eval_type", type=str, default="finetuned", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--metric_method", type=str, default="Rouge", | |||||
| help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.") | |||||
| parser.add_argument("--epoch_num", type=int, default=2, | |||||
| help="Epoch number. Default: 2.") | |||||
| # dataset and params_dict file settings | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| # sampling settings | |||||
| parser.add_argument("--top_k", type=int, default=2, | |||||
| help="top k tokens chosen for sampling") | |||||
| parser.add_argument("--top_p", type=str, default="1.0", | |||||
| help="top p accumulated probability threshold for logit to be counted") | |||||
| parser.add_argument("--generate_length", type=int, default=100, | |||||
| help="the number of generated tokens.") | |||||
| parser.add_argument("--temperature", type=str, default="1.0", | |||||
| help="temperature on logits for sampling") | |||||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||||
| help="vocab & merge file path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| eval_type = args_opt.eval_type | |||||
| tokenizer_file = args_opt.tokenizer_file_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device = args_opt.device_target | |||||
| if device == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| if args_opt.do_train.lower() == "true": | |||||
| train_data_file_path = args_opt.train_data_file_path | |||||
| gpt2_loss = GPT2Summarization(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| print("============== Start Loading Train Dataset ============") | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| eval_dataset_file_path = args_opt.eval_data_file_path | |||||
| print("============== Start Loading Evaluation Dataset ============") | |||||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=eval_dataset_file_path) | |||||
| do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file, | |||||
| args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_summarization() | |||||
| print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,298 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 finetune and evaluation script for Translation task. | |||||
| """ | |||||
| import argparse | |||||
| import time | |||||
| from mindspore import context | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn import AdamWeightDecay, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.GPT2ForTranslation import GPT2TranslationModel | |||||
| from src.gpt2_for_finetune import GPT2FinetuneCell, GPT2Translation | |||||
| from src.finetune_eval_config import cfg, gpt2_net_cfg | |||||
| from src.dataset import create_language_model_dataset | |||||
| from src.utils.lr_schedule import GPT2LearningRate | |||||
| from src.utils.tokenization import Tokenizer | |||||
| from src.utils.metric_method import BLEU | |||||
| from src.GPT2_generation import GenerateForTranslation | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1): | |||||
| """ | |||||
| Do train | |||||
| Args: | |||||
| dataset: the train dataset. | |||||
| network: the network with loss | |||||
| load_checkpoint_path: the file path which saved pretrained model checkpoint. | |||||
| save_checkpoint_path: the file path which will save finetuned model checkpoint. | |||||
| epoch_num: the number of epoch. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| # optimizer | |||||
| if cfg.optimizer == 'AdamWeightDecay': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.AdamWeightDecay.power) | |||||
| params = network.trainable_params() | |||||
| decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) | |||||
| other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) | |||||
| group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, | |||||
| {'params': other_params, 'weight_decay': 0.0}] | |||||
| optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate, | |||||
| end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| power=cfg.Lamb.power) | |||||
| optimizer = Lamb(network.trainable_params(), lr_schedule) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| prefix_name = "gpt2_translation_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \ | |||||
| + str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=prefix_name, | |||||
| directory=None if save_checkpoint_path == "" else save_checkpoint_path, | |||||
| config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['gpt2.dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(network, final_param_dict) | |||||
| print("Load the pretrained parameter successfully! \n") | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| netwithgrads.set_train(True) | |||||
| loss_cb = LossMonitor(per_print_times=1) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb] | |||||
| print("=================== Starting Training For Translation Task ====================") | |||||
| model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False) | |||||
| print("=================== Translation Training Success ====================") | |||||
| def eval_result_print(metric="BLEU", callback=None): | |||||
| """ print eval result""" | |||||
| if metric == "BLEU": | |||||
| print(" | BLEU: {:.6f}".format(callback.bleu / float(callback.total_num))) | |||||
| else: | |||||
| raise ValueError("metric method '{}' not supported, support: [BLEU]. ".format(str(metric))) | |||||
| def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file_path="", | |||||
| generate_length=1, top_k=1, top_p=1.0, temperature=1.0): | |||||
| """ | |||||
| Do evaluation on Translation | |||||
| Args: | |||||
| dataset: the eval dataset. | |||||
| network: the network with loss. | |||||
| metric: the evaluation method. | |||||
| load_checkpoint_path: the file path which saved finetune model checkpoint. | |||||
| """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if metric.lower() == "bleu": | |||||
| print("Prepare to calculate the BLEU score ...") | |||||
| gpt2_translation = network(config=gpt2_net_cfg, | |||||
| is_training=False, | |||||
| use_one_hot_embeddings=False) | |||||
| gpt2_translation.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| if eval_type == "zero-shot": | |||||
| final_param_dict = {} | |||||
| for name, _ in param_dict.items(): | |||||
| final_param_dict['gpt2.' + name] = param_dict[name] | |||||
| final_param_dict['dense1.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| load_param_into_net(gpt2_translation, final_param_dict) | |||||
| print("load pretrained parameter successfully!\n") | |||||
| elif eval_type == "finetuned": | |||||
| load_param_into_net(gpt2_translation, param_dict) | |||||
| print("load finetuned parameter successfully!\n") | |||||
| else: | |||||
| raise ValueError("Evaluation type missed, eval_type should be [zero-shot, finetuned]") | |||||
| model = Model(gpt2_translation) | |||||
| tokenizer = Tokenizer(vocab_file=tokenizer_file_path + 'gpt2-vocab.json', | |||||
| merge_file=tokenizer_file_path + 'gpt2-merges.txt') | |||||
| callback = BLEU(tokenizer) | |||||
| translation_generator = GenerateForTranslation(decoder=model, | |||||
| config=gpt2_net_cfg, | |||||
| tokenizer=tokenizer, | |||||
| generate_length=1, | |||||
| use_hint=True, | |||||
| select_first_sentence=True, | |||||
| topk_num=top_k, | |||||
| topp_prob=float(top_p), | |||||
| temperature=float(temperature) | |||||
| ) | |||||
| columns_list = ["input_ids", "input_mask", "label_ids"] | |||||
| print("==================== [BLEU] Testing ====================") | |||||
| num_data = 1 | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(data[i]) | |||||
| input_ids, input_mask, label_ids = input_data | |||||
| print("| Data count: {}".format(num_data * gpt2_net_cfg.batch_size)) | |||||
| print("input_ids shape: {}".format(input_ids.shape)) | |||||
| print("input_mask shape: {}".format(input_mask.shape)) | |||||
| print("label_ids shape: {}".format(label_ids.shape)) | |||||
| ts_predict_list, ref_list = translation_generator.generate_for_translation(input_ids) | |||||
| print("| Batch Reference translation:\n{}\n".format(ref_list)) | |||||
| if ref_list == '' or ref_list is None: | |||||
| print("Sorry ref_list is None, skip it!") | |||||
| continue | |||||
| else: | |||||
| print(" | Batch Predict translation:\n{}\n".format(ts_predict_list)) | |||||
| callback.update(ref_list, ts_predict_list) | |||||
| num_data += 1 | |||||
| print("\n\n") | |||||
| print("**************************************************************") | |||||
| eval_result_print(metric, callback) | |||||
| print("********************** Testing Finished **********************") | |||||
| else: | |||||
| raise ValueError("metric method not supported in translation, support: [BLEU]") | |||||
| def run_translation(): | |||||
| """ | |||||
| run translation task | |||||
| """ | |||||
| parser = argparse.ArgumentParser(description="Finetune and Evaluate translation") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", | |||||
| help="Device type. Default: Ascend.") | |||||
| parser.add_argument("--device_id", type=int, default=0, | |||||
| help="ID of target device. ") | |||||
| parser.add_argument("--metric_method", type=str, default="BLEU", | |||||
| help="The eval method including [BLEU]. Default: BLEU.") | |||||
| parser.add_argument("--do_train", type=str, default="false", | |||||
| help="Enable train. Default: false.") | |||||
| parser.add_argument("--do_eval", type=str, default="true", | |||||
| help="Enable evaluation. Default: false.") | |||||
| parser.add_argument("--eval_type", type=str, default="zero-shot", | |||||
| help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.") | |||||
| parser.add_argument("--epoch_num", type=int, default=1, | |||||
| help="Epoch number. Default: 1.") | |||||
| parser.add_argument("--train_data_shuffle", type=str, default="true", | |||||
| help="Enable train data shuffle. Default: true.") | |||||
| parser.add_argument("--eval_data_shuffle", type=str, default="false", | |||||
| help="Enable eval data shuffle. Default: false.") | |||||
| parser.add_argument("--save_finetune_ckpt_path", type=str, default="", | |||||
| help="Save the checkpoint path.") | |||||
| parser.add_argument("--load_pretrain_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--load_finetune_ckpt_path", type=str, default="", | |||||
| help="Load the checkpoint file path.") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--tokenizer_file_path", type=str, default="", | |||||
| help="pretrained vocab and merge file path.") | |||||
| parser.add_argument("--generate_length", type=int, default=150, | |||||
| help="The generation length of translation sentence.") | |||||
| parser.add_argument("--top_k", type=int, default=1, | |||||
| help="Parameter for Top-K sampling.") | |||||
| parser.add_argument("--top_p", type=str, default="1.0", | |||||
| help="parameter for Top-P sampling.") | |||||
| parser.add_argument("--temperature", type=str, default="1.0", | |||||
| help="Parameter for generation, greater if generation more diverse. ") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| metric = args_opt.metric_method | |||||
| save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path | |||||
| load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path | |||||
| load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| device_target = args_opt.device_target | |||||
| if device_target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target=device_target, | |||||
| device_id=args_opt.device_id, | |||||
| max_call_depth=3000) | |||||
| context.set_auto_parallel_context(parallel_mode="stand_alone") | |||||
| print(" | Device: {} | Device id: {}".format(device_target, args_opt.device_id)) | |||||
| else: | |||||
| raise Exception("Device target error, Ascend is supported.") | |||||
| gpt2_loss = GPT2Translation(config=gpt2_net_cfg, | |||||
| is_training=True, | |||||
| use_one_hot_embeddings=False) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| print("============== Start Loading Translation Train Dataset ==============") | |||||
| print(" | Train Dataset: {}".format(args_opt.train_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_pretrain_ckpt_path)) | |||||
| train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.train_data_file_path) | |||||
| do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| print("============ Start Loading Translation Evaluation Dataset ============") | |||||
| print(" | Eval Dataset: {}".format(args_opt.eval_data_file_path)) | |||||
| print(" | Checkpoint: {}".format(args_opt.load_finetune_ckpt_path)) | |||||
| eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.eval_data_shuffle.lower() == "true"), | |||||
| dataset_path=args_opt.eval_data_file_path) | |||||
| do_eval(eval_dataset, GPT2TranslationModel, metric, load_finetune_ckpt_path, args_opt.eval_type, | |||||
| args_opt.tokenizer_file_path, args_opt.generate_length, args_opt.top_k, args_opt.top_p, | |||||
| args_opt.temperature) | |||||
| if __name__ == "__main__": | |||||
| print("Start Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| run_translation() | |||||
| print("End Time: \n", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) | |||||
| @@ -0,0 +1,60 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_cbt.sh" | |||||
| echo "for example: bash scripts/run_cbt.sh" | |||||
| echo "metric method: Accuracy" | |||||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_cbt.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_CBT_task.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=4 \ | |||||
| --num_choice=10 \ | |||||
| --metric_method="Accuracy" \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --eval_type="zero-shot" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,68 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_lambada.sh" | |||||
| echo "for example: bash scripts/run_lambada.sh" | |||||
| echo "method metric include: [Accuracy, PPL]" | |||||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_lambada.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| # tokenizer path | |||||
| tokenizer_file_path="" | |||||
| # stopword path | |||||
| stop_word_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_lambada.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=1 \ | |||||
| --metric_method="PPL" \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --eval_type="zero-shot" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --generate_length_dynamically="true" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path \ | |||||
| --tokenizer_file_path=$tokenizer_file_path \ | |||||
| --stop_word_file_path=$stop_word_file_path >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,59 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_language_model.sh" | |||||
| echo "for example: bash scripts/run_language_model.sh" | |||||
| echo "metric method: PPL" | |||||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_language_model.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_language_model.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=4 \ | |||||
| --metric_method="PPL" \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --eval_type="zero-shot" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,67 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_read_comprehension.sh" | |||||
| echo "for example: bash scripts/run_read_comprehension.sh" | |||||
| echo "metric method: F1" | |||||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_read_comprehension.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| # tokenizer path | |||||
| tokenizer_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_ReadComprehension.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=7 \ | |||||
| --metric_method="F1" \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --eval_type="zero-shot" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path \ | |||||
| --tokenizer_file_path=$tokenizer_file_path \ | |||||
| --generate_length=55 \ | |||||
| --top_k=1 \ | |||||
| --top_p="1.0" \ | |||||
| --temperature="1.0" >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,66 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_summarization.sh" | |||||
| echo "for example: bash scripts/run_summarization.sh" | |||||
| echo "eval_load_param_mode include: [zero-shot, finetuned]. Default: finetuned" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_summarization.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| # tokenizer path | |||||
| tokenizer_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_summarization.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=0 \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --metric_method="Rouge" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --top_k=2 \ | |||||
| --top_p="1.0" \ | |||||
| --generate_length=100 \ | |||||
| --temperature="1.0" \ | |||||
| --eval_type="finetuned" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path \ | |||||
| --tokenizer_file_path=$tokenizer_file_path >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,67 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the script as: " | |||||
| echo "bash scripts/run_translation.sh" | |||||
| echo "for example: bash scripts/run_translation.sh" | |||||
| echo "metric method: BLEU" | |||||
| echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot" | |||||
| echo "==============================================================================================================" | |||||
| CUR_DIR=`pwd` | |||||
| mkdir -p ms_log | |||||
| output_log="${CUR_DIR}/ms_log/gpt2_translation.log" | |||||
| # create file and head line | |||||
| echo " | Eval log file: " > $output_log | |||||
| echo $output_log >> $output_log | |||||
| # checkpoint path | |||||
| save_finetune_ckpt_path="" | |||||
| load_pretrain_ckpt_path="" | |||||
| load_eval_ckpt_path="" | |||||
| # dataset path | |||||
| train_data_file_path="" | |||||
| eval_data_file_path="" | |||||
| # tokenizer path | |||||
| tokenizer_file_path="" | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_translation.py \ | |||||
| --device_target="Ascend" \ | |||||
| --device_id=4 \ | |||||
| --metric_method="BLEU" \ | |||||
| --do_train="false" \ | |||||
| --do_eval="true" \ | |||||
| --eval_type="zero-shot" \ | |||||
| --epoch_num=1 \ | |||||
| --train_data_shuffle="true" \ | |||||
| --eval_data_shuffle="false" \ | |||||
| --save_finetune_ckpt_path=$save_finetune_ckpt_path \ | |||||
| --load_pretrain_ckpt_path=$load_pretrain_ckpt_path \ | |||||
| --load_finetune_ckpt_path=$load_eval_ckpt_path \ | |||||
| --train_data_file_path=$train_data_file_path \ | |||||
| --eval_data_file_path=$eval_data_file_path \ | |||||
| --tokenizer_file_path=$tokenizer_file_path \ | |||||
| --generate_length=100 \ | |||||
| --top_k=1 \ | |||||
| --top_p="1.0" \ | |||||
| --temperature="1.0" >> $output_log 2>&1 & | |||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (CBT) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2CBTModel(nn.Cell): | |||||
| """ | |||||
| GPT2CBTModel is responsible for Children's Book Test (CBT) task, i.e. CBT-CN, CBT-NE datasets. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| """ | |||||
| Args: | |||||
| config: the configuration of GPT-2 model | |||||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||||
| use_one_hot_embeddings (bool): default False. | |||||
| """ | |||||
| super(GPT2CBTModel, self).__init__() | |||||
| if not is_training: | |||||
| config.summary_first_dropout = 0.0 | |||||
| self.is_training = is_training | |||||
| self.d_model = config.d_model | |||||
| self.batch_size = config.batch_size | |||||
| self.seq_length = config.seq_length | |||||
| self.vocab_size = config.vocab_size | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| self.reshape = P.Reshape() | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.lm_head = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=TruncatedNormal(config.initializer_range), | |||||
| has_bias=False).to_float(config.compute_type) | |||||
| self.first_dropout = nn.Dropout(1 - config.summary_first_dropout) | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): shape with [batch_size, seq_len] | |||||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||||
| Returns: | |||||
| lm_logits (Tensor): language model distribution with log_softmax, | |||||
| shape with [batch_size, seq_len, vocab_size] | |||||
| """ | |||||
| output, _ = self.gpt2(input_ids, input_mask) # output shape is [batch_size, seq_len, d_model] | |||||
| output = self.cast(output, self.dtype) | |||||
| output = self.reshape(output, (-1, self.d_model)) | |||||
| output = self.first_dropout(output) | |||||
| lm_logits = self.lm_head(output) # [batch_size * seq_len, vocab_size] | |||||
| lm_logits = self.reshape(lm_logits, (self.batch_size, self.seq_length, self.vocab_size)) | |||||
| lm_logits = self.cast(lm_logits, self.dtype) | |||||
| lm_logits = self.log_softmax(lm_logits) | |||||
| return lm_logits | |||||
| def get_lm_head(self): | |||||
| return self.lm_head.weight | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (LAMBADA) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2LambadaModel(nn.Cell): | |||||
| """ | |||||
| GPT2LambadaModel is responsible for Lambada task, i.e. Lambada-train, Lambada-test datasets. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| """ | |||||
| Args: | |||||
| config: the configuration of GPT-2 model | |||||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||||
| use_one_hot_embeddings (bool): default False. | |||||
| """ | |||||
| super(GPT2LambadaModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout = 0.0 | |||||
| self.vocab_size = config.vocab_size | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.dense1 = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=TruncatedNormal(config.initializer_range)).to_float(mstype.float16) | |||||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Args: | |||||
| input_ids (Tensor): shape with [batch_size, seq_len] | |||||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||||
| Returns: | |||||
| lm_logits (Tensor): language model distribution with log_softmax, | |||||
| shape with [batch_size, seq_len, vocab_size] | |||||
| """ | |||||
| output, _ = self.gpt2(input_ids, input_mask) | |||||
| output = self.cast(output, self.dtype) | |||||
| output = self.dropout(output) | |||||
| batch_size, seq_length, d_model = self.shape(output) | |||||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||||
| logits = self.dense1(output_reshape) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| logits = self.log_softmax(logits) | |||||
| lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||||
| return lm_logits | |||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (Language Modeling) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2LanguageModel(nn.Cell): | |||||
| """ | |||||
| GPT2LanguageModel is responsible for Language Modeling task, i.e. WikiText2, WikiText103, PTB, 1BW datasets. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| """ | |||||
| Args: | |||||
| config: the configuration of GPT-2 model | |||||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||||
| use_one_hot_embeddings (bool): default False. | |||||
| """ | |||||
| super(GPT2LanguageModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout = 0.0 | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.vocab_size = config.vocab_size | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| self.dtype = config.dtype | |||||
| self.dense1 = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=TruncatedNormal(config.initializer_range), | |||||
| has_bias=False).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||||
| where 0 indicates padding position. | |||||
| Returns: | |||||
| lm_logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model]. | |||||
| """ | |||||
| output, _ = self.gpt2(input_ids, input_mask) | |||||
| output = self.cast(output, self.dtype) | |||||
| batch_size, seq_length, d_model = self.shape(output) | |||||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||||
| logits = self.dense1(output_reshape) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| logits = self.log_softmax(logits) | |||||
| lm_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) # [batch_size, seq_len, vocab] | |||||
| return lm_logits | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (Reading Comprehension) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import operations as P | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2CoQAModel(nn.Cell): | |||||
| """ | |||||
| This class is responsible for CoQA | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| super(GPT2CoQAModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout = 0.0 | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.dense1 = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=self.weight_init, | |||||
| has_bias=False).to_float(config.compute_type) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.vocab_size = config.vocab_size | |||||
| self.dtype = config.dtype | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||||
| where 0 indicates padding position. | |||||
| Returns: | |||||
| logits (Tensor): language model distribution with log_softmax, shape with[batch_size, seq_len, d_model]. | |||||
| """ | |||||
| decoder_output, _ = self.gpt2(input_ids, input_mask) | |||||
| decoder_output = P.Cast()(decoder_output, self.dtype) | |||||
| batch_size, seq_length, d_model = P.Shape()(decoder_output) | |||||
| reshaped_ouput = P.Reshape()(decoder_output, (-1, d_model)) # [batch_size * seq_length, d_model] | |||||
| logits = self.dense1(reshaped_ouput) | |||||
| logits = P.Cast()(logits, self.dtype) | |||||
| logits = self.log_softmax(logits) | |||||
| logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||||
| return logits | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (Summarization) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2SummarizationModel(nn.Cell): | |||||
| """ | |||||
| GPT2SummarizationModel is responsible for summary task, i.e. cnn_dailymail datasets. | |||||
| Args: | |||||
| config: the configuration of GPT-2 model | |||||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||||
| use_one_hot_embeddings (bool): default False. | |||||
| """ | |||||
| def __init__(self, config, is_training=True, use_one_hot_embeddings=False): | |||||
| super(GPT2SummarizationModel, self).__init__() | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.lm_head = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=TruncatedNormal(config.initializer_range), | |||||
| has_bias=False).to_float(mstype.float16) | |||||
| self.reshape = P.Reshape() | |||||
| self.dtype = config.dtype | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||||
| where 0 indicates padding position. | |||||
| Returns: | |||||
| lm_logits (Tensor): language model distribution without log_softmax, | |||||
| shape with [batch_size, seq_len, d_model]. | |||||
| """ | |||||
| output, _ = self.gpt2(input_ids, input_mask) | |||||
| output = self.cast(output, self.dtype) | |||||
| batch_size, seq_length, d_model = self.shape(output) | |||||
| hidden_state = self.reshape(output, (-1, d_model)) | |||||
| hidden_state = self.cast(hidden_state, self.dtype) | |||||
| lm_logits = self.lm_head(hidden_state) | |||||
| lm_logits = self.cast(lm_logits, self.dtype) | |||||
| lm_logits = self.reshape(lm_logits, (batch_size, seq_length, -1)) | |||||
| return lm_logits | |||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 downstream task (Translation) model script. | |||||
| """ | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from .GPT2_model import GPT2Model | |||||
| class GPT2TranslationModel(nn.Cell): | |||||
| """ | |||||
| GPT2TranslationModel is responsible for translation task, i.e. WMT-14 En-Fr, WMT-14 Fr-En datasets. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| """ | |||||
| Args: | |||||
| config: the configuration of GPT-2 model | |||||
| is_training (bool): `True` for train (finetune), `False` for evaluation. | |||||
| use_one_hot_embeddings (bool): default False. | |||||
| """ | |||||
| super(GPT2TranslationModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout = 0.0 | |||||
| self.gpt2 = GPT2Model(config, is_training, use_one_hot_embeddings) | |||||
| self.vocab_size = config.vocab_size | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| self.dtype = config.dtype | |||||
| self.dense1 = nn.Dense(config.d_model, | |||||
| config.vocab_size, | |||||
| weight_init=TruncatedNormal(config.initializer_range), | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - config.hidden_dropout) | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): input sentences shape with [batch_size, seq_len] | |||||
| input_mask (Tensor): shape with [batch_size, seq_len] 0 indicates padding mask | |||||
| Returns: | |||||
| translation_logits (Tensor): language model distribution without log_softmax, | |||||
| shape with [batch_size, seq_len, vocab_size] | |||||
| """ | |||||
| output, _ = self.gpt2(input_ids, input_mask) | |||||
| output = self.cast(output, self.dtype) | |||||
| output = self.dropout(output) | |||||
| batch_size, seq_length, d_model = self.shape(output) | |||||
| output_reshape = P.Reshape()(output, (-1, d_model)) # [batch_size * seq_len, d_model] | |||||
| logits = self.dense1(output_reshape) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| translation_logits = P.Reshape()(logits, (batch_size, seq_length, self.vocab_size)) | |||||
| return translation_logits | |||||
| @@ -0,0 +1,366 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| generation class for downstream task (Summarization, Reading Comprehension, Translation) | |||||
| """ | |||||
| import numpy as np | |||||
| from .utils.task_utils import extract_logits | |||||
| from .utils.generation_utils import Sample | |||||
| from .utils.tensor_manipulations import extract_string_from_tensor | |||||
| INF = 1. * 1e9 | |||||
| class GenerateForSummarization(): | |||||
| """ | |||||
| generate for summarization task | |||||
| """ | |||||
| def __init__(self, | |||||
| decoder, | |||||
| config=None, | |||||
| tokenizer=None, | |||||
| select_sentence=3, | |||||
| eval_type="finetuned", | |||||
| temperature=1.0, | |||||
| generate_length=100, | |||||
| topk=2, | |||||
| topp=1.0): | |||||
| self.decoder = decoder | |||||
| self.config = config | |||||
| self.tokenizer = tokenizer | |||||
| self.select_sentence = select_sentence | |||||
| self.eval_type = eval_type | |||||
| self.generator = Sample(decoder, | |||||
| tokenizer=tokenizer, | |||||
| config=config, | |||||
| topk_num=topk, | |||||
| topp_prob=topp, | |||||
| min_tokens_to_keep=1, | |||||
| demo_mode=False, | |||||
| temperature=temperature) | |||||
| self.generate_length = generate_length | |||||
| def generate_for_summarization(self, input_ids): | |||||
| """generation function for summarization task""" | |||||
| # prepare input_str | |||||
| article_str, summary_str = extract_string_from_tensor(input_ids=input_ids, | |||||
| mode="pair", | |||||
| config=self.config, | |||||
| tokenizer=self.tokenizer) | |||||
| generated_summary_list = [""] * self.config.batch_size | |||||
| # clip overflow | |||||
| for batch_idx in range(self.config.batch_size): | |||||
| last_dot_pos = max(article_str[batch_idx].rfind(' .'), article_str[batch_idx].rfind('. ')) + 2 | |||||
| article_str[batch_idx] = article_str[batch_idx][:last_dot_pos] | |||||
| # pad a <TL,DR;> token(<EOS>) after the string of Article. | |||||
| tldr_str = "TL;DR:" | |||||
| if self.eval_type == "finetuned": | |||||
| for batch_idx in range(self.config.batch_size): | |||||
| article_str[batch_idx] += (" " + tldr_str) | |||||
| # add prefix | |||||
| for batch_idx in range(self.config.batch_size): | |||||
| article_str[batch_idx] = article_str[batch_idx] | |||||
| generate_str_list, _ = self.generator.generate(input_str=article_str, generate_length=self.generate_length) | |||||
| for batch_idx in range(self.config.batch_size): | |||||
| generate_str = generate_str_list[batch_idx] | |||||
| generated_summary = "" | |||||
| if self.select_sentence > 0: | |||||
| # check if there are number of select_sentence of sentences in generated text, | |||||
| # if not enough, it will return full generated string | |||||
| len_generate_str = len(generate_str) | |||||
| search_index = -1 | |||||
| for _ in range(self.select_sentence): | |||||
| search_index = generate_str.find('.', search_index + 1) | |||||
| if search_index == -1 or search_index >= len_generate_str: | |||||
| search_index = len_generate_str | |||||
| break | |||||
| # increase search_index to add period token('.') if search_index does not overflow. | |||||
| search_index = search_index + 1 if search_index < len_generate_str else len_generate_str | |||||
| generated_summary = generate_str[:search_index] | |||||
| if generated_summary.find(self.tokenizer.eos_token) != -1: | |||||
| cut_pos = generated_summary.find(self.tokenizer.eos_token, 0) | |||||
| generated_summary = generated_summary[:cut_pos] | |||||
| else: | |||||
| generated_summary = generate_str | |||||
| # if all of str hs been clipped, restore it to beginning state. | |||||
| if generated_summary == '': | |||||
| generated_summary = generate_str | |||||
| # empty str check | |||||
| if generated_summary == '': | |||||
| generated_summary = '<empty>' | |||||
| generated_summary_list[batch_idx] = generated_summary | |||||
| return generated_summary_list, summary_str # Hypo and Ref | |||||
| class GenerateForLambada(): | |||||
| """ | |||||
| generate class for lambada task, which is to predict the final word of sentence. | |||||
| """ | |||||
| def __init__(self, | |||||
| decoder, | |||||
| config=None, | |||||
| tokenizer=None, | |||||
| generate_length_dynamic=True, | |||||
| generate_length=1, | |||||
| max_iterations=200, | |||||
| stop_word_file=""): | |||||
| """ | |||||
| Args: | |||||
| decoder: decoder (Model): GPT2 model to do generation. | |||||
| config (object): configuration of given GPT2 model. | |||||
| tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||||
| generate_length_dynamic (bool): True for the generate length is dynamic, False for fixed. Default: True. | |||||
| max_iterations (int): choose the top k token according to selected probability, there k = `max_iterations`. | |||||
| generate_length (int): the final word max generated token length. | |||||
| stop_word_file (str): stop word file is used to be a stop-word filter. | |||||
| """ | |||||
| self.decoder = decoder | |||||
| self.config = config | |||||
| self.batch_size = config.batch_size | |||||
| self.tokenizer = tokenizer | |||||
| self.generate_length_dynamic = generate_length_dynamic | |||||
| self.generate_length = generate_length | |||||
| self.max_iterations = max_iterations | |||||
| self.stop_word_set = self.build_stop_word(stop_word_file) | |||||
| self.generator = Sample(decoder=decoder, | |||||
| config=config, | |||||
| batch_size=1, | |||||
| tokenizer=tokenizer, | |||||
| topk_num=1, | |||||
| topp_prob=1, | |||||
| return_ids=True | |||||
| ) | |||||
| self.stop_eos = ['.', ',', '!', '?', '"', " '", " and", " says", " said"] | |||||
| def build_stop_word(self, stop_word_file): | |||||
| stop_words_set = set() | |||||
| with open(stop_word_file, 'r', encoding="utf8") as file: | |||||
| for line in file.readlines(): | |||||
| line = line.strip('\n') | |||||
| stop_words_set.add(line) | |||||
| return stop_words_set | |||||
| def is_stop_word(self, word): | |||||
| flag = False | |||||
| if word in self.stop_word_set: | |||||
| flag = True | |||||
| return flag | |||||
| return flag | |||||
| def generate_for_lambada(self, input_ids, logits, input_length): | |||||
| """ | |||||
| generation function for lambada task | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| logits (Tensor): the language model distribution. | |||||
| input_length (Tensor): store the context length which not including final word , and whole sentence length | |||||
| return: | |||||
| batch_predict_words (list): the list of predict_words | |||||
| """ | |||||
| batch_predict_words = ["" for _ in range(self.batch_size)] | |||||
| input_len_np = input_length.asnumpy() | |||||
| input_ids_list = input_ids.asnumpy().tolist() | |||||
| extracted_logits = extract_logits(logits=logits, position=input_len_np) # [batch_size, vocab_size] | |||||
| extracted_logits = extracted_logits.asnumpy() | |||||
| sorted_ids = np.argsort(-extracted_logits, axis=-1)[::, :self.max_iterations] # [batch_size, max_iterations] | |||||
| for batch_idx in range(self.batch_size): | |||||
| final_word_spos = input_len_np[batch_idx, 0] | |||||
| context_ids = input_ids_list[batch_idx][1:final_word_spos] # 1 for dropping <bos> token | |||||
| last_word_token_num = input_len_np[batch_idx, 1] - input_len_np[batch_idx, 0] | |||||
| if self.generate_length_dynamic: | |||||
| generate_length = last_word_token_num | |||||
| else: | |||||
| generate_length = self.generate_length | |||||
| for num in range(self.max_iterations): | |||||
| id_ = sorted_ids[batch_idx][num] | |||||
| source_ids = context_ids + [id_] | |||||
| source_string = self.tokenizer.decode(source_ids) | |||||
| generated_ids_list = self.generator.generate(input_str=source_string, | |||||
| generate_length=generate_length, | |||||
| do_sample=False) | |||||
| predict_tokens_ids = [id_] + generated_ids_list[0] | |||||
| predict_word = self.tokenizer.decode(predict_tokens_ids) | |||||
| eos_pos = min(predict_word.find(word) if predict_word.find(word) >= 0 | |||||
| else INF for word in self.stop_eos) | |||||
| if eos_pos == INF: | |||||
| continue | |||||
| else: | |||||
| predict_word = predict_word[:eos_pos] | |||||
| predict_word = predict_word.strip() | |||||
| if predict_word.find(" ") == -1: | |||||
| if self.is_stop_word(word=predict_word.lower()): | |||||
| continue | |||||
| batch_predict_words[batch_idx] = predict_word | |||||
| print("predict word: {}".format(predict_word)) | |||||
| break | |||||
| return batch_predict_words | |||||
| class GenerateForTranslation(): | |||||
| """ | |||||
| generate class for translation task | |||||
| """ | |||||
| def __init__(self, | |||||
| decoder, | |||||
| config=None, | |||||
| tokenizer=None, | |||||
| generate_length=1, | |||||
| use_hint=True, | |||||
| select_first_sentence=True, | |||||
| topk_num=None, | |||||
| topp_prob=None, | |||||
| temperature=None | |||||
| ): | |||||
| self.decoder = decoder | |||||
| self.config = config | |||||
| self.batch_size = config.batch_size | |||||
| self.tokenizer = tokenizer | |||||
| self.generate_length = generate_length | |||||
| self.use_hint = use_hint | |||||
| self.select_first_sentence = select_first_sentence | |||||
| self.generator = Sample(decoder=decoder, | |||||
| config=config, | |||||
| tokenizer=tokenizer, | |||||
| topk_num=topk_num, | |||||
| topp_prob=topp_prob, | |||||
| temperature=temperature, | |||||
| min_tokens_to_keep=1, | |||||
| early_stop=False) | |||||
| def generate_for_translation(self, input_ids): | |||||
| """generation function for translation task""" | |||||
| source_str_list, ref_str_list = extract_string_from_tensor(input_ids=input_ids, | |||||
| mode="pair", | |||||
| config=self.config, | |||||
| tokenizer=self.tokenizer) | |||||
| final_predict_translation_list = [""] * self.batch_size | |||||
| if self.use_hint: | |||||
| for index in range(self.batch_size): | |||||
| source_str_list[index] += " =" # now source_str is "english sentence =" | |||||
| translation_str_list, _ = self.generator.generate(input_str=source_str_list, | |||||
| generate_length=self.generate_length, | |||||
| do_sample=False) | |||||
| for index in range(self.batch_size): | |||||
| generate_str = translation_str_list[index].replace('<|endoftext|>', '') | |||||
| predict_translation = "" | |||||
| # According to the GPT2 paper, the select_first_sentence will be set "True" | |||||
| if self.select_first_sentence: | |||||
| # check if there are number of select_sentence of sentences in generated text, | |||||
| # if not enough, it will return full generated string | |||||
| search_index = generate_str.find('.', 0, len(generate_str)) | |||||
| if search_index == -1: | |||||
| search_index = len(generate_str) | |||||
| else: | |||||
| search_index = search_index + 1 | |||||
| predict_translation = generate_str[:search_index] | |||||
| else: | |||||
| predict_translation = generate_str | |||||
| if predict_translation == '': | |||||
| predict_translation = '<empty>' | |||||
| final_predict_translation_list[index] = predict_translation | |||||
| return final_predict_translation_list, ref_str_list | |||||
| class GenerateForReadComprehension(): | |||||
| """ | |||||
| generate class for Reading Comprehension task. | |||||
| Args: | |||||
| decoder: decoder (Model): GPT2 model to do generation. | |||||
| config (object): configuration of given GPT2 model. | |||||
| tokenizer (object): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||||
| generate_length (int): | |||||
| """ | |||||
| def __init__(self, | |||||
| decoder, | |||||
| config=None, | |||||
| tokenizer=None, | |||||
| generate_length=1, | |||||
| topk_num=None, | |||||
| topp_prob=None, | |||||
| temperature=None | |||||
| ): | |||||
| self.decoder = decoder | |||||
| self.config = config | |||||
| self.batch_size = config.batch_size | |||||
| self.tokenizer = tokenizer | |||||
| self.generate_length = generate_length | |||||
| self.generator = Sample(decoder=decoder, | |||||
| config=config, | |||||
| tokenizer=tokenizer, | |||||
| topk_num=topk_num, | |||||
| topp_prob=topp_prob, | |||||
| temperature=temperature, | |||||
| min_tokens_to_keep=1, | |||||
| ) | |||||
| def generate_for_read_comprehension(self, input_ids): | |||||
| """generation function for reading comprehension task""" | |||||
| passage_str_list, answer_str_list = extract_string_from_tensor(input_ids=input_ids, | |||||
| mode="pair", | |||||
| config=self.config, | |||||
| tokenizer=self.tokenizer) | |||||
| passage = passage_str_list[:] | |||||
| generate_str_list, _ = self.generator.generate(input_str=passage_str_list, | |||||
| generate_length=self.generate_length, | |||||
| do_sample=False) | |||||
| pred_answer = [] | |||||
| for batch_id in range(self.batch_size): | |||||
| new_str = generate_str_list[batch_id].replace('<|endoftext|>', '') | |||||
| index_a = new_str.find('.') | |||||
| index_b = new_str.find('Q:') | |||||
| if index_a != -1 or index_b != -1: | |||||
| index = max(index_a, index_b) | |||||
| pred_answer += [new_str[1:index]] # 1 represents skip the space in the beginning of the sentence | |||||
| else: | |||||
| pred_answer += [new_str] | |||||
| return passage, pred_answer, answer_str_list | |||||
| @@ -0,0 +1,896 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| GPT-2 base model | |||||
| """ | |||||
| import math | |||||
| import copy | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.functional as F | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from .weight_init import normal_weight, zero_weight | |||||
| class GPT2Config: | |||||
| """ | |||||
| Configuration for `GPT2Model`. | |||||
| Args: | |||||
| batch_size (int): Batch size of input dataset. Default: 512. | |||||
| seq_length (int): Length of input sequence. Default: 1024. | |||||
| vocab_size (int): The shape of each embedding vector. Default: 50257. | |||||
| d_model (int): Size of the bert encoder layers. Default: 768. | |||||
| num_hidden_layers (int): Number of hidden layers in the GPT2Transformer decoder block. Default: 12. | |||||
| num_attention_heads (int): Number of attention heads in the GPT2Transformer decoder block. Default: 12. | |||||
| intermediate_size (int): Size of intermediate layer in the GPT2Transformer decoder block. Default: 3072. | |||||
| hidden_act (str): Activation function used in the GPT2Transformer decoder block. Default: "gelu". | |||||
| hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1. | |||||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1. | |||||
| max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024. | |||||
| initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. | |||||
| input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from dataset. | |||||
| Default: True. | |||||
| summary_first_dropout (float): The dropout probability for GPT2CBTModel. Default: 0.1. | |||||
| dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. | |||||
| compute_type (:class:`mindspore.dtype`): Compute type in GPT2Transformer. Default: mstype.float16. | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=512, | |||||
| seq_length=1024, | |||||
| vocab_size=50257, | |||||
| d_model=768, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout=0.1, | |||||
| attention_dropout=0.1, | |||||
| max_position_embeddings=1024, | |||||
| initializer_range=0.02, | |||||
| input_mask_from_dataset=True, | |||||
| summary_first_dropout=0.1, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| ): | |||||
| self.batch_size = batch_size | |||||
| self.seq_length = seq_length | |||||
| self.vocab_size = vocab_size | |||||
| self.d_model = d_model | |||||
| self.num_hidden_layers = num_hidden_layers | |||||
| self.num_attention_heads = num_attention_heads | |||||
| self.intermediate_size = intermediate_size | |||||
| self.hidden_act = hidden_act | |||||
| self.hidden_dropout = hidden_dropout | |||||
| self.attention_dropout = attention_dropout | |||||
| self.max_position_embeddings = max_position_embeddings | |||||
| self.initializer_range = initializer_range | |||||
| self.input_mask_from_dataset = input_mask_from_dataset | |||||
| self.summary_first_dropout = summary_first_dropout | |||||
| self.dtype = dtype | |||||
| self.compute_type = compute_type | |||||
| class EmbeddingLookup(nn.Cell): | |||||
| """ | |||||
| A embeddings lookup table with a fixed dictionary and size. | |||||
| Args: | |||||
| vocab_size (int): Size of the dictionary of embeddings. | |||||
| embedding_dim (int): The size of each embedding vector. | |||||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| vocab_size, | |||||
| embedding_dim, | |||||
| use_one_hot_embeddings=False, | |||||
| compute_type=mstype.float16): | |||||
| super(EmbeddingLookup, self).__init__() | |||||
| self.vocab_size = vocab_size | |||||
| self.embedding_dim = embedding_dim | |||||
| self.use_one_hot_embeddings = use_one_hot_embeddings | |||||
| self.compute_type = compute_type | |||||
| self.embedding_table = Parameter(normal_weight([vocab_size, embedding_dim], embedding_dim), | |||||
| name='embedding_table') | |||||
| self.expand = P.ExpandDims() | |||||
| self.shape_flat = (-1,) | |||||
| self.gather = P.GatherV2() | |||||
| self.one_hot = P.OneHot() | |||||
| self.on_value = Tensor(1.0, mstype.float32) | |||||
| self.off_value = Tensor(0.0, mstype.float32) | |||||
| self.array_mul = P.MatMul() | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, input_ids): | |||||
| """ | |||||
| get embedding according to input_ids. | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| Returns: | |||||
| output (Tensor): the embedding matrix according to the input_ids. | |||||
| self.embedding_table (Parameter): the whole embedding table of GPT-2 model. | |||||
| """ | |||||
| input_shape = self.shape(input_ids) # [batch_size, seq_length] | |||||
| flat_ids = self.reshape(input_ids, self.shape_flat) # [batch_size * seq_length] | |||||
| if self.use_one_hot_embeddings: | |||||
| one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) | |||||
| # precision transition fp32 -> fp16 | |||||
| one_hot_ids = self.cast(one_hot_ids, self.compute_type) | |||||
| self.embedding_table = self.cast(self.embedding_table, self.compute_type) | |||||
| output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) | |||||
| output_for_reshape = self.cast(output_for_reshape, mstype.float32) | |||||
| else: | |||||
| # [batch_size * seq_length * embedding_dim] | |||||
| output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) | |||||
| out_shape = input_shape + (self.embedding_dim,) | |||||
| output = self.reshape(output_for_reshape, out_shape) # [batch_size, seq_length, embedidng_dim] | |||||
| return output, self.embedding_table | |||||
| class EmbeddingPostprocessor(nn.Cell): | |||||
| """ | |||||
| Postprocessors apply positional embeddings to word embeddings. | |||||
| Args: | |||||
| embedding_dim (int): The size of each embedding vector. | |||||
| seq_length (int): the length of input sequence. | |||||
| max_position_embeddings (int): Maximum length of sequences used in this model. Default: 1024. | |||||
| dropout_prob (float): The dropout probability. Default: 0.1. | |||||
| """ | |||||
| def __init__(self, | |||||
| embedding_dim=None, | |||||
| seq_length=None, | |||||
| max_position_embeddings=1024, | |||||
| dropout_prob=0.1): | |||||
| super(EmbeddingPostprocessor, self).__init__() | |||||
| self.position_embedding_table = Parameter( | |||||
| normal_weight([max_position_embeddings, embedding_dim], embedding_dim), name='position_embeddings') | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.add = P.TensorAdd() | |||||
| self.gather = P.GatherV2() | |||||
| self.input_indices = Tensor(np.array([x for x in range(seq_length)]), mindspore.int32) | |||||
| self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32) | |||||
| self.use_dropout = dropout_prob > 0 | |||||
| def construct(self, word_embeddings): | |||||
| """ | |||||
| Add the position embedding table to token embedding table | |||||
| Args: | |||||
| word_embeddings (Tensor): the token embedding matrix | |||||
| Returns: | |||||
| output (Tensor): the final embedding matrix by adding the position embedding table | |||||
| to token embedding table. | |||||
| """ | |||||
| position_embeddings = self.gather(self.position_embedding_table, self.input_indices, 0) | |||||
| position_embeddings = self.expand_dims(position_embeddings, 0) | |||||
| output = self.add(word_embeddings, position_embeddings) | |||||
| if self.use_dropout: | |||||
| output = self.dropout(output) | |||||
| return output | |||||
| class CastWrapper(nn.Cell): | |||||
| """ | |||||
| Cast wrapper | |||||
| """ | |||||
| def __init__(self, | |||||
| dst_type=mstype.float32): | |||||
| super(CastWrapper, self).__init__() | |||||
| self.cast = P.Cast() | |||||
| self.dst_type = dst_type | |||||
| def construct(self, x): | |||||
| """ | |||||
| type cast | |||||
| Args: | |||||
| x (Tensor): the input which need to be cast. | |||||
| Returns: | |||||
| Tensor, the cast output. | |||||
| """ | |||||
| return self.cast(x, self.dst_type) | |||||
| class LayerNorm(nn.Cell): | |||||
| """ | |||||
| Do layer norm | |||||
| Args: | |||||
| in_channels (int): In channels number of layer norm | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels=None): | |||||
| super(LayerNorm, self).__init__() | |||||
| self.layer_norm = nn.LayerNorm((in_channels,)) | |||||
| self.cast = P.Cast() | |||||
| self.get_dtype = P.DType() | |||||
| def construct(self, input_tensor): | |||||
| """ | |||||
| layer norm | |||||
| Args: | |||||
| input_tensor (Tensor): the input of layernorm. | |||||
| Returns: | |||||
| Tensor, the output after layernorm. | |||||
| """ | |||||
| output = self.cast(input_tensor, mstype.float32) | |||||
| output = self.layer_norm(output) | |||||
| output = self.cast(output, self.get_dtype(input_tensor)) | |||||
| return output | |||||
| class ResidualConnection(nn.Cell): | |||||
| """ | |||||
| Add residual to output. | |||||
| Args: | |||||
| dropout_prob (float): Dropout rate. | |||||
| """ | |||||
| def __init__(self, dropout_prob=0.0): | |||||
| super(ResidualConnection, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| self.use_dropout = dropout_prob > 0 | |||||
| def construct(self, hidden_tensor, input_tensor): | |||||
| """ | |||||
| Args: | |||||
| hidden_tensor (Tensor): the output of sublayer. | |||||
| input_tensor (Tensor): the input tensor. | |||||
| Returns: | |||||
| output (Tensor): with the same shape of hidden_tensor. | |||||
| """ | |||||
| output = hidden_tensor | |||||
| if self.use_dropout: | |||||
| output = self.dropout(output) | |||||
| output = self.add(output, input_tensor) | |||||
| return output | |||||
| class Conv1D(nn.Cell): | |||||
| """ | |||||
| 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). | |||||
| Basically works like a linear layer but the weights are transposed. | |||||
| Args: | |||||
| nx (int): The number of input features. | |||||
| nf (int): The number of output features. | |||||
| """ | |||||
| def __init__(self, | |||||
| nx, | |||||
| nf): | |||||
| super(Conv1D, self).__init__() | |||||
| self.nx = nx | |||||
| self.nf = nf | |||||
| self.weight = Parameter(normal_weight([nx, nf], nf), name='projection_weight') | |||||
| self.bias = Parameter(zero_weight(nf), name='projection_bias') | |||||
| self.matmul = P.MatMul() | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, input_tensor): | |||||
| """ | |||||
| Args: | |||||
| input_tensor (Tensor): the input tensor of Conv1D with shape [batch_size * seq_length, nx] | |||||
| Returns: | |||||
| output_tensor (Tensor): the output tensor with shape [batch_size * seq_length, self.nf] | |||||
| """ | |||||
| # precision transition fp32 -> fp16 | |||||
| input_tensor = self.cast(input_tensor, mstype.float16) | |||||
| fp16_weight = self.cast(self.weight, mstype.float16) | |||||
| output_tensor = self.matmul(input_tensor, fp16_weight) # [batch_size * seq_length, self.nf] | |||||
| output_tensor = self.cast(output_tensor, mstype.float32) | |||||
| output_tensor = self.bias_add(output_tensor, self.bias) | |||||
| return output_tensor | |||||
| class MaskedSelfAttention(nn.Cell): | |||||
| """ | |||||
| Apply masked multi-head attention. | |||||
| Args: | |||||
| batch_size (int): Batch size of input datasets. Default: 512. | |||||
| d_model (int): Size of last dim of input tensor. Default: 768. | |||||
| seq_length (int): Length of input tensor sequence. Default: 1024. | |||||
| num_attention_heads (int): Number of attention heads. Default: 12. | |||||
| dim_per_head (int): Size of each attention head. Default: 64. | |||||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||||
| attention_dropout (float): The dropout probability for MultiheadAttention. Default: 0.0. | |||||
| compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32. | |||||
| Returns: | |||||
| Tensor, with the shape [batch_size, seq_length, d_model] | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=512, | |||||
| d_model=768, | |||||
| seq_length=1024, | |||||
| num_attention_heads=12, | |||||
| dim_per_head=64, | |||||
| has_attention_mask=True, | |||||
| do_return_2d_tensor=True, | |||||
| attention_dropout=0.0, | |||||
| compute_type=mstype.float16): | |||||
| super(MaskedSelfAttention, self).__init__() | |||||
| self.batch_size = batch_size | |||||
| self.d_model = d_model | |||||
| self.seq_length = seq_length | |||||
| self.num_heads = num_attention_heads | |||||
| self.dim_per_head = dim_per_head | |||||
| self.has_attention_mask = has_attention_mask | |||||
| self.compute_type = compute_type | |||||
| assert has_attention_mask | |||||
| self.scale = Tensor([1.0 / math.sqrt(float(self.dim_per_head))], dtype=compute_type) # attention scale | |||||
| self.mask_data = Tensor([-10000.0,], dtype=compute_type) | |||||
| self.split_head_shape = (-1, self.seq_length, self.num_heads, self.dim_per_head) | |||||
| self.c_attn = Conv1D(d_model, d_model * 3) | |||||
| self.c_proj = Conv1D(d_model, d_model) | |||||
| self.split_for_qkv = P.Split(1, 3) | |||||
| self.reshape = P.Reshape() | |||||
| self.transpose = P.Transpose() | |||||
| self.trans_shape = (0, 2, 1, 3) | |||||
| self.matmul_trans_b = P.BatchMatMul(transpose_b=True) | |||||
| self.matmul = P.BatchMatMul() | |||||
| self.multiply = P.Mul() | |||||
| if self.has_attention_mask: | |||||
| self.expand_dims = P.ExpandDims() | |||||
| self.sub = P.Sub() | |||||
| self.add = P.TensorAdd() | |||||
| self.cast = P.Cast() | |||||
| self.get_dtype = P.DType() | |||||
| if do_return_2d_tensor: | |||||
| self.shape_return = (-1, d_model) | |||||
| else: | |||||
| self.shape_return = (-1, seq_length, d_model) | |||||
| self.softmax = nn.Softmax() | |||||
| self.softmax_cast = P.Cast() | |||||
| self.dropout = nn.Dropout(1 - attention_dropout) | |||||
| self.use_attention_dropout = attention_dropout > 0 | |||||
| def construct(self, input_tensor, attention_mask): | |||||
| """ | |||||
| do masked self-attention | |||||
| Args: | |||||
| input_tensor (Tensor): the embedding of input sequence tokens, | |||||
| shape with [batch_size * seq_length, d_mdoel] | |||||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||||
| shape with [batch_size, seq_len, seq_len]. | |||||
| Returns: | |||||
| outputs (Tensor): the output of masked self-attention, shape with [batch_size * seq_len, d_model]. | |||||
| """ | |||||
| input_tensor = self.c_attn(input_tensor) # [batch_size * seq_length, d_model*3]---> eg.[1 * 3, 2304] | |||||
| input_tensor = self.split_for_qkv(input_tensor) | |||||
| query = input_tensor[0] # [batch_size * seq_length, d_model] ---> eg. [1 * 3, 768] | |||||
| key = input_tensor[1] | |||||
| value = input_tensor[2] | |||||
| # split head | |||||
| query = self.reshape(query, self.split_head_shape) | |||||
| # query shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||||
| query = self.transpose(query, self.trans_shape) | |||||
| key = self.reshape(key, self.split_head_shape) | |||||
| # key shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||||
| key = self.transpose(key, self.trans_shape) | |||||
| value = self.reshape(value, self.split_head_shape) | |||||
| # value shape [batch_size, num_heads, seq_len, dim_per_head] ---> eg. [1, 12, 3, 64] | |||||
| value = self.transpose(value, self.trans_shape) | |||||
| # attention and mask | |||||
| # precision transition fp32 -> fp16 | |||||
| query = self.cast(query, self.compute_type) | |||||
| key = self.cast(key, self.compute_type) | |||||
| attention_scores = self.matmul_trans_b(query, key) # [batch_size, num_heads, seq_len, seq_len] | |||||
| attention_scores = self.cast(attention_scores, self.compute_type) | |||||
| attention_scores = self.multiply(attention_scores, self.scale) | |||||
| if self.has_attention_mask: | |||||
| attention_mask = self.expand_dims(attention_mask, 1) # [batch_size, 1, seq_length, seq_length] | |||||
| multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), | |||||
| self.cast(attention_mask, self.get_dtype(attention_scores))) # fp16 | |||||
| adder = self.multiply(multiply_out, self.mask_data) | |||||
| adder = self.cast(adder, mstype.float32) | |||||
| attention_scores = self.cast(attention_scores, mstype.float32) | |||||
| attention_scores = self.add(adder, attention_scores) | |||||
| attention_scores = self.softmax_cast(attention_scores, mstype.float32) | |||||
| attention_probs = self.softmax(attention_scores) # [batch_size, num_heads, seq_len, seq_len] | |||||
| attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key)) | |||||
| if self.use_attention_dropout: | |||||
| attention_probs = self.dropout(attention_probs) | |||||
| value = self.cast(value, mstype.float16) | |||||
| attention_probs = self.cast(attention_probs, self.compute_type) | |||||
| outputs = self.matmul(attention_probs, value) # [batch_size, num_heads, seq_len, dim_per_head] | |||||
| outputs = self.cast(outputs, mstype.float32) | |||||
| # merge heads | |||||
| outputs = self.transpose(outputs, self.trans_shape) # [batch_size, seq_len, num_heads, dim_per_head] | |||||
| outputs = self.reshape(outputs, | |||||
| self.shape_return) # default True, the outputs shape [batch_size * seq_len, d_model] | |||||
| # project | |||||
| outputs = self.c_proj(outputs) | |||||
| return outputs | |||||
| class FeedForward(nn.Cell): | |||||
| """ | |||||
| Apply two-layer feed forward | |||||
| Args: | |||||
| in_channels (int): Size of the input layer. Default: 768. | |||||
| out_channels (int): Size of the output layers. Default: 768. | |||||
| hidden_size (int): Size of the hidden layer. Default: 3072. | |||||
| hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1. | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels=786, | |||||
| out_channels=768, | |||||
| hidden_size=3072, | |||||
| hidden_dropout=0.1): | |||||
| super(FeedForward, self).__init__() | |||||
| self.c_fc = Conv1D(in_channels, hidden_size) | |||||
| self.c_proj = Conv1D(hidden_size, out_channels) | |||||
| # self.gelu = Gelu() | |||||
| self.layernorm = LayerNorm(in_channels=in_channels) | |||||
| self.residual_connect = ResidualConnection(dropout_prob=hidden_dropout) | |||||
| self.gelu_act = P.Gelu() | |||||
| self.dropout = nn.Dropout(1 - hidden_dropout) | |||||
| self.use_dropout = hidden_dropout > 0 | |||||
| self.reshape = P.Reshape() | |||||
| def construct(self, input_tensor): | |||||
| """ | |||||
| FeedForward construct function with layernorm and residual connection. | |||||
| Args: | |||||
| input_tensor (Tensor): the input of FeedForward layer, shape with [batch_szie * seq_len, d_model]. | |||||
| Returns: | |||||
| output (Tensor): the output of FeedForward layer, shape with [batch_szie * seq_len, d_model] | |||||
| """ | |||||
| # LayerNorm | |||||
| output = self.layernorm(input_tensor) | |||||
| # Feed Forward | |||||
| output = self.c_fc(output) # [batch_szie * seq_len, d_model * 4] | |||||
| output = self.gelu_act(output) | |||||
| # output = self.gelu(output) | |||||
| output = self.c_proj(output) # [batch_szie * seq_len, d_model] | |||||
| if self.use_dropout: | |||||
| output = self.dropout(output) | |||||
| # Add | |||||
| output = self.residual_connect(output, input_tensor) | |||||
| return output | |||||
| class MaskedMultiHeadAttention(nn.Cell): | |||||
| """ | |||||
| Masked multi-head attention block. | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=512, | |||||
| seq_length=2014, | |||||
| d_model=768, | |||||
| num_attention_heads=12, | |||||
| attention_dropout=0.02, | |||||
| hidden_dropout=0.1, | |||||
| has_attention_mask=True, | |||||
| compute_type=mstype.float16 | |||||
| ): | |||||
| super(MaskedMultiHeadAttention, self).__init__() | |||||
| if d_model % num_attention_heads != 0: | |||||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||||
| "of attention heads (%d)" % (d_model, num_attention_heads)) | |||||
| self.dim_per_head = int(d_model / num_attention_heads) # 64 | |||||
| self.masked_self_attention = MaskedSelfAttention( | |||||
| batch_size=batch_size, | |||||
| d_model=d_model, | |||||
| seq_length=seq_length, | |||||
| num_attention_heads=num_attention_heads, | |||||
| dim_per_head=self.dim_per_head, | |||||
| has_attention_mask=has_attention_mask, | |||||
| do_return_2d_tensor=True, | |||||
| attention_dropout=attention_dropout, | |||||
| compute_type=compute_type | |||||
| ) | |||||
| self.layer_norm = LayerNorm(in_channels=d_model) | |||||
| self.residual_connection = ResidualConnection() | |||||
| self.reshape = P.Reshape() | |||||
| self.new_shape = (-1, d_model) | |||||
| def construct(self, input_tensor, attention_mask): | |||||
| """ | |||||
| do masked multi head self-attention with layernorm and residual_connection. | |||||
| Args: | |||||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||||
| shape with [batch_size * seq_length, d_mdoel] | |||||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||||
| shape with [batch_size, seq_len, seq_len]. | |||||
| Returns: | |||||
| outputs (Tensor): the output of MaskedMultiHeadAttention, shape with [batch_size * seq_len, d_model]. | |||||
| """ | |||||
| # LayerNorm | |||||
| output_tensor = self.layer_norm(input_tensor) | |||||
| # masked multi-head attention | |||||
| # attention_output shape [batch_size * seq_length, d_model] | |||||
| attention_output = self.masked_self_attention(output_tensor, attention_mask) | |||||
| # residual connection | |||||
| output = self.residual_connection(attention_output, input_tensor) | |||||
| return output | |||||
| class DecoderBlock(nn.Cell): | |||||
| """ | |||||
| decoder block used in GPT2. | |||||
| Args: | |||||
| batch_size (int): Batch size of input dataset. Default: 512. | |||||
| seq_length (int): Length of input sequence. Default: 1024. | |||||
| d_model (int): Size of the GPT2 decoder layers. Default: 768. | |||||
| num_attention_heads (int): Number of attention heads. Default: 12. | |||||
| intermediate_size (int): Size of intermediate layer. Default: 3072. | |||||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.02. | |||||
| hidden_dropout (float): The dropout probability for hidden outputs. Default: 0.1. | |||||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||||
| compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=512, | |||||
| seq_length=1024, | |||||
| d_model=768, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| attention_dropout=0.02, | |||||
| hidden_dropout=0.1, | |||||
| has_attention_mask=True, | |||||
| compute_type=mstype.float16 | |||||
| ): | |||||
| super(DecoderBlock, self).__init__() | |||||
| if d_model % num_attention_heads != 0: | |||||
| raise ValueError("The hidden size (%d) is not a multiple of the number " | |||||
| "of attention heads (%d)" % (d_model, num_attention_heads)) | |||||
| self.dim_per_head = int(d_model / num_attention_heads) # 64 | |||||
| self.masked_multi_head_attention = MaskedMultiHeadAttention( | |||||
| batch_size=batch_size, | |||||
| seq_length=seq_length, | |||||
| d_model=d_model, | |||||
| num_attention_heads=num_attention_heads, | |||||
| attention_dropout=attention_dropout, | |||||
| hidden_dropout=hidden_dropout, | |||||
| has_attention_mask=has_attention_mask, | |||||
| compute_type=compute_type | |||||
| ) | |||||
| self.feedforward = FeedForward( | |||||
| in_channels=d_model, | |||||
| out_channels=d_model, | |||||
| hidden_size=intermediate_size, | |||||
| hidden_dropout=hidden_dropout | |||||
| ) | |||||
| self.reshape = P.Reshape() | |||||
| self.new_shape = (-1, d_model) | |||||
| def construct(self, input_tensor, attention_mask): # input tensor shape[batch_size, seq_length, d_model] | |||||
| """ | |||||
| DecoderBlock with masked_multi_head_attention and feedforward. | |||||
| Args: | |||||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||||
| shape with [batch_size * seq_length, d_mdoel] | |||||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||||
| shape with [batch_size, seq_len, seq_len]. | |||||
| Returns: | |||||
| outputs (Tensor): the output of DecoderBlock, shape with [batch_size * seq_len, d_model]. | |||||
| """ | |||||
| input_tensor = self.reshape(input_tensor, self.new_shape) | |||||
| # masked multi head attention with ln, res | |||||
| attention_output = self.masked_multi_head_attention(input_tensor, attention_mask) | |||||
| # feed forward with ln, res | |||||
| output = self.feedforward(attention_output) | |||||
| return output | |||||
| class GPT2Transformer(nn.Cell): | |||||
| """ | |||||
| Multi-layer GPT2 transformer. | |||||
| Args: | |||||
| batch_size (int): Batch size of input dataset. Default: 512. | |||||
| d_model (int): Size of the decoder layers. Default: 768. | |||||
| seq_length (int): Length of input sequence. Default: 1024. | |||||
| num_hidden_layers (int): Number of hidden layers in decoder cells. Default: 12. | |||||
| num_attention_heads (int): Number of attention heads in decoder cells. Default: 12. | |||||
| intermediate_size (int): Size of intermediate layer in decoder cells. Default: 3072. | |||||
| has_attention_mask (bool): Specifies whether to use attention mask. Default: True. | |||||
| attention_dropout (float): The dropout probability for MaskedMultiHeadAttention. Default: 0.1. | |||||
| hidden_dropout (float): The dropout probability for GPT2Output. Default: 0.1. | |||||
| compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=512, | |||||
| d_model=768, | |||||
| seq_length=1024, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| has_attention_mask=True, | |||||
| attention_dropout=0.1, | |||||
| hidden_dropout=0.1, | |||||
| compute_type=mstype.float16): | |||||
| super(GPT2Transformer, self).__init__() | |||||
| layers = [] | |||||
| for _ in range(num_hidden_layers): | |||||
| layer = DecoderBlock(batch_size=batch_size, | |||||
| seq_length=seq_length, | |||||
| d_model=d_model, | |||||
| num_attention_heads=num_attention_heads, | |||||
| intermediate_size=intermediate_size, | |||||
| attention_dropout=attention_dropout, | |||||
| hidden_dropout=hidden_dropout, | |||||
| has_attention_mask=has_attention_mask, | |||||
| compute_type=compute_type) | |||||
| layers.append(layer) | |||||
| self.layers = nn.CellList(layers) | |||||
| self.reshape = P.Reshape() | |||||
| self.new_shape = (-1, d_model) | |||||
| # self.out_shape = (batch_size, seq_length, d_model) | |||||
| self.out_shape = (-1, seq_length, d_model) | |||||
| def construct(self, input_tensor, attention_mask): | |||||
| """ | |||||
| Do Multi DecoderBlock. | |||||
| Args: | |||||
| input_tensor (Tensor): the embedding matrix of input sequence tokens, | |||||
| shape with [batch_size * seq_length, d_mdoel] | |||||
| attention_mask (Tensor): mask to avoid performing attention on padding token indices, | |||||
| shape with [batch_size, seq_len, seq_len]. | |||||
| Returns: | |||||
| outputs (Tensor): the output of GPT2Transformer, shape with [batch_size * seq_len, d_model]. | |||||
| """ | |||||
| prev_output = self.reshape(input_tensor, self.new_shape) | |||||
| for layer_module in self.layers: | |||||
| layer_output = layer_module(prev_output, attention_mask) | |||||
| prev_output = layer_output | |||||
| output = self.reshape(prev_output, self.out_shape) | |||||
| return output | |||||
| class CreateAttentionMaskFromInputMask(nn.Cell): | |||||
| """ | |||||
| Create attention mask according to input mask. | |||||
| Args: | |||||
| config (Class): Configuration for GPT2Model. | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(CreateAttentionMaskFromInputMask, self).__init__() | |||||
| self.input_mask_from_dataset = config.input_mask_from_dataset | |||||
| self.input_mask = None | |||||
| self.compute_type = config.compute_type | |||||
| assert self.input_mask_from_dataset | |||||
| self.cast = P.Cast() | |||||
| self.shape = P.Shape() | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.BatchMatMul() | |||||
| self.multiply = P.Mul() | |||||
| # mask future positions | |||||
| ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length)) | |||||
| self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32) | |||||
| def construct(self, input_mask, mask_future=True): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_mask (Tensor): Tensor mask vectors with shape [batch_size, seq_len]. | |||||
| mask_future (bool): Whether mask future (for decoder training). Default: True. | |||||
| Returns: | |||||
| attention_mask (Tensor): shape [batch_size, seq_len, seq_len]. | |||||
| """ | |||||
| input_shape = self.shape(input_mask) | |||||
| shape_right = (input_shape[0], 1, input_shape[1]) # [batch_size, 1, seq_len] | |||||
| shape_left = input_shape + (1,) # [batch_size, seq_len, 1] | |||||
| input_mask = self.cast(input_mask, mstype.float32) | |||||
| mask_left = self.reshape(input_mask, shape_left) | |||||
| mask_right = self.reshape(input_mask, shape_right) | |||||
| # precision transition fp32 -> fp16 | |||||
| mask_left = self.cast(mask_left, self.compute_type) | |||||
| mask_right = self.cast(mask_right, self.compute_type) | |||||
| attention_mask = self.matmul(mask_left, mask_right) # [batch_szie, seq_len, seq_len] | |||||
| attention_mask = self.cast(attention_mask, mstype.float32) | |||||
| if mask_future: | |||||
| attention_mask = self.multiply(attention_mask, self.lower_triangle_mask) | |||||
| return attention_mask | |||||
| class GPT2Model(nn.Cell): | |||||
| """ | |||||
| Decoder Representations from Transformers. | |||||
| Args: | |||||
| config (Class): Configuration for GPT2Model. | |||||
| is_training (bool): True for training mode. False for eval mode. | |||||
| use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. | |||||
| """ | |||||
| def __init__(self, | |||||
| config, | |||||
| is_training, | |||||
| use_one_hot_embeddings=False | |||||
| ): | |||||
| super(GPT2Model, self).__init__() | |||||
| self.config = copy.deepcopy(config) | |||||
| self.is_training = is_training | |||||
| if not is_training: | |||||
| self.config.hidden_dropout = 0.0 | |||||
| self.config.attention_dropout = 0.0 | |||||
| self.input_mask_from_dataset = self.config.input_mask_from_dataset | |||||
| self.batch_size = self.config.batch_size | |||||
| self.seq_length = self.config.seq_length | |||||
| self.d_model = self.config.d_model | |||||
| self.num_hidden_layers = self.config.num_hidden_layers | |||||
| self.embedding_dim = self.config.d_model | |||||
| self.last_idx = self.num_hidden_layers - 1 | |||||
| self.gpt2_embedding_lookup = EmbeddingLookup( | |||||
| vocab_size=self.config.vocab_size, | |||||
| embedding_dim=self.embedding_dim, | |||||
| use_one_hot_embeddings=use_one_hot_embeddings, | |||||
| compute_type=self.config.compute_type | |||||
| ) | |||||
| self.gpt2_embedding_postprocess = EmbeddingPostprocessor( | |||||
| embedding_dim=self.embedding_dim, | |||||
| seq_length=self.seq_length, | |||||
| max_position_embeddings=self.config.max_position_embeddings, | |||||
| dropout_prob=self.config.hidden_dropout | |||||
| ) | |||||
| self.gpt2_decoder = GPT2Transformer( | |||||
| batch_size=self.batch_size, | |||||
| d_model=self.d_model, | |||||
| seq_length=self.seq_length, | |||||
| num_hidden_layers=self.num_hidden_layers, | |||||
| num_attention_heads=self.config.num_attention_heads, | |||||
| intermediate_size=self.config.intermediate_size, | |||||
| has_attention_mask=True, | |||||
| attention_dropout=self.config.attention_dropout, | |||||
| hidden_dropout=self.config.hidden_dropout, | |||||
| compute_type=self.config.compute_type | |||||
| ) | |||||
| self.cast_compute_type = CastWrapper(dst_type=self.config.compute_type) | |||||
| self.layer_norm = LayerNorm(in_channels=self.d_model) | |||||
| self.dropout = nn.Dropout(1 - self.config.hidden_dropout) | |||||
| self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(self.config) | |||||
| self.reshape = P.Reshape() | |||||
| self.new_shape = (-1, self.d_model) | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| Construct network. | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| input_mask (Tensor): input sentences padding mask with shape [batch_size, seq_len], | |||||
| where 0 indicates padding position. | |||||
| Returns: | |||||
| decoder_output (Tensor): shape[batch_size, seq_len, d_model]. | |||||
| embedding_tables (Tensor): word embeddings with shape [vocab_size, d_model] | |||||
| """ | |||||
| # Embedding | |||||
| word_embeddings, embedding_tables = self.gpt2_embedding_lookup(input_ids) | |||||
| embedding_output = self.gpt2_embedding_postprocess(word_embeddings) | |||||
| embedding_output = self.dropout(embedding_output) | |||||
| # Attention mask with shape [batch_size, seq_len, seq_len] | |||||
| attention_mask = self._create_attention_mask_from_input_mask(input_mask, True) | |||||
| # GPT2 decoder | |||||
| decoder_output = self.gpt2_decoder( | |||||
| self.cast_compute_type(embedding_output), | |||||
| self.cast_compute_type(attention_mask) | |||||
| ) | |||||
| # LayerNorm | |||||
| decoder_output = self.reshape(decoder_output, self.new_shape) | |||||
| decoder_output = self.layer_norm(decoder_output) | |||||
| decoder_output = self.reshape(decoder_output, (-1, self.seq_length, self.d_model)) | |||||
| return decoder_output, embedding_tables | |||||
| def get_token_embeddings(self): | |||||
| return self.gpt2_embedding_lookup.embedding_table.asnumpy() | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """clip gradient""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||||
| # pylint: disable=consider-using-in | |||||
| @clip_grad.register("Number", "Number", "Tensor") | |||||
| def _clip_grad(clip_type, clip_value, grad): | |||||
| """ | |||||
| Clip gradients. | |||||
| Inputs: | |||||
| clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. | |||||
| clip_value (float): Specifies how much to clip. | |||||
| grad (tuple[Tensor]): Gradients. | |||||
| Outputs: | |||||
| tuple[Tensor], clipped gradients. | |||||
| """ | |||||
| if clip_type != 0 and clip_type != 1: | |||||
| return grad | |||||
| dt = F.dtype(grad) | |||||
| if clip_type == 0: | |||||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| else: | |||||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| return new_grad | |||||
| @@ -0,0 +1,95 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Data operations""" | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from .finetune_eval_config import gpt2_net_cfg | |||||
| def create_language_model_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""): | |||||
| """create dataset like language model task""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = de.MindDataset(dataset_path, | |||||
| columns_list=["input_ids", "input_mask", "label_ids"], | |||||
| shuffle=do_shuffle, | |||||
| num_shards=device_num, | |||||
| shard_id=rank_id) | |||||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="label_ids") | |||||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_count) | |||||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||||
| print("output shape: {}".format(ds.output_shapes())) | |||||
| print("output type: {}".format(ds.output_types())) | |||||
| print("============== create dataset successful ===============") | |||||
| return ds | |||||
| def create_cbt_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=False, dataset_path=""): | |||||
| """create dataset for cbt task""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = de.MindDataset(dataset_path, | |||||
| columns_list=["input_ids", "input_mask", "input_length", "mc_labels"], | |||||
| shuffle=do_shuffle, | |||||
| num_shards=device_num, | |||||
| shard_id=rank_id) | |||||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_length") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="mc_labels") | |||||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_count) | |||||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||||
| print("output shape: {}".format(ds.output_shapes())) | |||||
| print("output type: {}".format(ds.output_types())) | |||||
| print("============== create CBT LM dataset successful ===============") | |||||
| return ds | |||||
| def create_lambada_control_dataset(device_num=1, repeat_count=1, rank_id=0, do_shuffle=True, dataset_path=""): | |||||
| """create dataset for lambada task""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = de.MindDataset(dataset_path, | |||||
| columns_list=["input_ids", "input_mask", "input_length"], | |||||
| shuffle=do_shuffle, | |||||
| num_shards=device_num, | |||||
| shard_id=rank_id) | |||||
| print("batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_ids") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_mask") | |||||
| ds = ds.map(operations=type_cast_op, input_columns="input_length") | |||||
| ds = ds.batch(gpt2_net_cfg.batch_size, drop_remainder=True) | |||||
| ds = ds.repeat(repeat_count) | |||||
| print("dataset size: {}".format(ds.get_dataset_size())) | |||||
| print("repeat count: {}".format(ds.get_repeat_count())) | |||||
| print("output shape: {}".format(ds.output_shapes())) | |||||
| print("output type: {}".format(ds.output_types())) | |||||
| print("============== create dataset successful ===============") | |||||
| return ds | |||||
| @@ -0,0 +1,104 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GPT-2 finetune config and GPT-2 model config""" | |||||
| from easydict import EasyDict as edict | |||||
| import mindspore.common.dtype as mstype | |||||
| from .GPT2_model import GPT2Config | |||||
| cfg = edict({ | |||||
| 'gpt2_network': 'large', | |||||
| 'optimizer': 'Lamb', | |||||
| 'AdamWeightDecay': edict({ | |||||
| 'learning_rate': 1e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 0.01, | |||||
| 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), | |||||
| 'eps': 1e-6, | |||||
| }), | |||||
| 'Lamb': edict({ | |||||
| 'learning_rate': 1e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 0.01, | |||||
| 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), | |||||
| }), | |||||
| 'Momentum': edict({ | |||||
| 'learning_rate': 2e-5, | |||||
| 'momentum': 0.9, | |||||
| }), | |||||
| }) | |||||
| """ | |||||
| three kinds of GPT2 model version | |||||
| """ | |||||
| if cfg.gpt2_network == 'small': | |||||
| gpt2_net_cfg = GPT2Config( | |||||
| batch_size=8, | |||||
| seq_length=1024, | |||||
| vocab_size=50257, | |||||
| d_model=768, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout=0.1, | |||||
| attention_dropout=0.1, | |||||
| max_position_embeddings=1024, | |||||
| initializer_range=0.02, | |||||
| input_mask_from_dataset=True, | |||||
| summary_first_dropout=0.1, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| ) | |||||
| if cfg.gpt2_network == 'medium': | |||||
| gpt2_net_cfg = GPT2Config( | |||||
| batch_size=8, | |||||
| seq_length=1024, | |||||
| vocab_size=50257, | |||||
| d_model=1024, | |||||
| num_hidden_layers=24, | |||||
| num_attention_heads=16, | |||||
| intermediate_size=4096, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout=0.1, | |||||
| attention_dropout=0.1, | |||||
| max_position_embeddings=1024, | |||||
| initializer_range=0.02, | |||||
| input_mask_from_dataset=True, | |||||
| summary_first_dropout=0.1, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| ) | |||||
| if cfg.gpt2_network == 'large': | |||||
| gpt2_net_cfg = GPT2Config( | |||||
| batch_size=6, | |||||
| seq_length=1024, | |||||
| vocab_size=50257, | |||||
| d_model=1280, | |||||
| num_hidden_layers=36, | |||||
| num_attention_heads=20, | |||||
| intermediate_size=5120, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout=0.1, | |||||
| attention_dropout=0.1, | |||||
| max_position_embeddings=1024, | |||||
| initializer_range=0.02, | |||||
| input_mask_from_dataset=True, | |||||
| summary_first_dropout=0.1, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| ) | |||||
| @@ -0,0 +1,464 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """GPT-2 finetune for downstream task""" | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore import context | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore.communication.management import get_group_size | |||||
| from .utils.CrossEntropy import CrossEntropyCalculationWithMask | |||||
| from .clip_grad_utils import clip_grad | |||||
| from .GPT2ForLanguageModel import GPT2LanguageModel | |||||
| from .GPT2ForLambada import GPT2LambadaModel | |||||
| from .GPT2ForCBT import GPT2CBTModel | |||||
| from .GPT2ForTranslation import GPT2TranslationModel | |||||
| from .GPT2ForReadComprehension import GPT2CoQAModel | |||||
| from .GPT2ForSummarization import GPT2SummarizationModel | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | |||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
| grad_overflow = P.FloatStatus() | |||||
| @_grad_overflow.register("Tensor") | |||||
| def _tensor_grad_overflow(grad): | |||||
| return grad_overflow(grad) | |||||
| class GPT2FinetuneCell(nn.Cell): | |||||
| """ | |||||
| Specifically defined for finetuning where only three inputs tensor are needed. | |||||
| Args: | |||||
| network (Cell): The training network. Note that loss function should have been added. | |||||
| optimizer (Optimizer): Optimizer for updating the weights. | |||||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(GPT2FinetuneCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("gradients_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.gpu_target = False | |||||
| if context.get_context("device_target") == "GPU": | |||||
| self.gpu_target = True | |||||
| self.float_status = P.FloatStatus() | |||||
| self.addn = P.AddN() | |||||
| self.reshape = P.Reshape() | |||||
| else: | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| label_ids, | |||||
| sens=None): | |||||
| """ | |||||
| GPT2 Finetune. | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||||
| """ | |||||
| weights = self.weights | |||||
| init = False | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| label_ids) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| if not self.gpu_target: | |||||
| init = self.alloc_status() | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| label_ids, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| if not self.gpu_target: | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| else: | |||||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||||
| flag_sum = self.addn(flag_sum) | |||||
| flag_sum = self.reshape(flag_sum, (())) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class GPT2LM(nn.Cell): | |||||
| """ | |||||
| Train interface for Language Modeling finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||||
| super(GPT2LM, self).__init__() | |||||
| self.gpt2 = GPT2LanguageModel(config, is_training, use_one_hot_embeddings) | |||||
| self.num_labels = config.vocab_size | |||||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||||
| num_labels=self.num_labels, | |||||
| config=config) | |||||
| self.is_training = is_training | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, input_ids, input_mask, label_ids): | |||||
| """ | |||||
| construct function for Language Modeling | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||||
| Returns: | |||||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||||
| otherwise, return the computed loss. | |||||
| """ | |||||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||||
| if self.is_training: | |||||
| shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||||
| label_ids = label_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| return lm_logits | |||||
| class GPT2Lambada(nn.Cell): | |||||
| """ | |||||
| Train interface for Lambada finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| super(GPT2Lambada, self).__init__() | |||||
| self.gpt2 = GPT2LambadaModel(config, is_training, use_one_hot_embeddings) | |||||
| self.num_labels = config.vocab_size | |||||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||||
| num_labels=self.num_labels, | |||||
| config=config) | |||||
| self.is_training = is_training | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, input_ids, input_mask, label_ids=None): | |||||
| """ | |||||
| construct function for Lambada task | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| Returns: | |||||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||||
| otherwise, return the computed loss. | |||||
| """ | |||||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||||
| if self.is_training: | |||||
| shift_logits = lm_logits[:, :-1, :] # [batch_size, seq_length - 1, vocab_size] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||||
| label_ids = label_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| return lm_logits | |||||
| class GPT2CBT(nn.Cell): | |||||
| """ | |||||
| Train interface for Children's Book Test finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||||
| super(GPT2CBT, self).__init__() | |||||
| self.gpt2 = GPT2CBTModel(config, is_training, use_one_hot_embeddings) | |||||
| self.num_labels = config.vocab_size | |||||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||||
| num_labels=self.num_labels, | |||||
| config=config) | |||||
| self.is_training = is_training | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.cast = P.Cast() | |||||
| def construct(self, input_ids, input_mask): | |||||
| """ | |||||
| construct function for CBT task | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| Returns: | |||||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||||
| otherwise, return the computed loss. | |||||
| """ | |||||
| lm_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||||
| if self.is_training: | |||||
| shift_logits = lm_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||||
| label_ids = input_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| return lm_logits | |||||
| class GPT2Translation(nn.Cell): | |||||
| """ | |||||
| Train interface for Translation finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| super(GPT2Translation, self).__init__() | |||||
| self.gpt2 = GPT2TranslationModel(config, is_training, use_one_hot_embeddings) | |||||
| self.num_labels = config.vocab_size | |||||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||||
| num_labels=self.num_labels, | |||||
| config=config) | |||||
| self.is_training = is_training | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| def construct(self, input_ids, input_mask, label_ids): | |||||
| """ | |||||
| construct function for Translation task | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||||
| Returns: | |||||
| translation_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||||
| otherwise, return the computed loss. | |||||
| """ | |||||
| translation_logits = self.gpt2(input_ids, input_mask) # [batch_size, seq_length, vocab_size] | |||||
| translation_logits = self.log_softmax(translation_logits) | |||||
| if self.is_training: | |||||
| shift_logits = translation_logits[::, :-1, ::] # [batch_size, seq_length - 1, vocab_size] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) # [batch * (seq_length - 1), vocab_size] | |||||
| label_ids = label_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| return translation_logits | |||||
| class GPT2Summarization(nn.Cell): | |||||
| """ | |||||
| Train interface for Summary finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config=None, is_training=None, use_one_hot_embeddings=False): | |||||
| super(GPT2Summarization, self).__init__() | |||||
| self.gpt2 = GPT2SummarizationModel(config, is_training, use_one_hot_embeddings) | |||||
| self.is_training = is_training | |||||
| self.last_idx = (-1,) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = P.Shape() | |||||
| self.batch_size = config.batch_size | |||||
| self.seq_length = config.seq_length | |||||
| self.vocab_size = config.vocab_size | |||||
| self.cast = P.Cast() | |||||
| self.loss_function = CrossEntropyCalculationWithMask(num_labels=self.vocab_size, | |||||
| is_training=self.is_training, | |||||
| config=config) | |||||
| def construct(self, input_ids, input_mask, label_ids): | |||||
| """ | |||||
| construct function for Language Modeling | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||||
| Returns: | |||||
| loss (mstype.float32): if is_training is True, return the computed loss. | |||||
| """ | |||||
| output = self.gpt2(input_ids, input_mask) | |||||
| shift_logits = output[::, :-1, ::] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.vocab_size)) | |||||
| shift_logits = self.log_softmax(shift_logits) | |||||
| label_ids = label_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss_function(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| class GPT2CoQA(nn.Cell): | |||||
| """ | |||||
| Train interface for Reading Comprehension finetuning task. | |||||
| Args: | |||||
| config (class): the configuration of GPT-2 model. | |||||
| is_training (bool): whether to train. | |||||
| use_one_hot_embeddings (bool): whether to use onehot embeddings. | |||||
| """ | |||||
| def __init__(self, config, is_training, use_one_hot_embeddings=False): | |||||
| super(GPT2CoQA, self).__init__() | |||||
| self.gpt2 = GPT2CoQAModel(config, is_training, use_one_hot_embeddings) | |||||
| self.num_labels = config.vocab_size | |||||
| self.loss = CrossEntropyCalculationWithMask(is_training=is_training, | |||||
| num_labels=self.num_labels, | |||||
| config=config) | |||||
| self.is_training = is_training | |||||
| self.reshape = P.Reshape() | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| def construct(self, input_ids, input_mask, label_ids=None): | |||||
| """ | |||||
| construct function for reading comprehension task | |||||
| Args: | |||||
| input_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sequence padding mask, where 0 indicates padding position. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary | |||||
| Returns: | |||||
| lm_logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits, | |||||
| otherwise, return the computed loss. | |||||
| """ | |||||
| lm_logits = self.gpt2(input_ids, input_mask) | |||||
| lm_logits = self.log_softmax(lm_logits) | |||||
| if self.is_training: | |||||
| shift_logits = lm_logits[::, :-1, ::] | |||||
| shift_logits = self.reshape(shift_logits, (-1, self.num_labels)) | |||||
| label_ids = label_ids[::, 1:] | |||||
| input_mask = input_mask[::, 1:] | |||||
| loss = self.loss(shift_logits, label_ids, input_mask) | |||||
| return loss | |||||
| return lm_logits | |||||
| @@ -0,0 +1,82 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Calculate Cross Entropy With Mask""" | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| import mindspore.nn as nn | |||||
| class CrossEntropyCalculationWithMask(nn.Cell): | |||||
| """ | |||||
| Cross Entropy loss | |||||
| """ | |||||
| def __init__(self, is_training=None, num_labels=None, config=None): | |||||
| super(CrossEntropyCalculationWithMask, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0, mstype.float32) | |||||
| self.off_value = Tensor(0.0, mstype.float32) | |||||
| self.reduce_sum = P.ReduceSum() | |||||
| self.reduce_mean = P.ReduceMean() | |||||
| self.reshape = P.Reshape() | |||||
| self.last_idx = (-1,) | |||||
| self.neg = P.Neg() | |||||
| self.cast = P.Cast() | |||||
| self.is_training = is_training | |||||
| self.num_labels = num_labels | |||||
| if config is not None: | |||||
| # for PPL calculation in evaluation | |||||
| self.input_mask_length = Tensor(config.batch_size * (config.seq_length - 1), mstype.float32) | |||||
| def construct(self, logits, label_ids, input_mask=None): | |||||
| """ | |||||
| Calculate loss | |||||
| Args: | |||||
| logits (Tensor): the probability distribution over vocabulary. | |||||
| label_ids (Tensor): the indices of input sequence tokens in the vocabulary. | |||||
| input_mask (Tensor): input sentences padding mask, where 0 indicates padding position. | |||||
| Returns: | |||||
| return_value (Tensor, mstype.float32): if is_training is False, directly return the logits, otherwise, | |||||
| return the computed loss. | |||||
| """ | |||||
| # logits [batch * (seq_length-1), vocab_size] label_ids [batch, seq_length-1] | |||||
| if self.is_training: | |||||
| label_ids = self.reshape(label_ids, self.last_idx) # label_ids [batch * (seq_length-1)] | |||||
| one_hot_labels = self.onehot(label_ids, self.num_labels, self.on_value, | |||||
| self.off_value) # [batch * (seq_length-1), vocab_size] | |||||
| per_example_loss = self.neg( | |||||
| self.reduce_sum(one_hot_labels * logits, self.last_idx)) # [batch * (seq_length-1)] | |||||
| # for PPL calculation in evaluation | |||||
| if input_mask is not None: | |||||
| input_mask = self.cast(self.reshape(input_mask, self.last_idx), | |||||
| mstype.float32) # [batch * (seq_length-1)] | |||||
| valid_loss_sum = self.reduce_sum(input_mask * per_example_loss, ()) | |||||
| valid_element_sum = self.reduce_sum(input_mask, ()) + self.cast(F.tuple_to_array((1e-5,)), | |||||
| mstype.float32) | |||||
| loss = valid_loss_sum / valid_element_sum | |||||
| else: | |||||
| loss = self.reduce_mean(per_example_loss, self.last_idx) # a number | |||||
| return_value = self.cast(loss, mstype.float32) | |||||
| else: | |||||
| return_value = logits * 1.0 # [batch * (seq_length-1), vocab_size] | |||||
| return return_value | |||||
| @@ -0,0 +1,488 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """data preprocess for downstream task""" | |||||
| import re | |||||
| import json | |||||
| import random | |||||
| def lambada_detokenizer(string): | |||||
| string = re.sub(r"``", "-DQ-", string) | |||||
| string = re.sub(r"`", "-SQ-", string) | |||||
| string = re.sub(r"''", "-DQ-", string) | |||||
| string = re.sub(r" '", "-SQ-", string) | |||||
| string = re.sub("-DQ-", '"', string) | |||||
| string = re.sub("-SQ-", "'", string) | |||||
| string = re.sub(r"([,?!.]['\"])(\w)", "\g<1> \g<2>", string) | |||||
| # contractions | |||||
| string = string.replace("s '", "s'") | |||||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||||
| # number separators | |||||
| string = string.replace(" @-@ ", "-") | |||||
| string = string.replace(" @,@ ", ",") | |||||
| string = string.replace(" @.@ ", ".") | |||||
| # miscellaneous | |||||
| string = string.replace("= = = =", "====") | |||||
| string = string.replace("= = =", "===") | |||||
| string = string.replace("= =", "==") | |||||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||||
| string = string.replace(" \n", "\n") | |||||
| string = string.replace("\n ", "\n") | |||||
| string = string.replace(" N ", " 1 ") | |||||
| string = string.replace(" 's", "'s") | |||||
| string = string.replace(" 'd", "'d") | |||||
| string = string.replace(" '", "'") | |||||
| string = string.replace(" n't", "n't") | |||||
| string = string.replace(" .", ".") | |||||
| string = string.replace(" ,", ",") | |||||
| string = string.replace(" !", "!") | |||||
| string = string.replace(" ?", "?") | |||||
| string = string.replace(" :", ":") | |||||
| string = string.replace(" ;", ";") | |||||
| string = string.replace(" : ", ": ") | |||||
| string = string.replace(" ; ", "; ") | |||||
| string = string.replace(" ,'", ",'") | |||||
| string = string.replace(" .'", ".'") | |||||
| string = string.replace(" !'", "!'") | |||||
| string = string.replace(" ?'", "?'") | |||||
| string = string.replace("~", "") | |||||
| string = string.replace("---", "") | |||||
| string = string.replace("<", "") | |||||
| string = string.replace(">", "") | |||||
| string = string.replace("#", "") | |||||
| string = string.replace(', "', ',"') | |||||
| string = string.replace('. "', '."') | |||||
| string = string.replace('! "', '!"') | |||||
| string = string.replace('? "', '?"') | |||||
| string = string.replace('"" ', '" "') | |||||
| string = string.replace('• • •', '') | |||||
| # sensitive word process | |||||
| string = string.replace("f ** k", "fuck") | |||||
| string = string.replace("f ** king", "fucking") | |||||
| string = string.replace("f ** ked", "fucked") | |||||
| string = string.replace("c ** k", "cock") | |||||
| string = string.replace("br ** sts", "breasts") | |||||
| string = string.replace("n ** ples", "nipples") | |||||
| string = string.replace("ni ** les", "nipples") | |||||
| string = string.replace("a ** hole", "asshole") | |||||
| string = string.replace("ass ** le", "asshole") | |||||
| string = string.replace("p ** sy", "pussy") | |||||
| string = string.replace("pu ** y", "pussy") | |||||
| string = string.replace("na ** d", "naked") | |||||
| string = string.replace("nak * d", "naked") | |||||
| string = string.replace("cli ** x", "climax") | |||||
| string = string.replace("h * ps", "hips") | |||||
| string = string.replace("c * ck", "cock") | |||||
| string = string.replace("coc ** ne", "cocaine") | |||||
| string = string.replace("*", "") | |||||
| string = re.sub(" "," ",string) | |||||
| string = re.sub(" "," ",string) | |||||
| string = re.sub(" "," ",string) | |||||
| return string | |||||
| def lambada_dataset_preprocess(input_file, output_file): | |||||
| sentences = [] | |||||
| count = 0 | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| line = lambada_detokenizer(line) | |||||
| split_sentence_list = line.split() | |||||
| final_word = split_sentence_list[-1] | |||||
| context = split_sentence_list[:-1] | |||||
| new_sentence = ' '.join(context) + '\t' + ' ' + final_word | |||||
| sentences.append(new_sentence) | |||||
| count += 1 | |||||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for sentence in sentences: | |||||
| sentence = sentence.strip() | |||||
| if sentence: | |||||
| f.write(sentence) | |||||
| f.write('\n') | |||||
| count -= 1 | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||||
| def get_gold_answer_id(gold_answer, candidate_answer_list): | |||||
| id_ = 0 | |||||
| for candidate in candidate_answer_list: | |||||
| if gold_answer == candidate: | |||||
| return id_ | |||||
| id_ += 1 | |||||
| def get_passage_string(passage_string, candidate_answer, final_sentence, gold_answer_id): | |||||
| """ | |||||
| concat each candidate answer to the rest_sentence | |||||
| Args: | |||||
| candidate_answer (list): store each candidate answers | |||||
| final_sentence (str): the 21st sentence string with "XXXXX" | |||||
| gold_answer_id (int): the id of correct answer. | |||||
| return: | |||||
| candidate_passage (list): the length of candidate_sentence equals to length of candidate_answer. | |||||
| """ | |||||
| candidate_passage = [] | |||||
| for answer in candidate_answer: | |||||
| passage = passage_string + " " + final_sentence | |||||
| passage = passage.replace(" XXXXX", "\t XXXXX") | |||||
| final_passage = passage.replace("XXXXX", answer) | |||||
| whole_passage = final_passage + "\t" + str(gold_answer_id) | |||||
| candidate_passage.append(whole_passage) | |||||
| return candidate_passage | |||||
| def cbt_dataset_preprocess(input_file, output_file): | |||||
| passages = [] | |||||
| candidate_passage_list = [] | |||||
| passage_string = "" | |||||
| count = 0 | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| single_sentence = line.split(' ', 1) | |||||
| line_id = int(single_sentence[0]) | |||||
| string = single_sentence[1] | |||||
| if line_id == 21: | |||||
| string = string.replace("\t\t", "\t") | |||||
| mini_string = string.split("\t") | |||||
| candidate_answer = mini_string[-1] | |||||
| candidate_answer_list = candidate_answer.split("|") | |||||
| gold_answer_id = get_gold_answer_id(mini_string[-2], candidate_answer_list) | |||||
| candidate_passage = get_passage_string(passage_string, | |||||
| candidate_answer_list, | |||||
| mini_string[0], | |||||
| gold_answer_id) | |||||
| assert len(candidate_passage) == 10 | |||||
| count += 10 | |||||
| else: | |||||
| passage_string = passage_string + " " + string | |||||
| else: | |||||
| passages.append(candidate_passage) | |||||
| candidate_passage_list = [] | |||||
| passage_string = "" | |||||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for passage in passages: | |||||
| for candidate_passage in passage: | |||||
| candidate_passage = candidate_passage.replace(" \t ", "\t ") | |||||
| candidate_passage = candidate_passage.strip() | |||||
| f.write(candidate_passage) | |||||
| f.write("\n") | |||||
| count -= 1 | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||||
| def wikitext_detokenizer(string): | |||||
| # contractions | |||||
| string = string.replace("s '", "s'") | |||||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||||
| # number separators | |||||
| string = string.replace(" @-@ ", "-") | |||||
| string = string.replace(" @,@ ", ",") | |||||
| string = string.replace(" @.@ ", ".") | |||||
| # punctuation | |||||
| string = string.replace(" : ", ": ") | |||||
| string = string.replace(" ; ", "; ") | |||||
| string = string.replace(" . ", ". ") | |||||
| string = string.replace(" .", ".") | |||||
| string = string.replace(" ! ", "! ") | |||||
| string = string.replace(" ? ", "? ") | |||||
| string = string.replace(" , ", ", ") | |||||
| # double brackets | |||||
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |||||
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |||||
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |||||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||||
| # miscellaneous | |||||
| string = string.replace("= = = =", "====") | |||||
| string = string.replace("= = =", "===") | |||||
| string = string.replace("= =", "==") | |||||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||||
| string = string.replace(" \n", "\n") | |||||
| string = string.replace("\n ", "\n") | |||||
| string = string.replace(" N ", " 1 ") | |||||
| string = string.replace(" 's", "'s") | |||||
| return string | |||||
| def wikitext_dataset_preprocess(input_file, output_file): | |||||
| dataset_test = [] | |||||
| passage = [] | |||||
| count = 0 | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| if line.startswith('=') and line.endswith('=') and len(passage) != 0: | |||||
| dataset_test.append(passage) | |||||
| count += 1 | |||||
| passage = [] | |||||
| elif line.startswith('=') and line.endswith('='): | |||||
| continue | |||||
| else: | |||||
| passage.append(line) | |||||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for line in dataset_test: | |||||
| text = "" | |||||
| for sentence in line: | |||||
| sentence = wikitext_detokenizer(sentence) | |||||
| text = text + " " + sentence | |||||
| text = text.strip() | |||||
| f.write(text) | |||||
| f.write("\n") | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||||
| def ptb_detokenizer(string): | |||||
| string = string.replace(" '", "'") | |||||
| string = string.replace(" \n", "\n") | |||||
| string = string.replace("\n ", "\n") | |||||
| string = string.replace(" n't", "n't") | |||||
| string = string.replace(" N ", "1 ") | |||||
| string = string.replace("$ 1", "$1") | |||||
| string = string.replace("# 1", "#1") | |||||
| string = string.replace("\/abc", "") | |||||
| string = string.replace("\/ua", "") | |||||
| string = string.replace("s '", "s'") | |||||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||||
| # punctuation | |||||
| string = string.replace(" : ", ": ") | |||||
| string = string.replace(" ; ", "; ") | |||||
| string = string.replace(" . ", ". ") | |||||
| string = string.replace(" ! ", "! ") | |||||
| string = string.replace(" ? ", "? ") | |||||
| string = string.replace(" , ", ", ") | |||||
| string = string.replace(" 's", "'s") | |||||
| return string | |||||
| def ptb_dataset_preprocess(input_file, output_file): | |||||
| sentences = [] | |||||
| count = 0 | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| line = ptb_detokenizer(line) | |||||
| sentences.append(line) | |||||
| count += 1 | |||||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for sentence in sentences: | |||||
| sentence = sentence.strip() | |||||
| if sentence: | |||||
| f.write(sentence) | |||||
| f.write("\n") | |||||
| count -= 1 | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||||
| def onebw_detokenizer(string): | |||||
| # contractions | |||||
| string = string.replace("s '", "s'") | |||||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||||
| # number separators | |||||
| string = string.replace(" @-@ ", "-") | |||||
| string = string.replace(" @,@ ", ",") | |||||
| string = string.replace(" @.@ ", ".") | |||||
| # punctuation | |||||
| string = string.replace(" : ", ": ") | |||||
| string = string.replace(" ; ", "; ") | |||||
| string = string.replace(" . ", ". ") | |||||
| string = string.replace(" ! ", "! ") | |||||
| string = string.replace(" ? ", "? ") | |||||
| string = string.replace(" , ", ", ") | |||||
| # double brackets | |||||
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |||||
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |||||
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |||||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||||
| # miscellaneous | |||||
| string = string.replace("= = = =", "====") | |||||
| string = string.replace("= = =", "===") | |||||
| string = string.replace("= =", "==") | |||||
| string = string.replace(" --", "") | |||||
| string = string.replace("--", "") | |||||
| string = string.replace("? ? ?", " ?") | |||||
| string = string.replace(" " + chr(176) + " ", chr(176)) | |||||
| string = string.replace(" \n", "\n") | |||||
| string = string.replace("\n ", "\n") | |||||
| string = string.replace(" 't", "'t") | |||||
| string = string.replace(" N ", " 1 ") | |||||
| string = string.replace(" 's", "'s") | |||||
| string = string.replace(" '", "'") | |||||
| string = string.replace(" n't", "n't") | |||||
| string = string.replace("$ 1", "$1") | |||||
| string = string.replace("# 1", "#1") | |||||
| return string | |||||
| def test_length(string): | |||||
| string_list = string.split() | |||||
| return len(string_list) | |||||
| def onebw_dataset_preprocess(condition, input_file, output_file): | |||||
| sentences = [] | |||||
| count = 0 | |||||
| if condition.lower() == "test": | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| sentences.append(line) | |||||
| count += 1 | |||||
| print('read {} file finished!\n total count = {}'.format(input_file, count)) | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for sentence in sentences: | |||||
| sentence = sentence.strip() | |||||
| if sentence: | |||||
| sentence = onebw_detokenizer(sentence) | |||||
| f.write(sentence) | |||||
| f.write("\n") | |||||
| count -= 1 | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, count)) | |||||
| elif condition.lower() == "train": | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| for line in f: | |||||
| line = line.strip() | |||||
| if line: | |||||
| line = onebw_detokenizer(line) | |||||
| length = test_length(line) | |||||
| if length > 10 and length < 60: | |||||
| sentences.append(line) | |||||
| count += 1 | |||||
| print('read finished! count = ', count) | |||||
| sample_result_list = random.sample(range(0, count), 30000) | |||||
| sample_result_list.sort() | |||||
| count_sample = 0 | |||||
| choiced_sentence = "" | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for i in range(len(sample_result_list)): | |||||
| choiced_sentence = sentences[sample_result_list[i]] | |||||
| f.write(choiced_sentence) | |||||
| f.write("\n") | |||||
| count_sample += 1 | |||||
| print('write finished! ', count_sample) | |||||
| else: | |||||
| raise ValueError("condition error support: [train, test]") | |||||
| def coqa_dataset_preprocess(input_file, output_file): | |||||
| with open(input_file, 'r', encoding='utf-8') as f: | |||||
| source_data = json.load(f) | |||||
| stories = [] | |||||
| instances = [] | |||||
| end_sep = [',', '.', ';'] | |||||
| question_before_sep = " " | |||||
| question_after_sep = " A: " | |||||
| answer_sep = " A:\t" | |||||
| for i, dialog in enumerate(source_data["data"]): | |||||
| story = dialog["story"].replace("\n", "") | |||||
| stories.append(story) | |||||
| concat_ = "" | |||||
| concat_ += story | |||||
| for question, answer in zip(dialog["questions"], dialog["answers"]): | |||||
| question = question["input_text"] | |||||
| answer = answer["input_text"] | |||||
| concat_ += question_before_sep | |||||
| concat_ += question | |||||
| tmp = concat_ + question_after_sep | |||||
| concat_ += answer_sep | |||||
| concat_ += answer | |||||
| instances.append(concat_) | |||||
| concat_ = tmp + answer | |||||
| if concat_[-1] not in end_sep: | |||||
| concat_ += "." | |||||
| instances.append("") | |||||
| with open(output_file, 'w', encoding='utf-8') as f: | |||||
| for i in range(len(instances)): | |||||
| if instances[i]: | |||||
| f.write(instances[i]) | |||||
| f.write("\n") | |||||
| print('write {} file finished!\n total count = {}'.format(output_file, len(instances))) | |||||
| def wmt14_en_fr_preprocess(input_file, output_file): | |||||
| input_file = input_file + "/newstest2014-fren-ref" | |||||
| output_file = output_file + "/wmt14" | |||||
| language = ['.en.sgm', '.fr.sgm'] | |||||
| count = 0 | |||||
| # en-fr | |||||
| with open(input_file + language[0], "r", encoding='utf-8') as english, \ | |||||
| open(input_file + language[1], "r", encoding='utf-8') as french, \ | |||||
| open(output_file + '.en_fr.txt', "a", encoding='utf-8') as enfr_f, \ | |||||
| open(output_file + '.fr_en.txt', "a", encoding='utf-8') as fren_f: | |||||
| line_id = 0 | |||||
| for en, fr in zip(english, french): | |||||
| line_id += 1 | |||||
| if (en[:7] == '<seg id'): | |||||
| print("=" * 20, "\n", line_id, "\n", "=" * 20) | |||||
| en_start = en.find('>', 0) | |||||
| en_end = en.find('</seg>', 0) | |||||
| print(en[en_start + 1:en_end]) | |||||
| en_ = en[en_start + 1:en_end] | |||||
| fr_start = fr.find('>', 0) | |||||
| fr_end = fr.find('</seg>', 0) | |||||
| print(fr[fr_start + 1:fr_end]) | |||||
| fr_ = fr[fr_start + 1:fr_end] | |||||
| en_fr_str = en_ + "\t" + fr_ + "\n" | |||||
| enfr_f.write(en_fr_str) | |||||
| fr_en_str = fr_ + "\t" + en_ + "\n" | |||||
| fren_f.write(fr_en_str) | |||||
| count += 1 | |||||
| print('write {} file finished!\n total count = {}'.format(output_file + '.en_fr.txt', count)) | |||||
| print('write {} file finished!\n total count = {}'.format(output_file + '.fr_en.txt', count)) | |||||
| @@ -0,0 +1,542 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| generation utils | |||||
| """ | |||||
| import numpy as np | |||||
| from scipy.special import softmax | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | |||||
| from .tensor_manipulations import extract_single_token_logits, add_last_token | |||||
| INF = 1. * 1e9 | |||||
| class TopKTopP_Filter(): | |||||
| """ | |||||
| Top K sampling along with Top P sampling(Nucleus Sampling) | |||||
| Choose top-K probability of ids and those with top-P probability ids into candidate sample sets. | |||||
| Use np.random.multinomial to sample | |||||
| Args: | |||||
| batch_size (int): batch size of input dataset. | |||||
| vocab_size (int): the shape of each embedding vector. | |||||
| k (int): parameter for Top-K sampling, k should be in range of [0, vocab_size]. | |||||
| 0 for no filter for TopK sampling(do nothing). Default: 0. | |||||
| p (float) [Optional]: parameter for Top-P sampling a.k.a. Necleus Sampling, p is in between 0.0 and 1.0. | |||||
| Default: 1.0. | |||||
| temperature (float) [Optional]: parameter for generation, greater if generation more diverse. Default: 1.0. | |||||
| """ | |||||
| def __init__(self, | |||||
| batch_size=None, | |||||
| vocab_size=None, | |||||
| k=0, | |||||
| p=1.0, | |||||
| temperature=1.0, | |||||
| min_tokens_to_keep=1, | |||||
| ): | |||||
| self.k = k | |||||
| self.p = p | |||||
| self.temp = temperature | |||||
| self.batch_size = batch_size | |||||
| self.vocab_size = vocab_size | |||||
| self.min_tokens_to_keep = min_tokens_to_keep | |||||
| assert self.temp > 0.0, 'temperature must be positive' | |||||
| assert self.k >= 0, 'the top_k number must be no negative.' | |||||
| if self.k > 0: | |||||
| assert self.min_tokens_to_keep <= self.k, 'k must be larger than or equal to min_token_to_keep ' \ | |||||
| 'for Top-p sampling' | |||||
| if self.k == 0: | |||||
| self.k = self.vocab_size | |||||
| self.safety_mask = np.concatenate((np.ones((self.batch_size, self.min_tokens_to_keep)), | |||||
| np.zeros((self.batch_size, self.k - self.min_tokens_to_keep))), | |||||
| axis=1).astype(np.bool) | |||||
| def calculate(self, distribution): | |||||
| """ | |||||
| calculate sampling procedure with setting initialized before, return a list of sampled ids. | |||||
| Args: | |||||
| distribution (numpy.ndarray): with shape (batch_size,vocab_size) | |||||
| Returns: | |||||
| sampled ids: a list, with length of batch_size | |||||
| """ | |||||
| if self.temp != 1.0: | |||||
| distribution = distribution / float(self.temp) | |||||
| distribution_sorted = -np.sort(-distribution, axis=1) | |||||
| index_sorted = np.argsort(-distribution, axis=1) | |||||
| topk_distribution = distribution_sorted[::, :self.k if self.k > 0 else self.vocab_size] | |||||
| topk_indices = index_sorted[::, :self.k if self.k > 0 else self.vocab_size] | |||||
| # safety check of probability | |||||
| self.p = max(0.0, min(1.0, self.p)) | |||||
| cum_sum = np.cumsum(softmax(topk_distribution, axis=1), axis=1) | |||||
| bool_map = np.logical_or((cum_sum <= self.p), self.safety_mask).astype(np.float32) | |||||
| topk_distribution = topk_distribution * bool_map + np.float32(-1e5) * (1.0 - bool_map) | |||||
| topk_distribution = softmax(topk_distribution, axis=1) | |||||
| # normalize for np.float64 | |||||
| # choose np.float64 to avoid overflow in softmax operation | |||||
| topk_distribution = topk_distribution.astype(np.float64) | |||||
| for batch_idx in range(self.batch_size): | |||||
| topk_distribution[batch_idx] = topk_distribution[batch_idx] / np.sum(topk_distribution[batch_idx]) | |||||
| ret_ids = [] | |||||
| for batch_idx in range(self.batch_size): | |||||
| select_index = np.argmax(np.random.multinomial(1, topk_distribution[batch_idx])) | |||||
| ret_ids.append(topk_indices[batch_idx][select_index]) | |||||
| return ret_ids | |||||
| class Sample(): | |||||
| """ | |||||
| Initiate a Sample object for sampling next token(s) from previous text. | |||||
| Args: | |||||
| decoder (Model): GPT2 model to do generation. | |||||
| config (GPT2Config): configuration of given GPT2 model. | |||||
| tokenizer (GPT2Tokenizer): if choose to use input_str parameter in self.generate(), a tokenizer is compulsory. | |||||
| generate_length (int): number of tokens which should be generated. Default: 1. | |||||
| topk_num (int): number of k in Top-k Sampling, 0 for no condition constrained, | |||||
| equivalent to k = self.vocab_size. Default:0 | |||||
| topp_prob (float): probability parameter of Top-p sampling. | |||||
| if p = 1.0, it equals to do nothing. (nucleus sampling). Default: 1.0 | |||||
| temperature (float): temperature for Top-k sampling. Default: 1.0 | |||||
| min_tokens_to_keep (int): guarantee for there is at least min_tokens_to_keep token(s) generated. Default:1 | |||||
| early_stop (bool): whether stop when the model generates <EOS> token. | |||||
| It is functioned when batch_size is 1. Default: False | |||||
| demo_mode(bool): True if input_str is a str not a List of str. | |||||
| self.batch_size should be 1 if it is True. Default: False | |||||
| return_ids (bool): whether return ids generated from Sample. Default: False | |||||
| return_last_token_logits (bool): whether return logits of last token for each time step during generation. | |||||
| Default: False | |||||
| append_eos (bool): whether append <EOS> token id to input_ids pass directly to GPT2Model class. Default: False | |||||
| """ | |||||
| def __init__(self, | |||||
| decoder, | |||||
| config=None, | |||||
| batch_size=None, | |||||
| tokenizer=None, | |||||
| generate_length=1, | |||||
| topk_num=0, | |||||
| topp_prob=1.0, | |||||
| temperature=1.0, | |||||
| min_tokens_to_keep=1, | |||||
| early_stop=False, | |||||
| demo_mode=False, | |||||
| return_ids=False, | |||||
| return_last_token_logits=False, | |||||
| append_eos=False, | |||||
| ): | |||||
| assert config is not None, 'Config is a must for sampling.' | |||||
| self.decoder = decoder | |||||
| self.config = config | |||||
| self.tokenizer = tokenizer | |||||
| self.generate_length = generate_length | |||||
| self.topk_num = topk_num | |||||
| self.topp_prob = topp_prob | |||||
| self.temperature = temperature | |||||
| self.min_tokens_to_keep = min_tokens_to_keep | |||||
| self.early_stop = early_stop | |||||
| self.demo_mode = demo_mode | |||||
| self.return_ids = return_ids | |||||
| self.return_last_token_logits = return_last_token_logits | |||||
| self.append_eos = append_eos | |||||
| self.seq_length = config.seq_length | |||||
| self.batch_size = config.batch_size if batch_size is None else batch_size | |||||
| self.vocab_size = config.vocab_size | |||||
| self.on_value = Tensor(1.0, mstype.float32) | |||||
| self.off_value = Tensor(0.0, mstype.float32) | |||||
| self.reshape = P.Reshape() | |||||
| self.cumsum = P.CumSum() | |||||
| self.onehot = P.OneHot() | |||||
| self.cast = P.Cast() | |||||
| self.concat = P.Concat() | |||||
| self.sample_function = P.RandomCategorical(mstype.int32) | |||||
| self.filter_distribution = TopKTopP_Filter(batch_size=self.batch_size, | |||||
| vocab_size=self.vocab_size, | |||||
| k=self.topk_num, | |||||
| p=self.topp_prob, | |||||
| temperature=self.temperature, | |||||
| min_tokens_to_keep=self.min_tokens_to_keep) | |||||
| if self.tokenizer is not None: | |||||
| self.eos_id = self.tokenizer.eos_token_id | |||||
| else: | |||||
| self.eos_id = config.vocab_size - 1 | |||||
| if self.tokenizer is not None: | |||||
| self.eos_text = self.tokenizer.eos_token | |||||
| else: | |||||
| self.eos_text = "<|endoftext|>" | |||||
| if self.demo_mode is True: | |||||
| assert self.batch_size == 1, 'Demo mode requires batchsize euqals to 1, but get batch_size={}'.format( | |||||
| self.batch_size) | |||||
| def _extract_string_from_tensor(self, input_ids, mode="pair"): | |||||
| """ | |||||
| Args: | |||||
| input_ids(Tensor): input sentences with shape [self.batch_size, self.seq_len] | |||||
| mode (str): ["pair", "single"] | |||||
| "pair" for tasks with paired inputs `<bos> A <eos> B <eos>`, | |||||
| such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`, | |||||
| reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`. | |||||
| "single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task. | |||||
| Returns: | |||||
| source_list (list): the list of source_text or first part of text. | |||||
| target_list (list): the list of target_text or second part of text. | |||||
| If self.batch_size is 1, it will return the first sentence of list, that is to say, the string. | |||||
| Example: | |||||
| for pair mode, if self.demo_mode is True, it will return source_list[0], target_list[0] | |||||
| """ | |||||
| assert self.tokenizer is not None, 'There is no tokenizer' | |||||
| source_list = [""] * self.batch_size | |||||
| target_list = [""] * self.batch_size | |||||
| eos_text = self.tokenizer.eos_token | |||||
| len_eos_text = len(eos_text) | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| input_ids_np = input_ids_np.reshape((self.batch_size, self.seq_length)) | |||||
| # input_ids = self.reshape(input_ids, (self.batch_size, self.seq_length)) | |||||
| if mode == "pair": | |||||
| for batch_idx in range(self.batch_size): | |||||
| sentence_tensor = input_ids_np[batch_idx] | |||||
| sentence_list = sentence_tensor.tolist()[1:] | |||||
| sentence = self.tokenizer.decode(sentence_list) | |||||
| source_start = 0 | |||||
| source_end = sentence.find(eos_text, 0) | |||||
| target_start = source_end + len_eos_text | |||||
| target_end = sentence[target_start:].find(eos_text, 0) + target_start | |||||
| source_list[batch_idx] = sentence[source_start:source_end] | |||||
| target_list[batch_idx] = sentence[target_start:target_end] | |||||
| if self.batch_size == 1 and self.demo_mode is True: | |||||
| return source_list[0], target_list[0] | |||||
| return source_list, target_list | |||||
| if mode == "single": | |||||
| for batch_idx in range(self.batch_size): | |||||
| sentence_tensor = input_ids_np[batch_idx] | |||||
| sentence_list = sentence_tensor.tolist()[1:] | |||||
| sentence = self.tokenizer.decode(sentence_list) | |||||
| source_start = 0 | |||||
| source_end = sentence.find(eos_text, 0) | |||||
| source_list[batch_idx] = sentence[source_start:source_end] | |||||
| if self.batch_size == 1 and self.demo_mode is True: | |||||
| return source_list[0] | |||||
| else: | |||||
| raise ValueError('mode:{} not supported, only support [pair, single].'.format(mode)) | |||||
| return source_list | |||||
| def _tensorize_ids_with_masks(self, src_str): | |||||
| """ | |||||
| Transform from string to tensor | |||||
| Args: | |||||
| src_str: string or list of strings | |||||
| Return: | |||||
| input_ids (Tensor): shape with [self.batch_size, self.seq_length] | |||||
| input_mask (Tensor): shape with [self.batch_size, self.seq_length] | |||||
| src_len_list (list): the length of tokens of src_string after decoded by self.tokenzier | |||||
| """ | |||||
| if isinstance(src_str, str): | |||||
| src_str = [src_str] | |||||
| single_sentence_shape = (1, self.seq_length) | |||||
| src_len_list = list() | |||||
| input_ids = None | |||||
| input_mask = None | |||||
| for batch_idx in range(self.batch_size): | |||||
| src_ids_list = self.tokenizer.encode(src_str[batch_idx]) | |||||
| src_ids_len = len(src_ids_list) | |||||
| if src_ids_len > self.seq_length: | |||||
| src_ids_list = src_ids_list[:self.seq_length] | |||||
| src_ids_len = self.seq_length | |||||
| src_len_list.append(src_ids_len) | |||||
| return_dict = self.tokenizer.prepare_for_model(src_ids_list, | |||||
| max_length=self.config.seq_length, | |||||
| add_special_tokens=False) | |||||
| input_ids_list = return_dict['input_ids'] | |||||
| input_mask_list = return_dict['attention_mask'] | |||||
| input_ids_np = np.array(input_ids_list, dtype=int) | |||||
| input_mask_np = np.array(input_mask_list, dtype=int) | |||||
| input_ids_np = input_ids_np.reshape(single_sentence_shape) | |||||
| input_mask_np = input_mask_np.reshape(single_sentence_shape) | |||||
| # input_ids_tensor = self.reshape(Tensor(np.array(input_ids_list, dtype=int), dtype=mstype.int32), | |||||
| # single_sentence_shape) | |||||
| # input_mask_tensor = self.reshape(Tensor(np.array(input_mask_list, dtype=int), dtype=mstype.int32), | |||||
| # single_sentence_shape) | |||||
| if batch_idx == 0: | |||||
| # input_ids = input_ids_tensor | |||||
| # input_mask = input_mask_tensor | |||||
| input_ids_np_ = input_ids_np | |||||
| input_mask_np_ = input_mask_np | |||||
| else: | |||||
| # input_ids = self.concat((input_ids, input_ids_tensor)) | |||||
| # input_mask = self.concat((input_mask, input_mask_tensor)) | |||||
| input_ids_np_ = np.concatenate((input_ids_np_, input_ids_np), axis=0) | |||||
| input_mask_np_ = np.concatenate((input_mask_np_, input_mask_np), axis=0) | |||||
| input_ids = Tensor(input_ids_np_, dtype=mstype.int32) | |||||
| input_mask = Tensor(input_mask_np_, dtype=mstype.int32) | |||||
| return input_ids, input_mask, src_len_list | |||||
| class LastTokenPos(): | |||||
| """ | |||||
| class for record input_strs and the position of their last tokens | |||||
| Args: | |||||
| input_ (Union[list, Tensor]): list if input is a list containing strings, | |||||
| Tensor with shape (batch_size, seq_length) representing input_mask. | |||||
| """ | |||||
| def __init__(self, input_, seq_length=1024): | |||||
| if isinstance(input_, list): | |||||
| self.input_strs = input_ | |||||
| self.input_mask = None | |||||
| else: | |||||
| self.input_strs = None | |||||
| self.input_mask = input_ | |||||
| self.seq_length = seq_length | |||||
| if self.input_strs is not None: | |||||
| self.pos_list = [len(input_str) - 1 for input_str in self.input_strs] | |||||
| else: | |||||
| input_mask_ = P.Cast()(self.input_mask, mstype.float32) | |||||
| temp_pos_list = P.ReduceSum(keep_dims=False)(input_mask_, axis=1).asnumpy().astype(np.int32).tolist() | |||||
| # minimum value is always 0 for safety | |||||
| self.pos_list = [max(0, pos - 1) for pos in temp_pos_list] | |||||
| def get_pos(self, shift: int = 0): | |||||
| # return last token if overflow | |||||
| shift_list = [min(self.seq_length - 1, pos + shift) for pos in self.pos_list] | |||||
| return shift_list | |||||
| def _sample_from_distribution(self, distribution): | |||||
| """ | |||||
| sample one token per batch from self.sample_function(). | |||||
| Arg: | |||||
| distribution (Tensor): the distribution or logits of the last token of different batches. | |||||
| shape with [batch_size, vocab_size] | |||||
| Return: | |||||
| word_index (Tensor): shape with [batch_size, ] | |||||
| """ | |||||
| distribution = self.reshape(distribution, (self.vocab_size, self.batch_size)) | |||||
| topk_distribution = distribution[:self.topk_num, ::] | |||||
| topk_distribution = self.reshape(topk_distribution, (self.batch_size, -1)) | |||||
| word_index = self.sample_function(P.Softmax()(topk_distribution), 1, 1) | |||||
| word_index = self.reshape(word_index, (-1,)) | |||||
| return word_index | |||||
| def _demo_mode_check(self, input_str): | |||||
| """ | |||||
| type check for demo_mode: 1 batch, input_str is not None and initiate full_str as input_str | |||||
| """ | |||||
| if self.batch_size == 1 and self.demo_mode is True: | |||||
| assert input_str is not None, "demo mode should have input str" | |||||
| # type check | |||||
| if isinstance(input_str, list): | |||||
| assert isinstance(input_str[0], str), "type of input_str is {}, " \ | |||||
| "which should be str instead.".format(type(input_str[0])) | |||||
| if len(input_str) != 1: | |||||
| print("[WARNING] Sample.generate: length of input_str is larger than 1, " | |||||
| "choose input_str[0] as input_str.") | |||||
| input_str = input_str[0] | |||||
| assert isinstance(input_str, str), "type of input_str is {}, " \ | |||||
| "which should be str instead.".format(input_str) | |||||
| input_str = [input_str] | |||||
| return input_str | |||||
| def _input_check_and_normalize(self, input_str=None, input_ids=None, input_mask=None, generate_length=None): | |||||
| """ | |||||
| input check function | |||||
| """ | |||||
| if input_str is not None: | |||||
| assert self.tokenizer is not None, 'if choose to give input_str, a tokenizer is necessary.' | |||||
| input_str = self._demo_mode_check(input_str) | |||||
| if input_ids is not None: | |||||
| assert input_mask is not None, 'if input_ids is given, input_mask is required either.' | |||||
| if input_str is not None and input_ids is not None and input_mask is not None: | |||||
| print('[WARNING] Sample.generate got input_str, input_ids and input_mask, ' | |||||
| 'choose input_str as default for input') | |||||
| if input_ids is None and input_mask is None: | |||||
| input_ids, input_mask, _ = self._tensorize_ids_with_masks(input_str) | |||||
| else: | |||||
| if input_str is None: | |||||
| if input_ids is not None: | |||||
| input_str = self._extract_string_from_tensor(input_ids, mode="full") | |||||
| if generate_length is not None: | |||||
| # reload generate_length | |||||
| generate_length = int(generate_length) | |||||
| assert generate_length >= 0, 'generate_length can not be negative.' | |||||
| else: | |||||
| generate_length = self.generate_length | |||||
| return input_str, input_ids, input_mask, generate_length | |||||
| def generate(self, input_str=None, input_ids=None, input_mask=None, generate_length=None, do_sample=True): | |||||
| """ | |||||
| base function for text generation given a batch_size list of str or str itself (when demo mode is on) | |||||
| Args | |||||
| input_str (list(str) or str): prompt string. | |||||
| generate_length: number of tokens to generate. | |||||
| Returns: | |||||
| generate_str: string generated by the GPT-2 model. | |||||
| full_str: input_str appended with generate_str. | |||||
| """ | |||||
| input_str, input_ids, input_mask, generate_length = self._input_check_and_normalize(input_str, | |||||
| input_ids, | |||||
| input_mask, | |||||
| generate_length) | |||||
| return_ids_list = [[]] * self.batch_size | |||||
| last_token = self.LastTokenPos(input_mask, seq_length=self.seq_length) | |||||
| for i in range(generate_length): | |||||
| last_token_pos_list = last_token.get_pos(shift=i) | |||||
| early_stop_mask = [0] * self.batch_size | |||||
| # unsorted logits (distribution) of next word | |||||
| logits = self.decoder.predict(input_ids, input_mask) | |||||
| if self.return_last_token_logits is True: | |||||
| if i == 0: | |||||
| # [batch_size, 1, vocab_size] | |||||
| return_last_logits = extract_single_token_logits(logits, last_token_pos_list) | |||||
| else: | |||||
| # [batch_size, 1, vocab_size] + [batch_size, i, vocab_size] --> [batch_size, i+1, vocab_size] | |||||
| return_last_logits = P.Concat(axis=1)((return_last_logits, | |||||
| extract_single_token_logits(logits, last_token_pos_list))) | |||||
| nextword_distribution = self.reshape(logits[0, last_token_pos_list[0]:last_token_pos_list[0]+1:1, ::], | |||||
| (1, -1)) | |||||
| # stack up nextword_distribution if batch_size is larger than 1 | |||||
| if self.batch_size > 1: | |||||
| for batch_idx in range(1, self.batch_size): | |||||
| nextword_distribution_rest = self.reshape( | |||||
| logits[batch_idx, last_token_pos_list[batch_idx]:last_token_pos_list[batch_idx] + 1:1, ::], | |||||
| (1, -1)) | |||||
| nextword_distribution = self.concat((nextword_distribution, nextword_distribution_rest)) | |||||
| if do_sample: | |||||
| # get sampled ids | |||||
| nextword_distribution = nextword_distribution.asnumpy().astype(np.float32) | |||||
| real_next_word_index_list = self.filter_distribution.calculate(nextword_distribution) | |||||
| else: | |||||
| np_nextword_distribution = nextword_distribution.asnumpy() | |||||
| next_word_index = np.argmax(np_nextword_distribution, axis=-1) | |||||
| real_next_word_index_list = next_word_index.tolist() | |||||
| append_ids = [] | |||||
| # tokenizer.decode and early_stop (if all batched generates a EOS, then it is time to say goodbye) | |||||
| for batch_idx in range(self.batch_size): | |||||
| next_word_index = real_next_word_index_list[batch_idx] | |||||
| # earlystop if the model generates a EOS token. | |||||
| if self.early_stop is True: | |||||
| if next_word_index == self.eos_id: | |||||
| if self.batch_size == 1: | |||||
| break | |||||
| else: | |||||
| early_stop_mask[batch_idx] = 1 | |||||
| continue | |||||
| return_ids_list[batch_idx].append(next_word_index) | |||||
| append_ids.append(next_word_index) | |||||
| # check early_stop mask at the end of each loop | |||||
| if 0 not in early_stop_mask: | |||||
| break | |||||
| input_ids, input_mask = add_last_token(input_ids, | |||||
| input_mask, | |||||
| overflow_strategy="shift", | |||||
| append_ids=append_ids, | |||||
| next_token_pos=last_token.get_pos(shift=i + 1)) | |||||
| # add str to full str | |||||
| generate_str = [""] * self.batch_size | |||||
| full_str = [""] * self.batch_size | |||||
| text_cnt = 0 | |||||
| for text_ids in return_ids_list: | |||||
| text = self.tokenizer.decode(text_ids) | |||||
| generate_str[text_cnt] = text | |||||
| text_cnt += 1 | |||||
| for batch_idx in range(self.batch_size): | |||||
| full_str[batch_idx] = input_str[batch_idx] + generate_str[batch_idx] | |||||
| # return by several conditions | |||||
| if self.batch_size == 1 and self.demo_mode is True: | |||||
| if self.return_ids: | |||||
| return generate_str[0], input_str[0], return_ids_list[0] | |||||
| return generate_str[0], input_str[0] | |||||
| if self.return_ids: | |||||
| if self.return_last_token_logits: | |||||
| return return_ids_list, return_last_logits | |||||
| return return_ids_list | |||||
| return generate_str, full_str | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """get config setting""" | |||||
| def get_train_setting(finetune_config): | |||||
| """get train config setting""" | |||||
| cfg = finetune_config | |||||
| print("Loading GPT2 Finetune Config setting......") | |||||
| print(" | optimizer: {}".format(cfg.optimizer)) | |||||
| opt = cfg['optimizer'] | |||||
| print(" | learning rate: {}".format(cfg[opt]['learning_rate'])) | |||||
| print(" | end learning rate: {}".format( | |||||
| cfg[opt]['end_learning_rate'] if 'end_learning_rate' in cfg[opt] else 'None')) | |||||
| print(" | weight decay: {}\n".format(cfg[opt]['weight_decay'] if 'weight_decay' in cfg[opt] else 'None')) | |||||
| def get_model_setting(finetune_config, model_config): | |||||
| """get GPT-2 model config setting""" | |||||
| cfg = finetune_config | |||||
| gpt2_net_cfg = model_config | |||||
| print("Loading GPT2 Model Config setting......") | |||||
| print(" | model size: {}".format(cfg.gpt2_network)) | |||||
| print(" | batch_size: {}".format(gpt2_net_cfg.batch_size)) | |||||
| print(" | seq_length: {}".format(gpt2_net_cfg.seq_length)) | |||||
| print(" | vocab_size: {}".format(gpt2_net_cfg.vocab_size)) | |||||
| print(" | d_model: {}".format(gpt2_net_cfg.d_model)) | |||||
| print(" | num_hidden_layers: {}".format(gpt2_net_cfg.num_hidden_layers)) | |||||
| print(" | num_attention_heads: {}".format(gpt2_net_cfg.num_attention_heads)) | |||||
| print(" | hidden_dropout: {}".format(gpt2_net_cfg.hidden_dropout)) | |||||
| print(" | attention_dropout: {}".format(gpt2_net_cfg.attention_dropout)) | |||||
| print(" | summary_first_dropout: {}\n".format(gpt2_net_cfg.summary_first_dropout)) | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning schedule""" | |||||
| import numpy as np | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR | |||||
| class GPT2LearningRate(LearningRateSchedule): | |||||
| """ | |||||
| Implements of warmup-polydecay learning rate scheduler. | |||||
| Args: | |||||
| learning_rate (float): The initial value of learning rate. | |||||
| end_learning_rate (float): The end value of learning rate. | |||||
| warmup_steps (int): The warm up steps of learning rate. | |||||
| decay_steps (int): A value used to calculate decayed learning rate. | |||||
| power (float): A value used to calculate decayed learning rate. | |||||
| Returns: | |||||
| lr (Tensor): The learning rate value for the current step. | |||||
| """ | |||||
| def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): | |||||
| super(GPT2LearningRate, self).__init__() | |||||
| self.warmup_flag = False | |||||
| if warmup_steps > 0: | |||||
| self.warmup_flag = True | |||||
| self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) | |||||
| self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) | |||||
| self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) | |||||
| self.greater = P.Greater() | |||||
| self.one = Tensor(np.array([1.0]).astype(np.float32)) | |||||
| self.cast = P.Cast() | |||||
| def construct(self, global_step): | |||||
| decay_lr = self.decay_lr(global_step) | |||||
| if self.warmup_flag: | |||||
| is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32) | |||||
| warmup_lr = self.warmup_lr(global_step) | |||||
| lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr | |||||
| else: | |||||
| lr = decay_lr | |||||
| return lr | |||||
| @@ -0,0 +1,185 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """metric method for downstream task""" | |||||
| import string | |||||
| import re | |||||
| from collections import Counter | |||||
| import numpy as np | |||||
| from .rouge_score import get_rouge_score | |||||
| from .bleu import compute_bleu | |||||
| class LastWordAccuracy(): | |||||
| """ | |||||
| LastWordAccuracy class is for lambada task (predict the final word of sentence) | |||||
| """ | |||||
| def __init__(self): | |||||
| self.acc_num = 0 | |||||
| self.total_num = 0 | |||||
| def normalize(self, word): | |||||
| """normalization""" | |||||
| word = word.lstrip() | |||||
| word = word.rstrip() | |||||
| def remove_punc(text): | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(ch for ch in text if ch not in exclude) | |||||
| def lower(text): | |||||
| return text.lower() | |||||
| return remove_punc(lower(word)) | |||||
| def update(self, predict_label, gold_label): | |||||
| if isinstance(predict_label, str) and isinstance(gold_label, str): | |||||
| predict_label = [predict_label] | |||||
| gold_label = [gold_label] | |||||
| for predict_word, gold_word in zip(predict_label, gold_label): | |||||
| self.total_num += 1 | |||||
| if self.normalize(predict_word) == self.normalize(gold_word): | |||||
| self.acc_num += 1 | |||||
| class Accuracy(): | |||||
| """ | |||||
| calculate accuracy | |||||
| """ | |||||
| def __init__(self): | |||||
| self.acc_num = 0 | |||||
| self.total_num = 0 | |||||
| def update(self, logits, labels): | |||||
| """accuracy update""" | |||||
| labels = np.reshape(labels, -1) | |||||
| logits_id = np.argmax(logits, axis=-1) | |||||
| print(" | Preict Label: {} Gold Label: {}".format(logits_id, labels)) | |||||
| self.acc_num += np.sum(labels == logits_id) | |||||
| self.total_num += len(labels) | |||||
| print("\n| Accuracy = {} \n".format(self.acc_num / self.total_num)) | |||||
| class F1(): | |||||
| """calculate F1 score""" | |||||
| def __init__(self): | |||||
| self.f1_score = 0.0 | |||||
| def get_normalize_answer_token(self, string_): | |||||
| """Lower text and remove punctuation, article and extra whitespace.""" | |||||
| def remove_articles(text): | |||||
| regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) | |||||
| return re.sub(regex, ' ', text) | |||||
| def white_space_fix(text): | |||||
| return ' '.join(text.split()) | |||||
| def remove_punc(text): | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(char for char in text if char not in exclude) | |||||
| def lower(text): | |||||
| return text.lower() | |||||
| return white_space_fix(remove_articles(remove_punc(lower(string_)))).split() | |||||
| def update(self, pred_answer, gold_answer): | |||||
| """F1 update""" | |||||
| common = Counter(pred_answer) & Counter(gold_answer) | |||||
| num_same = sum(common.values()) | |||||
| # the number of same tokens between pred_answer and gold_answer | |||||
| precision = 1.0 * num_same / len(pred_answer) if pred_answer else 0 | |||||
| recall = 1.0 * num_same / len(gold_answer) if gold_answer else 0 | |||||
| if ' '.join(pred_answer).strip() == "" and ' '.join(gold_answer).strip() == "": | |||||
| self.f1_score += 1 | |||||
| else: | |||||
| self.f1_score += 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0 | |||||
| print('| precision: {}, recall: {}\n'.format(precision, recall)) | |||||
| class BLEU(): | |||||
| """calculate BLEU score""" | |||||
| def __init__(self, tokenizer=None, max_order=4, smooth=True): | |||||
| self.bleu = 0.0 | |||||
| self.total_num = 0 | |||||
| self.tokenizer = tokenizer | |||||
| self.max_order = max_order | |||||
| self.smooth = smooth | |||||
| def sum_bleu(self, references, translations, max_order, smooth): | |||||
| """calculate the sum of bleu score""" | |||||
| all_result = [] | |||||
| bleu_avg = 0.0 | |||||
| for refer, trans in zip(references, translations): | |||||
| result = compute_bleu([[refer]], [trans], max_order, smooth) | |||||
| all_result.append(result) | |||||
| bleu_avg += result[0] | |||||
| bleu_avg /= len(references) | |||||
| return bleu_avg, all_result | |||||
| def update(self, hypotheses, references): | |||||
| """BLEU update""" | |||||
| hypo_l = [] | |||||
| ref_l = [] | |||||
| if self.tokenizer is not None: | |||||
| for hypo, ref in zip(hypotheses, references): | |||||
| if ref.strip() == '': | |||||
| print("Reference is None, skip it !") | |||||
| continue | |||||
| if hypo.strip() == '': | |||||
| print("translation is None, skip it !") | |||||
| continue | |||||
| hypo_l.append(self.tokenizer.encode(hypo)) | |||||
| ref_l.append(self.tokenizer.encode(ref)) | |||||
| if hypo_l and ref_l: | |||||
| hypotheses = hypo_l | |||||
| references = ref_l | |||||
| bleu_avg, _ = self.sum_bleu(references, hypotheses, self.max_order, self.smooth) | |||||
| self.bleu += bleu_avg * 100 | |||||
| self.total_num += 1 | |||||
| print("============== BLEU: {} ==============".format(float(self.bleu / self.total_num))) | |||||
| class Rouge(): | |||||
| ''' | |||||
| Get Rouge Score | |||||
| ''' | |||||
| def __init__(self): | |||||
| self.Rouge1 = 0.0 | |||||
| self.Rouge2 = 0.0 | |||||
| self.RougeL = 0.0 | |||||
| self.total_num = 0 | |||||
| def update(self, hypothesis, targets): | |||||
| scores = get_rouge_score(hypothesis, targets) | |||||
| self.Rouge1 += scores['rouge-1']['f'] * 100 | |||||
| self.Rouge2 += scores['rouge-2']['f'] * 100 | |||||
| self.RougeL += scores['rouge-l']['f'] * 100 | |||||
| self.total_num += 1 | |||||
| print("=============== ROUGE: {} ===============".format( | |||||
| (self.Rouge1 + self.Rouge2 + self.RougeL) / float(3.0 * self.total_num))) | |||||
| @@ -0,0 +1,466 @@ | |||||
| , | |||||
| . | |||||
| ? | |||||
| ! | |||||
| # | |||||
| ~ | |||||
| = | |||||
| - | |||||
| " | |||||
| ' | |||||
| : | |||||
| - | |||||
| … | |||||
| -- | |||||
| | | |||||
| a | |||||
| about | |||||
| above | |||||
| across | |||||
| after | |||||
| again | |||||
| against | |||||
| all | |||||
| almost | |||||
| alone | |||||
| along | |||||
| already | |||||
| also | |||||
| although | |||||
| always | |||||
| among | |||||
| an | |||||
| and | |||||
| another | |||||
| any | |||||
| anybody | |||||
| anyone | |||||
| anything | |||||
| anywhere | |||||
| are | |||||
| area | |||||
| areas | |||||
| around | |||||
| as | |||||
| ask | |||||
| asked | |||||
| asking | |||||
| asks | |||||
| at | |||||
| away | |||||
| b | |||||
| back | |||||
| backed | |||||
| backing | |||||
| backs | |||||
| be | |||||
| became | |||||
| because | |||||
| become | |||||
| becomes | |||||
| been | |||||
| before | |||||
| began | |||||
| behind | |||||
| being | |||||
| beings | |||||
| best | |||||
| better | |||||
| between | |||||
| big | |||||
| both | |||||
| bro | |||||
| but | |||||
| by | |||||
| c | |||||
| came | |||||
| can | |||||
| cannot | |||||
| case | |||||
| cases | |||||
| certain | |||||
| certainly | |||||
| clear | |||||
| clearly | |||||
| come | |||||
| could | |||||
| d | |||||
| did | |||||
| differ | |||||
| different | |||||
| differently | |||||
| do | |||||
| does | |||||
| done | |||||
| down | |||||
| down | |||||
| downed | |||||
| downing | |||||
| downs | |||||
| during | |||||
| dr | |||||
| e | |||||
| each | |||||
| early | |||||
| eh | |||||
| either | |||||
| end | |||||
| ended | |||||
| ending | |||||
| ends | |||||
| enough | |||||
| even | |||||
| evenly | |||||
| ever | |||||
| every | |||||
| everybody | |||||
| everyone | |||||
| everything | |||||
| everywhere | |||||
| f | |||||
| fact | |||||
| facts | |||||
| far | |||||
| felt | |||||
| few | |||||
| find | |||||
| finds | |||||
| first | |||||
| for | |||||
| four | |||||
| from | |||||
| full | |||||
| fully | |||||
| further | |||||
| furthered | |||||
| furthering | |||||
| furthers | |||||
| g | |||||
| gave | |||||
| general | |||||
| generally | |||||
| get | |||||
| gets | |||||
| give | |||||
| given | |||||
| gives | |||||
| going | |||||
| good | |||||
| goods | |||||
| got | |||||
| great | |||||
| greater | |||||
| greatest | |||||
| group | |||||
| grouped | |||||
| grouping | |||||
| groups | |||||
| h | |||||
| had | |||||
| has | |||||
| have | |||||
| having | |||||
| he | |||||
| her | |||||
| here | |||||
| herself | |||||
| hey | |||||
| high | |||||
| high | |||||
| high | |||||
| higher | |||||
| highest | |||||
| him | |||||
| himself | |||||
| his | |||||
| house | |||||
| how | |||||
| however | |||||
| i | |||||
| if | |||||
| important | |||||
| in | |||||
| interest | |||||
| interested | |||||
| interesting | |||||
| interests | |||||
| into | |||||
| is | |||||
| it | |||||
| its | |||||
| itself | |||||
| j | |||||
| just | |||||
| k | |||||
| kae | |||||
| keep | |||||
| keeps | |||||
| kind | |||||
| knew | |||||
| know | |||||
| known | |||||
| knows | |||||
| kya | |||||
| l | |||||
| lads | |||||
| large | |||||
| largely | |||||
| last | |||||
| later | |||||
| latest | |||||
| least | |||||
| less | |||||
| let | |||||
| lets | |||||
| like | |||||
| likely | |||||
| long | |||||
| longer | |||||
| longest | |||||
| m | |||||
| made | |||||
| make | |||||
| making | |||||
| man | |||||
| many | |||||
| may | |||||
| me | |||||
| member | |||||
| members | |||||
| men | |||||
| might | |||||
| mister | |||||
| more | |||||
| most | |||||
| mostly | |||||
| mr | |||||
| Mr | |||||
| mrs | |||||
| much | |||||
| must | |||||
| my | |||||
| myself | |||||
| n | |||||
| na | |||||
| necessary | |||||
| need | |||||
| needed | |||||
| needing | |||||
| needs | |||||
| never | |||||
| new | |||||
| new | |||||
| newer | |||||
| newest | |||||
| next | |||||
| no | |||||
| nobody | |||||
| non | |||||
| noone | |||||
| not | |||||
| nothing | |||||
| now | |||||
| nowhere | |||||
| number | |||||
| numbers | |||||
| nt | |||||
| nn | |||||
| nope | |||||
| ny | |||||
| o | |||||
| oi | |||||
| of | |||||
| off | |||||
| often | |||||
| old | |||||
| older | |||||
| oldest | |||||
| on | |||||
| once | |||||
| one | |||||
| only | |||||
| open | |||||
| opened | |||||
| opening | |||||
| opens | |||||
| or | |||||
| order | |||||
| ordered | |||||
| ordering | |||||
| orders | |||||
| other | |||||
| others | |||||
| our | |||||
| out | |||||
| over | |||||
| oh | |||||
| p | |||||
| part | |||||
| parted | |||||
| parting | |||||
| parts | |||||
| per | |||||
| perhaps | |||||
| place | |||||
| places | |||||
| please | |||||
| point | |||||
| pointed | |||||
| pointing | |||||
| points | |||||
| possible | |||||
| present | |||||
| presented | |||||
| presenting | |||||
| presents | |||||
| problem | |||||
| problems | |||||
| put | |||||
| puts | |||||
| q | |||||
| quite | |||||
| r | |||||
| rather | |||||
| really | |||||
| right | |||||
| right | |||||
| room | |||||
| rooms | |||||
| s | |||||
| said | |||||
| same | |||||
| saw | |||||
| say | |||||
| says | |||||
| second | |||||
| seconds | |||||
| see | |||||
| seem | |||||
| seemed | |||||
| seeming | |||||
| seems | |||||
| sees | |||||
| several | |||||
| shall | |||||
| she | |||||
| should | |||||
| show | |||||
| showed | |||||
| showing | |||||
| shows | |||||
| side | |||||
| sides | |||||
| since | |||||
| small | |||||
| smaller | |||||
| smallest | |||||
| so | |||||
| some | |||||
| somebody | |||||
| someone | |||||
| something | |||||
| somewhere | |||||
| state | |||||
| states | |||||
| still | |||||
| still | |||||
| such | |||||
| sure | |||||
| t | |||||
| take | |||||
| taken | |||||
| than | |||||
| that | |||||
| the | |||||
| their | |||||
| them | |||||
| then | |||||
| there | |||||
| therefore | |||||
| these | |||||
| they | |||||
| thing | |||||
| things | |||||
| think | |||||
| thinks | |||||
| this | |||||
| those | |||||
| though | |||||
| thought | |||||
| thoughts | |||||
| three | |||||
| through | |||||
| thus | |||||
| to | |||||
| today | |||||
| together | |||||
| too | |||||
| took | |||||
| toward | |||||
| turn | |||||
| turned | |||||
| turning | |||||
| turns | |||||
| two | |||||
| u | |||||
| uh | |||||
| um | |||||
| under | |||||
| until | |||||
| up | |||||
| upon | |||||
| us | |||||
| use | |||||
| used | |||||
| uses | |||||
| v | |||||
| very | |||||
| w | |||||
| want | |||||
| wanted | |||||
| wanting | |||||
| wants | |||||
| was | |||||
| way | |||||
| ways | |||||
| we | |||||
| well | |||||
| wells | |||||
| went | |||||
| were | |||||
| what | |||||
| when | |||||
| where | |||||
| whether | |||||
| which | |||||
| while | |||||
| who | |||||
| whole | |||||
| whose | |||||
| why | |||||
| will | |||||
| with | |||||
| within | |||||
| without | |||||
| work | |||||
| worked | |||||
| working | |||||
| works | |||||
| would | |||||
| x | |||||
| y | |||||
| ya | |||||
| ye | |||||
| year | |||||
| years | |||||
| yet | |||||
| you | |||||
| young | |||||
| younger | |||||
| youngest | |||||
| your | |||||
| yours | |||||
| z | |||||
| @@ -0,0 +1,39 @@ | |||||
| """Calculate ROUGE score.""" | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from typing import List | |||||
| from rouge import Rouge | |||||
| def get_rouge_score(hypothesis: List[str], target: List[str]): | |||||
| """ | |||||
| Calculate ROUGE score. | |||||
| Args: | |||||
| hypothesis (List[str]): Inference result. | |||||
| target (List[str]): Reference. | |||||
| """ | |||||
| if not hypothesis or not target: | |||||
| raise ValueError(f"`hypothesis` and `target` can not be None.") | |||||
| _rouge = Rouge() | |||||
| print("hypothesis:", hypothesis) | |||||
| print("target:", target) | |||||
| scores = _rouge.get_scores(hypothesis, target, avg=True) | |||||
| print(" | ROUGE Score:") | |||||
| print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}") | |||||
| print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}") | |||||
| print(f" | RG-L(F): {scores['rouge-l']['f'] * 100:8.2f}") | |||||
| return scores | |||||
| @@ -0,0 +1,186 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| task utils | |||||
| """ | |||||
| import regex as re | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.common.tensor import Tensor | |||||
| # for lambada task | |||||
| def extract_logits(logits=None, position=None): | |||||
| """ | |||||
| Args | |||||
| logits (Tensor): Tensor(batch_size,seq_length,vocab_size) e.g.(8,1024,50257) | |||||
| position (numpy.array): the array stored the fianl word position, shape with [batch_size, 2] | |||||
| Return: | |||||
| output_logits (Tensor): extract the Specified logit according to the position, | |||||
| shape with [batch_size, vocab_size] | |||||
| """ | |||||
| batch_size = logits.shape[0] | |||||
| for batch_idx in range(batch_size): | |||||
| word_logits_pos = int(position[batch_idx, 0] - 1) | |||||
| logit = logits[batch_idx:batch_idx+1:1, word_logits_pos, ::] # [1, vocab_size] | |||||
| if batch_idx == 0: | |||||
| output_logits = logit | |||||
| else: | |||||
| output_logits = P.Concat()((output_logits, logit)) # [batch_size, vocab_size] | |||||
| return output_logits | |||||
| def get_final_word_label(input_ids, input_length, tokenizer=None): | |||||
| """ | |||||
| get whole word label_str from input_ids | |||||
| Args: | |||||
| input_ids: Tensor(batch_size,seq_length), indices of input text | |||||
| config: GPT2Config, config of GPT2 model, if not initiated, | |||||
| this function will create a MockConfig by params of input_ids, optional | |||||
| tokenizer: GPT2Tokenizer, if not initiated, it will be created using the default setting in utils. tokenization, | |||||
| optional | |||||
| Returns: | |||||
| batch_word_label: [str], lastword str given lambada as label | |||||
| """ | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| input_length_np = input_length.asnumpy() | |||||
| batch_word_label = [] | |||||
| for batch_idx in range(len(input_ids_np)): | |||||
| word_spos = input_length_np[batch_idx, 0] | |||||
| word_epos = input_length_np[batch_idx, 1] | |||||
| final_word_ids = input_ids_np[batch_idx, word_spos:word_epos] | |||||
| final_word_str = tokenizer.decode(final_word_ids.tolist()) | |||||
| batch_word_label.append(final_word_str) | |||||
| return batch_word_label | |||||
| def calculate_final_word_loss(logits, batch_size, input_ids, input_length, loss): | |||||
| """ | |||||
| Calculate the last word loss. | |||||
| """ | |||||
| logits = logits.asnumpy() | |||||
| input_len_np = input_length.asnumpy() | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| sum_batch_loss = 0.0 | |||||
| for batch in range(batch_size): | |||||
| lastword_spos = input_len_np[batch, 0] | |||||
| lastword_epos = input_len_np[batch, 1] | |||||
| last_word_logits = logits[batch, lastword_spos - 1:lastword_epos - 1:1, ::] | |||||
| last_word_logits_tensor = Tensor(last_word_logits, mstype.float32) | |||||
| last_word_label = input_ids_np[batch, lastword_spos:lastword_epos:1] | |||||
| print("last word label: ", last_word_label) | |||||
| last_word_label_tensor = Tensor(last_word_label, mstype.int32) | |||||
| last_word_loss = loss(last_word_logits_tensor, last_word_label_tensor) | |||||
| last_word_loss = float(last_word_loss.asnumpy()) | |||||
| sum_batch_loss += last_word_loss | |||||
| print(" | loss: ", last_word_loss) | |||||
| avg_batch_loss = float(sum_batch_loss / batch_size) | |||||
| return avg_batch_loss | |||||
| # for cbt task | |||||
| def calculate_choice_prob_for_cbt(logits, batch_size, input_length, input_ids): | |||||
| """ | |||||
| calculate choice prob for cbt | |||||
| Args: | |||||
| logits: | |||||
| batch_size: Any | |||||
| input_length: {asnumpy} | |||||
| input_ids: {asnumpy} | |||||
| Returns: | |||||
| choice_prob: List[float] | |||||
| """ | |||||
| choice_prob = [] # [batch_size] | |||||
| logits = logits.asnumpy() | |||||
| input_len_np = input_length.asnumpy() | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| for batch in range(batch_size): | |||||
| sum_ = 0.0 | |||||
| rest_spos = input_len_np[batch, 0] | |||||
| rest_epos = input_len_np[batch, 1] + 1 | |||||
| for rest_pos in range(rest_spos - 1, rest_epos - 1): | |||||
| rest_token_id = input_ids_np[batch, rest_pos + 1] | |||||
| log_prob = logits[batch, rest_pos, rest_token_id] | |||||
| sum_ = sum_ + log_prob | |||||
| choice_prob.append(sum_) | |||||
| print("rest sentence prob: ", sum_) | |||||
| return choice_prob | |||||
| # for summarization task | |||||
| def modify_paramdict(param_dict, mode="zero-shot", model_prefix="gpt2."): | |||||
| """ | |||||
| modify keys of param_dict to fit model. | |||||
| Args: | |||||
| param_dic: dict, dictionary of parameters imported from a ckpt file | |||||
| mode: str, "zero-shot" for an pretrained GPT2 model; | |||||
| "finetune" for an finetuned model for certain task. | |||||
| Return: | |||||
| reorganized_param_dict: dict, new param_dict to fit in model for different tasks. | |||||
| """ | |||||
| final_param_dict = dict() | |||||
| if mode == "zero-shot": | |||||
| for name in param_dict: | |||||
| final_param_dict[model_prefix + name] = param_dict[name] | |||||
| final_param_dict['lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table'] | |||||
| elif mode == "finetuned": | |||||
| embedding_name = "gpt2_embedding_lookup.embedding_table" | |||||
| embedding_name_old = "" | |||||
| for name in param_dict: | |||||
| name_remove_prefix = name[len(model_prefix):] | |||||
| name_prefix = name[:len(model_prefix)] | |||||
| final_param_dict[name_remove_prefix] = param_dict[name] | |||||
| if embedding_name in name and name_prefix == model_prefix: | |||||
| embedding_name_old = name | |||||
| final_param_dict[embedding_name] = param_dict[embedding_name_old] | |||||
| else: | |||||
| raise ValueError("mode should be [zero-shot, finetuned]") | |||||
| return final_param_dict | |||||
| def clean_hypo(text): | |||||
| """ | |||||
| to prevent generation of empty string, and lower text | |||||
| Arg: | |||||
| text: str, input str | |||||
| Return: | |||||
| text: str, cleaned input str | |||||
| """ | |||||
| text = text.lower() | |||||
| eng_re = re.compile(r'[a-z]+', re.I) | |||||
| length_con = len(eng_re.findall(text)) | |||||
| if length_con == 0: | |||||
| return '<EMPTY>' | |||||
| return text | |||||
| @@ -0,0 +1,217 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| tensor manipulations | |||||
| """ | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore import dtype as mstype | |||||
| from mindspore.ops import operations as P | |||||
| def extract_string_from_tensor(input_ids, mode="single", config=None, tokenizer=None): | |||||
| """ | |||||
| Args: | |||||
| input_ids (Tensor): input sentences with shape [batch_size, seq_len]. | |||||
| mode (str): ["pair", "single"] | |||||
| "pair" for tasks with paired inputs `<bos> A <eos> B <eos>`, | |||||
| such as summarization task, the dataset format `<bos> Article <eos> Summary <eos>`, | |||||
| reading comprehension task, the dataset format `<bos> Passage Question <eos> Answer <eos>`. | |||||
| "single" for tasks with single input `<bos> A <eos>`, such as Language Modeling, Lambada task. | |||||
| config: the configuration of GPT-2 model. | |||||
| tokenizer: the tokenizer of GPT-2 model. | |||||
| Return: | |||||
| prompt_list (list): list of prompt_text | |||||
| reference_list (list): list of reference_text, or second part of text | |||||
| rest_list (list): list of rest_text, or rest part of text | |||||
| """ | |||||
| batch_size = config.batch_size | |||||
| seq_length = config.seq_length | |||||
| prompt_list = [""] * batch_size | |||||
| reference_list = [""] * batch_size | |||||
| eos_text = tokenizer.eos_token | |||||
| len_eos_text = len(eos_text) | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| input_ids_np = input_ids_np.reshape((batch_size, seq_length)) | |||||
| # input_ids = P.Reshape()(input_ids, (batch_size, seq_length)) | |||||
| if mode == "pair": | |||||
| for batch_idx in range(batch_size): | |||||
| sentence_tensor = input_ids_np[batch_idx] | |||||
| sentence_list = sentence_tensor.asnumpy().tolist()[1:] | |||||
| sentence = tokenizer.decode(sentence_list) | |||||
| prompt_start = 0 | |||||
| prompt_end = sentence.find(eos_text, 0) | |||||
| reference_start = prompt_end + len_eos_text | |||||
| reference_end = sentence[reference_start:].find( | |||||
| eos_text, 0) + reference_start | |||||
| prompt_list[batch_idx] = sentence[prompt_start:prompt_end] | |||||
| reference_list[batch_idx] = sentence[reference_start:reference_end] | |||||
| return prompt_list, reference_list | |||||
| # For single output datasets such as WikiText, etc. | |||||
| if mode == "single": | |||||
| for batch_idx in range(batch_size): | |||||
| sentence_tensor = input_ids_np[batch_idx] | |||||
| sentence_list = sentence_tensor.asnumpy().tolist()[1:] | |||||
| sentence = tokenizer.decode(sentence_list) | |||||
| prompt_start = 0 | |||||
| prompt_end = sentence.find(eos_text, 0) | |||||
| prompt_list[batch_idx] = sentence[prompt_start:prompt_end] | |||||
| else: | |||||
| raise NotImplementedError('mode:{} not supported.'.format(mode)) | |||||
| return prompt_list | |||||
| def extract_single_token_logits(logits=None, seq_pos=None): | |||||
| """ | |||||
| Args | |||||
| logits: (batch_size,seq_length,vocab_size) e.g. when batchsize is 8, | |||||
| sequence length is 1024 and vocab_size is 50257, | |||||
| then logits is a Tensor with shape (8,1024,50257) | |||||
| seq_pos:(batch_size) list | |||||
| Return: | |||||
| output_logits: (batch_size,1,vocab_size) extract the logit to predict the last token. | |||||
| """ | |||||
| batch_size = logits.shape[0] | |||||
| logits_np = logits.asnumpy() | |||||
| logits_type = P.DType()(logits) | |||||
| for i in range(batch_size): | |||||
| # logit = logits[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::] | |||||
| logit_np = logits_np[i:i + 1:1, seq_pos[i]:seq_pos[i] + 1:1, ::] | |||||
| if i == 0: | |||||
| # output_logits = logit | |||||
| output_logits = logit_np | |||||
| else: | |||||
| # output_logits = P.Concat()((output_logits, logit)) | |||||
| output_logits = np.concatenate((output_logits, logit_np), axis=0) | |||||
| output_logits = Tensor(output_logits, dtype=logits_type) | |||||
| return output_logits | |||||
| def get_last_one_pos(input_mask: Tensor): | |||||
| """ | |||||
| Arg: | |||||
| input_mask (Tensor): (batch_size,seq_length) | |||||
| Return: | |||||
| pos (Tensor): (batch_size,) | |||||
| """ | |||||
| input_mask_ = P.Cast()(input_mask, mstype.float32) | |||||
| pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,) | |||||
| pos = P.Cast()(pos, mstype.int32) | |||||
| pos = pos - 1 | |||||
| return pos | |||||
| def get_next_one_pos(input_mask: Tensor): | |||||
| """ | |||||
| Arg: | |||||
| input_mask (Tensor): (batch_size,seq_length) | |||||
| """ | |||||
| input_mask_ = P.Cast()(input_mask, mstype.float32) | |||||
| pos = P.ReduceSum(keep_dims=False)(input_mask_, axis=1) # (batch_size,) | |||||
| pos = P.Cast()(pos, mstype.int32) | |||||
| return pos | |||||
| def add_last_token_mask(input_mask: Tensor, overflow_strategy: str = "shift"): | |||||
| """ | |||||
| add last token mask | |||||
| Args: | |||||
| input_mask: Tensor | |||||
| overflow_strategy: str | |||||
| Returns: | |||||
| Tensor | |||||
| """ | |||||
| pos = get_next_one_pos(input_mask).asnumpy() | |||||
| input_mask_np = input_mask.asnumpy() | |||||
| maximum_length = input_mask.shape[1] | |||||
| batch_size = input_mask.shape[0] | |||||
| for idx in range(batch_size): | |||||
| # not overflow | |||||
| if pos[idx] < maximum_length: | |||||
| input_mask_np[idx][pos[idx]] = 1 | |||||
| # overflow | |||||
| else: | |||||
| if overflow_strategy == "shift": | |||||
| continue | |||||
| if overflow_strategy == "truncate": | |||||
| continue | |||||
| else: | |||||
| raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy)) | |||||
| return Tensor(input_mask_np, dtype=mstype.int32) | |||||
| def add_last_token(input_ids: Tensor, input_mask: Tensor, overflow_strategy: str = "shift", append_ids=None, | |||||
| next_token_pos=None): | |||||
| """ | |||||
| add last token | |||||
| Args: | |||||
| input_ids: Tensor | |||||
| input_mask: Tensor | |||||
| overflow_strategy: str | |||||
| append_ids: Any | |||||
| next_token_pos: Any | |||||
| Returns: | |||||
| Tensor | |||||
| """ | |||||
| # get positional list/numpy array | |||||
| if next_token_pos is None: | |||||
| pos = get_next_one_pos(input_mask).asnumpy() | |||||
| else: | |||||
| pos = next_token_pos | |||||
| # get numpy of inputs | |||||
| input_mask_np = input_mask.asnumpy() | |||||
| input_ids_np = input_ids.asnumpy() | |||||
| maximum_length = int(input_mask.shape[1]) | |||||
| batch_size = int(input_mask.shape[0]) | |||||
| for idx in range(batch_size): | |||||
| # not overflow | |||||
| if pos[idx] < maximum_length: | |||||
| input_mask_np[idx][int(pos[idx])] = 1 | |||||
| input_ids_np[idx][int(pos[idx])] = append_ids[idx] | |||||
| # overflow | |||||
| else: | |||||
| if overflow_strategy == "shift": | |||||
| # shift one token left | |||||
| input_ids_np[idx][0:maximum_length - 1] = input_ids_np[idx][1:maximum_length] | |||||
| input_ids_np[idx][maximum_length - 1] = append_ids[idx] | |||||
| continue | |||||
| if overflow_strategy == "truncate": | |||||
| # do nothing | |||||
| continue | |||||
| else: | |||||
| raise ValueError("{} is not an option in ['shift','truncate'].".format(overflow_strategy)) | |||||
| return Tensor(input_ids_np, dtype=mstype.int32), Tensor(input_mask_np, dtype=mstype.int32) | |||||
| @@ -0,0 +1,517 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| tokenization | |||||
| """ | |||||
| import json | |||||
| from functools import lru_cache | |||||
| from typing import List, Optional | |||||
| import logging | |||||
| import regex as re | |||||
| logger = logging.getLogger(__name__) | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| bytes to unicode | |||||
| """ | |||||
| bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2 ** 8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2 ** 8 + n) | |||||
| n += 1 | |||||
| cs = [chr(i) for i in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """ | |||||
| Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| class GPT2Tokenizer(): | |||||
| """ | |||||
| GPT2Tokenizer | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| vocab_file, | |||||
| merge_file, | |||||
| add_prefix_space=False, | |||||
| ): | |||||
| with open(vocab_file, 'r', encoding="utf-8") as vocab_handle: | |||||
| self.encoder = json.load(vocab_handle) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.vocab_size = len(self.decoder) | |||||
| with open(merge_file, 'r', encoding="utf-8") as merge_handle: | |||||
| bpe_merges = merge_handle.read().split('\n')[1:-1] | |||||
| bpe_merges = [tuple(merge.split()) for merge in bpe_merges] | |||||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | |||||
| self.add_prefix_space = add_prefix_space | |||||
| self.cache = {} | |||||
| self.unk_token = "<|endoftext|>" | |||||
| self.unk_token_id = 50256 | |||||
| self.bos_token = "<|endoftext|>" | |||||
| self.bos_token_id = 50256 | |||||
| self.eos_token = "<|endoftext|>" | |||||
| self.eos_token_id = 50256 | |||||
| self.pad_token = "<|endoftext|>" | |||||
| self.pad_token_id = 50256 | |||||
| def bpe(self, token): | |||||
| """ | |||||
| bpe encode | |||||
| """ | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token) | |||||
| pairs = get_pairs(token) | |||||
| if not pairs: | |||||
| return token | |||||
| while True: | |||||
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| except ValueError: | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| else: | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| if word[i] == first and i + 1 < len(word) and word[i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = " ".join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def _tokenize(self, text): | |||||
| """ Tokenize a string using bpe encode. """ | |||||
| text = self.prepare_for_tokenization(text, is_pretokenized=False) | |||||
| # print(text) | |||||
| bpe_tokens = [] | |||||
| for token in re.findall(self.pat, text): | |||||
| token = "".join( | |||||
| self.byte_encoder[b] for b in token.encode("utf-8") | |||||
| ) | |||||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) | |||||
| return bpe_tokens | |||||
| def _convert_token_to_id(self, token): | |||||
| """ the index of the token in the vocabulary. """ | |||||
| return self.encoder.get(token, self.encoder.get(self.unk_token)) | |||||
| def _convert_id_to_token(self, _id): | |||||
| """ return the origin bpe token according to id""" | |||||
| return self.decoder.get(_id) | |||||
| def _convert_tokens_to_string(self, tokens): | |||||
| """ return a string according to the list of tokens""" | |||||
| text = "".join(tokens) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors='ignore') | |||||
| return text | |||||
| def encode(self, text): | |||||
| """ get the index list of text""" | |||||
| text_id = [] | |||||
| bpe_tokens = self._tokenize(text) | |||||
| for token in bpe_tokens: | |||||
| text_id.append(self._convert_token_to_id(token)) | |||||
| return text_id | |||||
| def decode(self, ids): | |||||
| """ return a string according to the index list of tokens""" | |||||
| tokens = [] | |||||
| for id_ in ids: | |||||
| tokens.append(self._convert_id_to_token(id_)) | |||||
| return self._convert_tokens_to_string(tokens) | |||||
| def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs): | |||||
| """ whether to add a whitespace in the front of text """ | |||||
| add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) | |||||
| if is_pretokenized or add_prefix_space: | |||||
| text = " " + text | |||||
| return text | |||||
| def add_special_tokens(self, special_tokens_dict): | |||||
| """ | |||||
| Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If | |||||
| special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the | |||||
| current vocabulary). | |||||
| Args: | |||||
| special_tokens_dict (dictionary `str` to `str`): | |||||
| Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, | |||||
| ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, | |||||
| ``additional_special_tokens``]. | |||||
| Returns: | |||||
| added_tokens (int): Number of tokens added to the vocabulary | |||||
| """ | |||||
| # special_tokens_dict = {'cls_token': '<CLS>'} | |||||
| if not special_tokens_dict: | |||||
| return 0 | |||||
| added_tokens = 0 | |||||
| for key, value in special_tokens_dict.items(): | |||||
| setattr(self, key, value) | |||||
| assert isinstance(value, str), f"Token {value} for key {key} should be a str instance" | |||||
| added_tokens += self.add_tokens([value], special_tokens=True) | |||||
| return added_tokens | |||||
| def add_tokens(self, new_tokens, special_tokens=False): | |||||
| if not new_tokens: | |||||
| return 0 | |||||
| if not isinstance(new_tokens, (list, tuple)): | |||||
| new_tokens = [new_tokens] | |||||
| return self._add_tokens(new_tokens, special_tokens=special_tokens) | |||||
| def _add_tokens(self, new_tokens, special_tokens=False): | |||||
| """ | |||||
| _add_tokens | |||||
| Args: | |||||
| new_tokens (list[str]): Token(s) to add in vocabulary. | |||||
| special_tokens (bool): Whether or not the tokens should be added as special tokens. | |||||
| Returns: | |||||
| the number of the new added tokens. | |||||
| """ | |||||
| new_tokens = [str(token) for token in new_tokens] | |||||
| tokens_to_add = [] | |||||
| for token in new_tokens: | |||||
| assert isinstance(token, str) | |||||
| tokens_to_add.append(token) | |||||
| logger.info("Adding %s to the vocabulary ! ", token) | |||||
| added_tok_encoder = dict((tok, self.vocab_size + i) for i, tok in enumerate(tokens_to_add)) | |||||
| added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} | |||||
| self.encoder.update(added_tok_encoder) | |||||
| self.decoder.update(added_tok_decoder) | |||||
| return len(tokens_to_add) | |||||
| def num_special_tokens_to_add(self, pair: bool = False): | |||||
| token_ids_0 = [] | |||||
| token_ids_1 = [] | |||||
| return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) | |||||
| def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None): | |||||
| """ | |||||
| Build model inputs from a sequence or a pair of sequence by concatenating and adding special tokens. | |||||
| A GPT2 sequence has the following format: | |||||
| - single sequence: ``<bos> X <eos>`` | |||||
| - pair of sequences: ``<bos> A <eos> B <eos>`` | |||||
| Args: | |||||
| token_ids_0 (List[int]): List of IDs to which the special tokens will be added | |||||
| token_ids_1 (List[int], `optional`, defaults to `None`): Optional second list of IDs for sequence pairs. | |||||
| """ | |||||
| bos = [self.bos_token_id] | |||||
| eos = [self.eos_token_id] | |||||
| if token_ids_1 is None: | |||||
| return bos + token_ids_0 + eos | |||||
| return bos + token_ids_0 + eos + token_ids_1 + eos | |||||
| def truncate_sequences(self, ids, num_tokens_to_remove, truncation_strategy="ONLY_FIRST", direction="RIGHT"): | |||||
| """ | |||||
| truncate sequences | |||||
| Args: | |||||
| ids: Any | |||||
| num_tokens_to_remove: | |||||
| truncation_strategy: str | |||||
| direction: str | |||||
| Returns: | |||||
| (ids, overflowing_tokens): (Any, list) | |||||
| """ | |||||
| if num_tokens_to_remove <= 0: | |||||
| return ids, [] | |||||
| overflowing_tokens = [] | |||||
| if truncation_strategy == "ONLY_FIRST": | |||||
| if len(ids) > num_tokens_to_remove: | |||||
| if direction == "RIGHT": | |||||
| overflowing_tokens = ids[-num_tokens_to_remove:] | |||||
| ids = ids[:-num_tokens_to_remove] | |||||
| if direction == "LEFT": | |||||
| overflowing_tokens = ids[:num_tokens_to_remove] | |||||
| ids = ids[num_tokens_to_remove:] | |||||
| else: | |||||
| logger.error("The first sequence length is smaller than removed tokens. ") | |||||
| else: | |||||
| logger.error("Please select correct truncation strategy, for instance 'ONLY_FIRST'") | |||||
| return (ids, overflowing_tokens) | |||||
| def _pad(self, encoded_inputs, max_length=None, padding_strategy=None, | |||||
| return_attention_mask: Optional[bool] = None): | |||||
| """ | |||||
| _pad | |||||
| Args: | |||||
| encoded_inputs: | |||||
| max_length: Any | |||||
| padding_strategy: Any | |||||
| return_attention_mask: Optional[bool] | |||||
| Returns: | |||||
| encoded_inputs: | |||||
| """ | |||||
| needs_to_be_padded = (len(encoded_inputs["input_ids"]) != max_length) | |||||
| if needs_to_be_padded: | |||||
| if padding_strategy == "MAX_LENGTH": | |||||
| difference = max_length - len(encoded_inputs["input_ids"]) | |||||
| if return_attention_mask: | |||||
| encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference | |||||
| encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference | |||||
| else: | |||||
| raise ValueError("Invalid padding strategy") | |||||
| else: | |||||
| if return_attention_mask: | |||||
| encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) | |||||
| return encoded_inputs | |||||
| def pad(self, encoded_inputs, max_length: Optional[int] = None, padding_strategy="MAX_LENGTH", | |||||
| return_attention_mask=True): | |||||
| """ | |||||
| pad | |||||
| Args: | |||||
| encoded_inputs: | |||||
| max_length: Optional[int] | |||||
| padding_strategy: str | |||||
| return_attention_mask: bool | |||||
| Returns: | |||||
| batch_outputs: Dict[Any, list] | |||||
| """ | |||||
| # no batch encoded_inputs["input_ids"]--->[98, 67, 32388, 318, 1912, 287, 170, 8496, 318, 905, 2667, 32] | |||||
| if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)): | |||||
| encoded_inputs = self._pad( | |||||
| encoded_inputs, | |||||
| max_length=max_length, | |||||
| padding_strategy=padding_strategy, | |||||
| return_attention_mask=return_attention_mask | |||||
| ) | |||||
| return encoded_inputs | |||||
| # encoded_inputs with batch_size | |||||
| batch_size = len(encoded_inputs["input_ids"]) | |||||
| assert all( | |||||
| len(v) == batch_size for v in encoded_inputs.values() | |||||
| ), "Some items in the output dictionary have a different batch size than others." | |||||
| if padding_strategy == "LONGEST": | |||||
| max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"]) | |||||
| padding_strategy = "MAX_LENGTH" | |||||
| batch_outputs = {} | |||||
| for i in range(batch_size): | |||||
| inputs = dict((k, v[i]) for k, v in encoded_inputs.items()) | |||||
| outputs = self._pad( | |||||
| encoded_inputs=inputs, | |||||
| max_length=max_length, | |||||
| padding_strategy=padding_strategy, | |||||
| return_attention_mask=return_attention_mask | |||||
| ) | |||||
| for key, value in outputs.items(): | |||||
| if key not in batch_outputs: | |||||
| batch_outputs[key] = [] | |||||
| batch_outputs[key].append(value) | |||||
| return batch_outputs | |||||
| def prepare_for_model(self, | |||||
| ids, | |||||
| pair_ids=None, | |||||
| add_special_tokens=True, | |||||
| max_length=None, | |||||
| padding=None, | |||||
| truncate_direction="RIGHT", | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True): | |||||
| """ | |||||
| prepare for model | |||||
| Args: | |||||
| ids: | |||||
| pair_ids: | |||||
| add_special_tokens: bool | |||||
| max_length: Any | |||||
| padding: Any | |||||
| truncate_direction: str | |||||
| return_overflowing_tokens: bool | |||||
| return_attention_mask: bool | |||||
| Returns: | |||||
| encoded_inputs:Dict | |||||
| """ | |||||
| pair = bool(pair_ids is not None) | |||||
| len_ids = len(ids) | |||||
| len_pair_ids = len(pair_ids) if pair else 0 | |||||
| encoded_inputs = {} | |||||
| # Compute the total size of the returned encodings | |||||
| total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) | |||||
| # Truncation: Handle max sequence length | |||||
| if max_length and total_len > max_length: | |||||
| ids, overflowing_tokens = self.truncate_sequences(ids=ids, | |||||
| num_tokens_to_remove=total_len - max_length, | |||||
| truncation_strategy="ONLY_FIRST", | |||||
| direction=truncate_direction) | |||||
| if return_overflowing_tokens: | |||||
| encoded_inputs["overflowing_tokens"] = overflowing_tokens | |||||
| encoded_inputs["num_truncated_tokens"] = total_len - max_length | |||||
| if add_special_tokens: | |||||
| sequence = self.build_inputs_with_special_tokens(ids, pair_ids) | |||||
| else: | |||||
| sequence = ids + pair_ids if pair else ids | |||||
| # build output dictionary | |||||
| encoded_inputs["input_ids"] = sequence | |||||
| # check lengths | |||||
| if max_length is None or len(encoded_inputs["input_ids"]) > max_length: | |||||
| logger.warning( | |||||
| "Token indices sequence length is longer than the specified maximum sequence length " | |||||
| "for this model (%ids > %length). Running this sequence through the model will result in " | |||||
| "indexing errors", len(ids), max_length | |||||
| ) | |||||
| # padding | |||||
| if padding or return_attention_mask: | |||||
| encoded_inputs = self.pad(encoded_inputs=encoded_inputs, | |||||
| max_length=max_length, | |||||
| padding_strategy="MAX_LENGTH", | |||||
| return_attention_mask=return_attention_mask) | |||||
| return encoded_inputs | |||||
| class CNN_DailyMail_tokenizer(GPT2Tokenizer): | |||||
| """ | |||||
| CNN DailyMail tokenizer | |||||
| """ | |||||
| def prepare_for_model(self, | |||||
| ids, | |||||
| pair_ids, | |||||
| max_length=1024, | |||||
| max_summary_length=150, | |||||
| add_special_tokens=True, | |||||
| padding=None, | |||||
| return_overflowing_tokens=False, | |||||
| return_attention_mask=True): | |||||
| len_ids = len(ids) | |||||
| len_pair_ids = len(pair_ids) | |||||
| encoded_inputs = {} | |||||
| # Compute the total size of the returned encodings | |||||
| total_len = len_ids + len_pair_ids | |||||
| ids_overflowing_tokens = [] | |||||
| pair_overflowing_tokens = [] | |||||
| # Truncation: Handle max sequence length | |||||
| if total_len > max_length-3: | |||||
| if len_pair_ids > max_summary_length: | |||||
| num_tokens_to_remove = len_pair_ids - max_summary_length | |||||
| pair_ids, pair_overflowing_tokens = self.truncate_sequences(ids=pair_ids, | |||||
| num_tokens_to_remove=num_tokens_to_remove, | |||||
| truncation_strategy="ONLY_FIRST", | |||||
| direction="RIGHT") | |||||
| if len_ids+max_summary_length > max_length-3: | |||||
| num_tokens_to_remove = (len_ids + max_summary_length) - (max_length - 3) | |||||
| ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids, | |||||
| num_tokens_to_remove=num_tokens_to_remove, | |||||
| truncation_strategy="ONLY_FIRST", | |||||
| direction="RIGHT") | |||||
| else: | |||||
| ids, ids_overflowing_tokens = self.truncate_sequences(ids=ids, | |||||
| num_tokens_to_remove=total_len - (max_length-3), | |||||
| truncation_strategy="ONLY_FIRST", | |||||
| direction="RIGHT") | |||||
| if return_overflowing_tokens: | |||||
| encoded_inputs["article_overflowing_tokens"] = ids_overflowing_tokens | |||||
| encoded_inputs["highlights_overflowing_tokens"] = pair_overflowing_tokens | |||||
| encoded_inputs["num_truncated_tokens"] = total_len - (max_length-3) | |||||
| sequence = self.build_inputs_with_special_tokens(ids, pair_ids) | |||||
| encoded_inputs["input_ids"] = sequence | |||||
| # check lengths | |||||
| if max_length is None or len(encoded_inputs["input_ids"]) > max_length: | |||||
| logger.warning( | |||||
| "Token indices sequence length is longer than the specified maximum sequence length " | |||||
| "for this model (%ids > %length). Running this sequence through the model will result " | |||||
| "in indexing errors", len(ids), max_length | |||||
| ) | |||||
| # padding | |||||
| if padding or return_attention_mask: | |||||
| encoded_inputs = self.pad(encoded_inputs=encoded_inputs, | |||||
| max_length=max_length, | |||||
| padding_strategy="MAX_LENGTH", | |||||
| return_attention_mask=return_attention_mask) | |||||
| return encoded_inputs | |||||
| def Tokenizer(vocab_file="./pretrain-data/gpt2-vocab.json", | |||||
| merge_file="./pretrain-data/gpt2-merges.txt", | |||||
| mode="normal"): | |||||
| """ use the GPT2Tokenizer""" | |||||
| print(" | Tokenizer mode: {}".format(mode)) | |||||
| if mode == "normal": | |||||
| tokenizer = GPT2Tokenizer(vocab_file, merge_file, add_prefix_space=False) | |||||
| elif mode == "cnn_dailymail": | |||||
| tokenizer = CNN_DailyMail_tokenizer(vocab_file, merge_file, add_prefix_space=False) | |||||
| else: | |||||
| raise ValueError("No Such Mode for {} in src.utils.tokenization.Tokenizer()".format(mode)) | |||||
| return tokenizer | |||||
| @@ -0,0 +1,55 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| init weight | |||||
| """ | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore.common.tensor import Tensor | |||||
| def _average_units(shape): | |||||
| if not shape: | |||||
| return 1 | |||||
| if len(shape) == 1: | |||||
| return float(shape[0]) | |||||
| if len(shape) == 2: | |||||
| return float(shape[0] + shape[1]) / 2. | |||||
| raise RuntimeError("not support shape.") | |||||
| def weight_variable(shape): | |||||
| scale_shape = shape | |||||
| avg_units = _average_units(scale_shape) | |||||
| scale = 1.0 / max(1., avg_units) | |||||
| limit = math.sqrt(3.0 * scale) | |||||
| values = np.random.uniform(-limit, limit, shape).astype(np.float32) | |||||
| return Tensor(values) | |||||
| def one_weight(shape): | |||||
| ones = np.ones(shape).astype(np.float32) | |||||
| return Tensor(ones) | |||||
| def zero_weight(shape): | |||||
| zeros = np.zeros(shape).astype(np.float32) | |||||
| return Tensor(zeros) | |||||
| def normal_weight(shape, num_units): | |||||
| norm = np.random.normal(0.0, num_units ** -0.5, shape).astype(np.float32) | |||||
| return Tensor(norm) | |||||
| @@ -0,0 +1,86 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """dataset preprocess""" | |||||
| import argparse | |||||
| from src.utils.data_preprocess import lambada_dataset_preprocess | |||||
| from src.utils.data_preprocess import cbt_dataset_preprocess | |||||
| from src.utils.data_preprocess import wikitext_dataset_preprocess | |||||
| from src.utils.data_preprocess import ptb_dataset_preprocess | |||||
| from src.utils.data_preprocess import onebw_dataset_preprocess | |||||
| from src.utils.data_preprocess import coqa_dataset_preprocess | |||||
| from src.utils.data_preprocess import wmt14_en_fr_preprocess | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser(description="All Task dataset preprocessing") | |||||
| parser.add_argument("--task", type=str, default="translation", | |||||
| help="The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada" | |||||
| "Summarization, ReadingComprehension]") | |||||
| parser.add_argument("--input_file", type=str, default="", | |||||
| help="The raw dataset path. ") | |||||
| parser.add_argument("--dataset", type=str, default="onebw", | |||||
| help="The name of dataset which should be processed, only for LanguageModeling task.") | |||||
| parser.add_argument("--output_file", type=str, default="", | |||||
| help="The output dataset path after preprocessing.") | |||||
| parser.add_argument("--condition", type=str, default="test", | |||||
| help="Process train or test dataset, including [train, test], only for 1BW and " | |||||
| "CNN & DailyMail dataset.") | |||||
| args_opt = parser.parse_args() | |||||
| task = args_opt.task | |||||
| condition = args_opt.condition | |||||
| dataset = args_opt.dataset | |||||
| input_file = args_opt.input_file | |||||
| output_file = args_opt.output_file | |||||
| if task.lower() == "languagemodeling": | |||||
| print("Start processing Language Modeling dataset ...") | |||||
| if dataset.lower() == "wikitext2" or dataset.lower() == "wikitext103": | |||||
| wikitext_dataset_preprocess(input_file=input_file, output_file=output_file) | |||||
| elif dataset.lower() == "ptb": | |||||
| ptb_dataset_preprocess(input_file=input_file, output_file=output_file) | |||||
| elif dataset.lower() == "onebw": | |||||
| onebw_dataset_preprocess(condition, input_file=input_file, output_file=output_file) | |||||
| else: | |||||
| raise ValueError("Only support wikitext2, wikitext103, ptb, onebw dataset") | |||||
| elif task.lower() == "lambada": | |||||
| print("Start processing Lambada dataset ...") | |||||
| lambada_dataset_preprocess(input_file=input_file, output_file=output_file) | |||||
| elif task.lower() == "cbt": | |||||
| print("Start processing CBT dataset ...") | |||||
| cbt_dataset_preprocess(input_file=input_file, output_file=output_file) | |||||
| elif task.lower() == "readingcomprehension": | |||||
| print("Start processing ReadingComprehension dataset ...") | |||||
| coqa_dataset_preprocess(input_file=input_file, output_file=output_file) | |||||
| elif task.lower() == "summarization": | |||||
| print("Start processing Summarization dataset ...") | |||||
| elif task.lower() == "translation": | |||||
| print("Start processing Translation dataset ...") | |||||
| wmt14_en_fr_preprocess(input_file=input_file, output_file=output_file) | |||||
| else: | |||||
| raise ValueError("Only support Language Modeling, CBT, Translation, Lambada, " | |||||
| "Summarization, Reading Comprehension task.") | |||||
| if __name__ == "__main__": | |||||
| main() | |||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2017 Google Inc. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Python implementation of BLEU and smooth-BLEU. | |||||
| This module provides a Python implementation of BLEU and smooth-BLEU. | |||||
| Smooth BLEU is computed following the method outlined in the paper: | |||||
| Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic | |||||
| evaluation metrics for machine translation. COLING 2004. | |||||
| """ | |||||
| import collections | |||||
| import math | |||||
| def _get_ngrams(segment, max_order): | |||||
| """ | |||||
| Extracts all n-grams upto a given maximum order from an input segment. | |||||
| Args: | |||||
| segment: text segment from which n-grams will be extracted. | |||||
| max_order: maximum length in tokens of the n-grams returned by this | |||||
| methods. | |||||
| Returns: | |||||
| The Counter containing all n-grams upto max_order in segment | |||||
| with a count of how many times each n-gram occurred. | |||||
| """ | |||||
| ngram_counts = collections.Counter() | |||||
| for order in range(1, max_order + 1): | |||||
| for i in range(0, len(segment) - order + 1): | |||||
| ngram = tuple(segment[i:i + order]) | |||||
| ngram_counts[ngram] += 1 | |||||
| return ngram_counts | |||||
| def compute_bleu(reference_corpus, translation_corpus, max_order=4, | |||||
| smooth=False): | |||||
| """Computes BLEU score of translated segments against one or more references. | |||||
| Args: | |||||
| reference_corpus: list of lists of references for each translation. Each | |||||
| reference should be tokenized into a list of tokens. | |||||
| translation_corpus: list of translations to score. Each translation | |||||
| should be tokenized into a list of tokens. | |||||
| max_order: Maximum n-gram order to use when computing BLEU score. | |||||
| smooth: Whether or not to apply Lin et al. 2004 smoothing. | |||||
| Returns: | |||||
| 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram | |||||
| precisions and brevity penalty. | |||||
| """ | |||||
| matches_by_order = [0] * max_order | |||||
| possible_matches_by_order = [0] * max_order | |||||
| reference_length = 0 | |||||
| translation_length = 0 | |||||
| for (references, translation) in zip(reference_corpus, translation_corpus): | |||||
| reference_length += min(len(r) for r in references) | |||||
| translation_length += len(translation) | |||||
| merged_ref_ngram_counts = collections.Counter() | |||||
| for reference in references: | |||||
| merged_ref_ngram_counts |= _get_ngrams(reference, max_order) | |||||
| translation_ngram_counts = _get_ngrams(translation, max_order) | |||||
| overlap = translation_ngram_counts & merged_ref_ngram_counts | |||||
| for ngram in overlap: | |||||
| matches_by_order[len(ngram) - 1] += overlap[ngram] | |||||
| for order in range(1, max_order + 1): | |||||
| possible_matches = len(translation) - order + 1 | |||||
| if possible_matches > 0: | |||||
| possible_matches_by_order[order - 1] += possible_matches | |||||
| precisions = [0] * max_order | |||||
| for i in range(0, max_order): | |||||
| if smooth: | |||||
| precisions[i] = ((matches_by_order[i] + 1.) / | |||||
| (possible_matches_by_order[i] + 1.)) | |||||
| else: | |||||
| if possible_matches_by_order[i] > 0: | |||||
| precisions[i] = (float(matches_by_order[i]) / | |||||
| possible_matches_by_order[i]) | |||||
| else: | |||||
| precisions[i] = 0.0 | |||||
| if min(precisions) > 0: | |||||
| p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) | |||||
| geo_mean = math.exp(p_log_sum) | |||||
| else: | |||||
| geo_mean = 0 | |||||
| ratio = float(translation_length) / reference_length | |||||
| if ratio > 1.0: | |||||
| bp = 1. | |||||
| else: | |||||
| bp = math.exp(1 - 1. / ratio) | |||||
| bleu = geo_mean * bp | |||||
| return (bleu, precisions, bp, ratio, translation_length, reference_length) | |||||