| @@ -0,0 +1,140 @@ | |||
| # It is still under development | |||
| # Contents | |||
| - [Contents](#contents) | |||
| - [PanGu-Alpha Description](#pangu-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Quick Start](#quick-start) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| - [Requirements](#requirements) | |||
| # [PanGu-Alpha Description](#pangu-description) | |||
| We release the code to explore the new front-edge of training large model with billions or even trillions of parameters. | |||
| By MindSpore's parallel feature, we adopt the efficient model parallel and data parallel technology such as operator level parallelism, | |||
| to minimize the communication cost and maximize computation efficiency. | |||
| The code is easy to scale to thousands of NPUs and trillion parameters with little modifications. | |||
| In the mean while, we run our parallel training upon a language model, named PanGu-Alpha, to demonstrate the large model can be trained easily | |||
| with our parallel setting. We summarized the training tricks as followings: | |||
| 1. Op-level Model Parallelism | |||
| 2. Pipeline Model Parallelism | |||
| 3. Optimizer Model Parallelism | |||
| The above features can be found [here](https://www.mindspore.cn/doc/programming_guide/en/r1.2/auto_parallel.html). | |||
| More amazing features are still under developing. | |||
| The technical report and checkpoint file can be found [here](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-AIpha). | |||
| # [Model Architecture](#contents) | |||
|  | |||
| The architecture of PanGu-α is based on Transformer, which has been extensively used as the backbone of a variety of | |||
| pretrained language models such as BERT and GPT. Different from them, we develop an additional query layeron top of | |||
| Transformer layers to predict the next token. The diagram of the model is shown in Figure 1. | |||
| # [Dataset](#dataset) | |||
| - Open Source Dataset. | |||
| The above dataset is preprocessed with 1024 tokens for each example. The default column key in dataset.py is `input_ids`. | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend processor. | |||
| - Framework | |||
| - [MindSpore](https://gitee.com/mindspore/mindspore) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Quick Start](#contents) | |||
| ## Generate Dataset | |||
| Suppose the text data is under the ./data and each text file ends with 'txt', we can run the following command to generate the mindrecord files with seq_length=1024, feature columns is `input_ids`. The output files is under | |||
| `output`. | |||
| ```bash | |||
| python src/preprocess.py --input_glob data/*.txt | |||
| ``` | |||
| ## Run Training | |||
| After installing MindSpore via the official website, you can start training as follows: | |||
| ```bash | |||
| # run distributed training example | |||
| bash scripts/run_distribute_training.sh /path/dataset /path/hccl.json 8 | |||
| ``` | |||
| We recommend to run the code on 32 Ascend cards. | |||
| For distributed training, an hccl configuration file with JSON format needs to be created in advance. | |||
| Please follow the instructions in the link below: | |||
| https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools. | |||
| ## Prediction | |||
| ### Download Checkpoint | |||
| Please refer to the [website](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha) to download the following parts: | |||
| - tokenizer: vocab.txt and vocab.model | |||
| - checkpint file: \*.part\[0-4\] and *.npy under the same parameter size | |||
| - strategy file: a file described how the parameters are sliced across different devices. | |||
| ### Run Prediction | |||
| ```bash | |||
| $FILE_PATH=/home/your_path | |||
| bash scripts/run_distribute_predict.sh 8 /home/config/rank_table_8p.json ${FILE_PATH}/strategy_load_ckpt/strategy.ckpt \ | |||
| ${FILE_PATH}/tokenizer/ ${FILE_PATH}/checkpoint_file filitered 2.6B | |||
| ``` | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ```bash | |||
| . | |||
| ├── docs | |||
| │ └── model.png | |||
| ├── predict.py | |||
| ├── README.md | |||
| ├── scripts | |||
| │ ├── run_distribute_predict.sh | |||
| │ └── run_distribute_train.sh | |||
| ├── src | |||
| │ ├── dataset.py | |||
| │ ├── generate.py | |||
| │ ├── pangu_alpha_config.py | |||
| │ ├── pangu_alpha.py | |||
| │ ├── pangu_alpha_wrapcell.py | |||
| │ ├── preprocess.py | |||
| │ ├── tokenization_jieba.py | |||
| │ └── utils.py | |||
| └── train.py | |||
| ``` | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| # [Requirements](#contents) | |||
| - mindspore 1.2 | |||
| - jieba 0.42.1 | |||
| - sentencepiece 0.1.94 | |||
| @@ -0,0 +1,133 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| PanGu predict run | |||
| """ | |||
| import os | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.model import Model | |||
| import mindspore.communication.management as D | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_distributed_checkpoint | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.parallel._cost_model_context import _set_multi_subgraphs | |||
| from mindspore.parallel import set_algo_parameters | |||
| from src.pangu_alpha import PanguAlpha, EvalNet | |||
| from src.pangu_alpha_config import PANGUALPHAConfig, set_parse | |||
| from src.utils import get_args | |||
| def run_predict(args_opt): | |||
| r""" | |||
| The main function for running prediction | |||
| """ | |||
| device_id = int(os.getenv("DEVICE_ID")) | |||
| rank_id_str = os.getenv('RANK_ID', '0') | |||
| rank_id = int( | |||
| rank_id_str[rank_id_str.rfind('-') + | |||
| 1:]) | |||
| print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str)) | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| local_rank = rank_id | |||
| print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True) | |||
| context.set_context(save_graphs=False, | |||
| mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| device_id=device_id) | |||
| context.set_context(variable_memory_max_size="30GB") | |||
| if args_opt.distribute == "true": | |||
| D.init() | |||
| device_num = D.get_group_size() | |||
| rank = D.get_rank() | |||
| print("device_id is {}, rank_id is {}, device_num is {}".format( | |||
| device_id, rank, device_num)) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, | |||
| gradients_mean=False, | |||
| device_num=device_num, | |||
| full_batch=True, | |||
| loss_repeated_mean=True, | |||
| enable_parallel_optimizer=False, | |||
| strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, | |||
| pipeline_stages=args_opt.stage_num) | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| _set_multi_subgraphs() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| model_parallel_num = args_opt.tensor_model_parallel_num | |||
| data_parallel_num = int(device_num / model_parallel_num) | |||
| per_batch_size = args_opt.per_batch_size | |||
| batch_size = per_batch_size * data_parallel_num | |||
| config = PANGUALPHAConfig( | |||
| data_parallel_num=data_parallel_num, | |||
| model_parallel_num=model_parallel_num, | |||
| batch_size=batch_size, | |||
| seq_length=args_opt.seq_length, | |||
| vocab_size=args_opt.vocab_size, | |||
| embedding_size=args_opt.embedding_size, | |||
| num_layers=args_opt.num_layers, | |||
| num_heads=args_opt.num_heads, | |||
| expand_ratio=4, | |||
| post_layernorm_residual=False, | |||
| dropout_rate=0.0, | |||
| compute_dtype=mstype.float16, | |||
| use_past=False, | |||
| self_layernorm=True, | |||
| stage_num=args_opt.stage_num, | |||
| micro_size=args_opt.micro_size, | |||
| eod_reset=False, | |||
| word_emb_dp=True, | |||
| load_ckpt_path=args_opt.load_ckpt_path) | |||
| print("===config is: ", config, flush=True) | |||
| print("=====args_opt is: ", args_opt, flush=True) | |||
| ckpt_name = args_opt.load_ckpt_name | |||
| pangu_alpha = PanguAlpha(config) | |||
| eval_net = EvalNet(pangu_alpha) | |||
| eval_net.set_train(False) | |||
| model_predict = Model(eval_net) | |||
| inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) | |||
| predict_layout = model_predict.infer_predict_layout(inputs_np) | |||
| print("======start load_distributed checkpoint", flush=True) | |||
| # For 2.6B and 13B models, the number of ckpt files is 512. | |||
| ckpt_name = 'filerted' | |||
| ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in | |||
| range(0, 512)] | |||
| print(f"Loading from path {ckpt_file_list[0]}", flush=True) | |||
| load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout) | |||
| print("================load param ok=================", flush=True) | |||
| from src.tokenization_jieba import JIEBATokenizer | |||
| from src.generate import generate | |||
| tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'), | |||
| os.path.join(args_opt.tokenizer_path, 'vocab10.model')) | |||
| sample = "今天是一个好天气" | |||
| tokenized_token = tokenizer.tokenize(sample) | |||
| start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) | |||
| input_ids = np.array(start_sentence).reshape(1, -1) | |||
| output_ids = generate(model_predict, input_ids, config.seq_length, 9) | |||
| output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) | |||
| print('Output is:', output_samples, flush=True) | |||
| if __name__ == "__main__": | |||
| opt = get_args() | |||
| set_parse(opt) | |||
| run_predict(opt) | |||
| @@ -0,0 +1,22 @@ | |||
| #!/bin/bash | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| export RANK_SIZE=$1 | |||
| export RANK_TABLE_FILE=$2 | |||
| export STRATEGY=$3 | |||
| export TOKENIZER=$4 | |||
| export CKPT_PATH=$5 | |||
| export CKPT_NAME=$6 | |||
| export MODE=$7 | |||
| for((i=0;i<$RANK_SIZE;i++)); | |||
| do | |||
| rm -rf ${execute_path}/device_$i/ | |||
| mkdir ${execute_path}/device_$i/ | |||
| cd ${execute_path}/device_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../predict.py --strategy_load_ckpt_path=$STRATEGY --tokenizer_path=$TOKENIZER --load_ckpt_path=$CKPT_PATH \ | |||
| --load_ckpt_name=$CKPT_NAME --mode=$MODE --run_type=predict >train_deep$i.log 2>&1 & | |||
| done | |||
| @@ -0,0 +1,38 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| echo "==============================================================================================================" | |||
| echo "Please run the script as: " | |||
| echo "bash run_distributed_pretrain_ascend.sh DATA_DIR RANK_TABLE_FILE DEVICE_NUM" | |||
| echo "for example: bash run_distributed_pretrain_ascend.sh /path/dataset /path/hccl.json 8" | |||
| echo "It is better to use absolute path." | |||
| echo "==============================================================================================================" | |||
| ROOT_PATH=`pwd` | |||
| DATA_DIR=$1 | |||
| export RANK_TABLE_FILE=$2 | |||
| RANK_SIZE=$3 | |||
| for((i=0;i<${RANK_SIZE};i++)); | |||
| do | |||
| rm ${ROOT_PATH}/device$i/ -rf | |||
| mkdir ${ROOT_PATH}/device$i | |||
| cd ${ROOT_PATH}/device$i || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python ${ROOT_PATH}/train.py --distribute=true --device_num=$RANK_SIZE --data_url=$DATA_DIR --run_type=train >log$i.log 2>&1 & | |||
| done | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Create dataset for training and evaluating | |||
| """ | |||
| import os | |||
| import numpy as np | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| import mindspore.common.dtype as mstype | |||
| def get_input_data(input_ids, eod_id, rank, dis): | |||
| """ | |||
| Generate position_id and attention_mask according to input_ids considering eod reset | |||
| Inputs: | |||
| input_ids: the input token ids | |||
| eod_id: the id for <EOD> | |||
| returns: | |||
| input_ids: the input token ids | |||
| position_id: the position ids cosidering eod reset | |||
| attention_mask: the attention mask considering eod reset | |||
| """ | |||
| rank = int(rank) | |||
| input_ids = input_ids[rank*dis: (rank+1)*dis] | |||
| seq_length = input_ids.shape[1] - 1 | |||
| batch_input_ids = input_ids | |||
| batch_position_ids = np.ones((dis, seq_length)) | |||
| batch_attention_mask = np.ones((dis, seq_length, seq_length)) | |||
| for bs_i, _ in enumerate(range(len(input_ids))): | |||
| local_ids = input_ids[bs_i] | |||
| batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length))) | |||
| batch_position_ids[bs_i] = np.arange(seq_length) | |||
| eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32) | |||
| prev_index = 0 | |||
| for i in range(eod_index.size): | |||
| index = eod_index[i] | |||
| batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0 | |||
| batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index) | |||
| prev_index = index + 1 | |||
| return batch_input_ids, batch_position_ids, batch_attention_mask | |||
| def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0, | |||
| eod_reset=False, eod_id=9, column_name='input_ids', epoch=1): | |||
| """ | |||
| Create dataset | |||
| Inputs: | |||
| batch_size: batch size | |||
| data_path: path of your MindRecord files | |||
| device_num: total device number | |||
| rank: current rank id | |||
| drop: whether drop remainder | |||
| eod_reset: whether enable position reset and attention mask reset | |||
| eod_id: the id for <EOD> | |||
| column_name: the column name of the mindrecord file. Default is input_ids | |||
| epoch: The repeat times of the dataset | |||
| Returns: | |||
| dataset_restore: the dataset for training or evaluating | |||
| """ | |||
| ds.config.set_seed(1) | |||
| home_path = os.path.join(os.getcwd(), data_path) | |||
| files = os.listdir(data_path) | |||
| dis = int(batch_size / device_num) | |||
| if dis <= 0: | |||
| raise ValueError( | |||
| "batch size {} should be a multiple of device number {}.".format(batch_size, | |||
| device_num)) | |||
| data = [ | |||
| os.path.join(home_path, name) for name in files | |||
| if not name.endswith(".db") | |||
| ] | |||
| dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| type_cast_op_float = C.TypeCast(mstype.float16) | |||
| if eod_reset: | |||
| map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis)) | |||
| dataset = dataset.batch(batch_size, drop_remainder=drop) | |||
| dataset = dataset.map(operations=map_func, input_columns=[column_name], | |||
| output_columns=["input_ids", "position_id", "attention_mask"], | |||
| column_order=["input_ids", "position_id", "attention_mask"]) | |||
| dataset = dataset.map(input_columns="position_id", operations=type_cast_op) | |||
| dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float) | |||
| else: | |||
| raise ValueError("Not supported here") | |||
| dataset = dataset.map(input_columns="input_ids", operations=type_cast_op) | |||
| dataset = dataset.repeat(epoch) | |||
| return dataset | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| TopK for text generation | |||
| """ | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| def generate(model, origin_inputs, seq_length, end_token=50256): | |||
| """ | |||
| TopK for text generation | |||
| Inputs: | |||
| model: the model for inferencing | |||
| origin_inputs: the original inputs based on which the model will continue writing | |||
| seq_length: seq_length for the model | |||
| end_token: end of sentence token id | |||
| Returns: | |||
| outputs: the ids for the generated text | |||
| """ | |||
| seq_length = seq_length | |||
| _, valid_length = origin_inputs.shape | |||
| pad_length = seq_length - origin_inputs.shape[-1] | |||
| input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) | |||
| print("input_ids is ", input_ids) | |||
| while valid_length < seq_length: | |||
| inputs = Tensor(input_ids, mstype.int32) | |||
| probs, p_args = model.predict(inputs) | |||
| probs = probs.asnumpy()[valid_length-1, :] | |||
| p_args = p_args.asnumpy()[valid_length-1, :] | |||
| p = probs | |||
| p = p / sum(p) | |||
| target_index = np.random.choice(len(p), p=p) | |||
| if p_args[target_index] == end_token or valid_length == seq_length-1: | |||
| outputs = input_ids | |||
| break | |||
| input_ids[0][valid_length] = p_args[target_index] | |||
| valid_length += 1 | |||
| length = np.sum(outputs != 0) | |||
| outputs = outputs[0][:length] | |||
| return outputs | |||
| @@ -0,0 +1,965 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """PanguAlpha model""" | |||
| import math | |||
| import os | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.common.parameter import Parameter | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.initializer import initializer, Normal, TruncatedNormal | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore import context | |||
| from mindspore.common.seed import _get_graph_seed | |||
| from mindspore._checkparam import Validator | |||
| class Dropout(nn.Cell): | |||
| r""" | |||
| A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training. | |||
| """ | |||
| def __init__(self, keep_prob=0.5, dtype=mstype.float32): | |||
| super(Dropout, self).__init__() | |||
| if keep_prob <= 0 or keep_prob > 1: | |||
| raise ValueError( | |||
| "dropout probability should be a number in range (0, 1], but got {}".format( | |||
| keep_prob)) | |||
| Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name) | |||
| self.keep_prob = keep_prob | |||
| self.is_ascend = context.get_context('device_target') in ["Ascend"] | |||
| if self.is_ascend: | |||
| seed0, seed1 = _get_graph_seed(0, "dropout") | |||
| self.seed0 = seed0 | |||
| self.seed1 = seed1 | |||
| self.dtype = dtype | |||
| self.get_shape = P.Shape() | |||
| self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1) | |||
| self.dropout_do_mask = P.DropoutDoMask() | |||
| self.cast = P.Cast() | |||
| else: | |||
| self.dropout = P.Dropout(keep_prob) | |||
| def construct(self, x): | |||
| r""" | |||
| Input: a tensor | |||
| Returns: a tensor | |||
| """ | |||
| if not self.training: | |||
| return x | |||
| if not self.is_ascend: | |||
| out, _ = self.dropout(x) | |||
| return out | |||
| if self.keep_prob == 1: | |||
| return x | |||
| shape = self.get_shape(x) | |||
| dtype = P.DType()(x) | |||
| keep_prob = self.cast(self.keep_prob, dtype) | |||
| output = self.dropout_gen_mask(shape, keep_prob) | |||
| return self.dropout_do_mask(x, output, keep_prob) | |||
| def extend_repr(self): | |||
| return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) | |||
| class LayerNorm(nn.Cell): | |||
| r""" | |||
| A self-defined layer norm operation using reduce sum and reduce mean | |||
| """ | |||
| def __init__(self, normalized_shape, dp=4, eps=1e-5, scale=1e-3): | |||
| super(LayerNorm, self).__init__() | |||
| self.gamma = Parameter(initializer('ones', normalized_shape), name="gamma") | |||
| self.beta = Parameter(initializer('zeros', normalized_shape), name="beta") | |||
| self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),)) | |||
| self.square = P.Square().shard(((dp, 1, 1),)) | |||
| self.sqrt = P.Sqrt().shard(((dp, 1, 1),)) | |||
| self.sub1 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.sub2 = P.Sub().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.add = P.TensorAdd().shard(((dp, 1, 1), ())) | |||
| self.eps = eps | |||
| self.mul = P.Mul().shard(((dp, 1, 1), (1,))) | |||
| self.add2 = P.TensorAdd().shard(((dp, 1, 1), (1,))) | |||
| self.real_div = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) | |||
| self.scale_div = P.RealDiv().shard(((dp, 1, 1), ())) | |||
| self.scale_mul = P.Mul().shard(((dp, 1, 1), ())) | |||
| self.scale = scale | |||
| def construct(self, x): | |||
| mean = self.mean(x, -1) | |||
| diff = self.sub1(x, mean) | |||
| variance = self.mean(self.square(diff), -1) | |||
| variance_eps = self.sqrt(self.add(variance, self.eps)) | |||
| output = self.real_div(diff, variance_eps) | |||
| output = self.add2(self.mul(output, self.gamma), self.beta) | |||
| return output | |||
| class Mapping(nn.Cell): | |||
| """ | |||
| A mapping function with a 3d input | |||
| Args: | |||
| input_size: the size of the last dimension of the input tensor | |||
| output_size: the desired size of the last dimension of the output tensor | |||
| dtype: the compute datatype | |||
| scale: the scale factor for initialization | |||
| Inputs: | |||
| x: the 3d input | |||
| Returns: | |||
| output: Tensor, a 3d tensor after projection | |||
| """ | |||
| # 优化:matmul,dtype, mapping_output | |||
| def __init__(self, config, input_size, output_size, scale=1.0): | |||
| super(Mapping, self).__init__() | |||
| self.output_size = output_size | |||
| self.input_size = input_size | |||
| self.weight = Parameter(initializer(Normal(sigma=0.02 * scale), | |||
| [input_size, output_size]), | |||
| name="mapping_weight") | |||
| self.bias = Parameter(initializer("zeros", [ | |||
| output_size, | |||
| ]), | |||
| name="mapping_bias", | |||
| parallel_optimizer=False) | |||
| self.dtype = config.compute_dtype | |||
| self.cast = P.Cast() | |||
| self.add = P.TensorAdd().shard(((config.dp, 1), (1,))) | |||
| self.matmul = P.MatMul().shard( | |||
| ((config.dp, config.mp), (config.mp, 1))) | |||
| def construct(self, x): | |||
| out_shape = P.Shape()(x)[:-1] + (self.output_size,) | |||
| x = P.Reshape()(x, (-1, self.input_size)) | |||
| weight = self.cast(self.weight, self.dtype) | |||
| x = self.matmul(x, weight) | |||
| x = self.add(x, self.cast(self.bias, self.dtype)) | |||
| output = P.Reshape()(x, out_shape) | |||
| return output | |||
| class Mapping_output(nn.Cell): | |||
| """ | |||
| A mapping function with a 3d input | |||
| Args: | |||
| input_size: the size of the last dimension of the input tensor | |||
| output_size: the desired size of the last dimension of the output tensor | |||
| dtype: the compute datatype | |||
| scale: the scale factor for initialization | |||
| Inputs: | |||
| x: the 3d input | |||
| Returns: | |||
| output: Tensor, a 3d tensor after projection | |||
| """ | |||
| def __init__(self, config, input_size, output_size, scale=1.0): | |||
| super(Mapping_output, self).__init__() | |||
| self.output_size = output_size | |||
| self.input_size = input_size | |||
| self.weight = Parameter(initializer(Normal(sigma=0.02 * scale), | |||
| [input_size, output_size]), | |||
| name="mapping_weight") | |||
| self.bias = Parameter(initializer("zeros", [ | |||
| output_size, | |||
| ]), | |||
| name="mapping_bias") | |||
| self.dtype = config.compute_dtype | |||
| self.cast = P.Cast() | |||
| self.add = P.TensorAdd().shard(((config.dp, config.mp), (config.mp,))) | |||
| self.matmul = P.MatMul().shard(((config.dp, 1), (1, config.mp))) | |||
| def construct(self, x): | |||
| out_shape = P.Shape()(x)[:-1] + (self.output_size,) | |||
| x = P.Reshape()(x, (-1, self.input_size)) | |||
| weight = self.cast(self.weight, self.dtype) | |||
| x = self.matmul(x, weight) | |||
| x = self.add(x, self.cast(self.bias, self.dtype)) | |||
| output = P.Reshape()(x, out_shape) | |||
| return output | |||
| class Output(nn.Cell): | |||
| """ | |||
| The output mapping module for each layer | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| scale: scale factor for initialization | |||
| Inputs: | |||
| x: output of the self-attention module | |||
| Returns: | |||
| output: Tensor, the output of this layer after mapping | |||
| """ | |||
| def __init__(self, config, scale=1.0): | |||
| super(Output, self).__init__() | |||
| input_size = config.embedding_size | |||
| output_size = config.embedding_size * config.expand_ratio | |||
| self.mapping = Mapping_output(config, input_size, output_size) | |||
| self.projection = Mapping(config, output_size, input_size, scale) | |||
| self.activation = nn.GELU() | |||
| self.activation.gelu.shard(((config.dp, 1, config.mp),)) | |||
| self.dropout = Dropout(1 - config.dropout_rate) | |||
| self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) | |||
| self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) | |||
| def construct(self, x): | |||
| hidden = self.activation(self.mapping(x)) | |||
| output = self.projection(hidden) | |||
| output = self.dropout(output) | |||
| return output | |||
| class AttentionMask(nn.Cell): | |||
| r""" | |||
| Get the attention matrix for self-attention module | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| Returns: | |||
| attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) | |||
| """ | |||
| def __init__(self, config): | |||
| super(AttentionMask, self).__init__() | |||
| self.reshape = P.Reshape() | |||
| self.mul = P.BatchMatMul().shard( | |||
| ((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1? | |||
| self.expand_dim = P.ExpandDims().shard(((1, 1),)) | |||
| ones = np.ones(shape=(config.seq_length, config.seq_length)) | |||
| self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32) | |||
| self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1))) | |||
| def construct(self, input_mask): | |||
| r""" | |||
| Generate the attention mask matrix. | |||
| """ | |||
| input_shape = P.Shape()(input_mask) | |||
| shape_right = (input_shape[0], 1, input_shape[1]) | |||
| shape_left = input_shape + (1,) | |||
| mask_left = self.reshape(input_mask, shape_left) | |||
| mask_right = self.reshape(input_mask, shape_right) | |||
| attention_mask = self.mul(mask_left, mask_right) | |||
| lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) | |||
| attention_mask = self.multiply( | |||
| attention_mask, lower_traiangle) #bs seq_length seq_length | |||
| return attention_mask | |||
| class EmbeddingLookup(nn.Cell): | |||
| """ | |||
| The embedding lookup table for vocabulary | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| Returns: | |||
| output: Tensor, the embedding vector for the input with shape (batch_size, | |||
| seq_length, embedding_size) | |||
| self.embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(EmbeddingLookup, self).__init__() | |||
| self.vocab_size = config.vocab_size | |||
| self.embedding_size = config.embedding_size | |||
| if config.load_ckpt_path: | |||
| # Loading the embedding table from the ckpt path: | |||
| embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy') | |||
| if os.path.exists(embedding_path): | |||
| e_table = np.load(embedding_path) | |||
| e_table = Tensor(e_table, mstype.float32) | |||
| self.embedding_table = Parameter(e_table, name="embedding_table") | |||
| else: | |||
| raise ValueError(f"{embedding_path} file not exits, please check whether word_embedding file exist.") | |||
| else: | |||
| self.embedding_table = Parameter(initializer( | |||
| Normal(0.02), [self.vocab_size, self.embedding_size]), | |||
| name="embedding_table") | |||
| if config.word_emb_dp: | |||
| self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) | |||
| else: | |||
| self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) | |||
| self.shape = (-1, config.seq_length, config.embedding_size) | |||
| def construct(self, input_ids): | |||
| output = self.gather(self.embedding_table, input_ids, 0) | |||
| return output, self.embedding_table | |||
| class Attention(nn.Cell): | |||
| """ | |||
| Self-Attention module for each layer | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| scale: scale factor for initialization | |||
| layer_idx: current layer index | |||
| """ | |||
| def __init__(self, config, scale=1.0, layer_idx=None): | |||
| super(Attention, self).__init__() | |||
| self.get_attention_mask = AttentionMask(config) | |||
| self.projection = Mapping(config, config.embedding_size, | |||
| config.embedding_size, scale) | |||
| self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),)) | |||
| self.merger_head_transpose = P.Transpose().shard( | |||
| ((config.dp, config.mp, 1, 1),)) | |||
| self.reshape = P.Reshape() | |||
| self.n_head = config.num_heads | |||
| self.size_per_head = config.embedding_size // self.n_head | |||
| self.concat_k = P.Concat(axis=3) | |||
| self.concat_v = P.Concat(axis=2) | |||
| self.multiply_data = Tensor([ | |||
| -10000.0, | |||
| ], dtype=mstype.float32) | |||
| self.batch_matmul = P.BatchMatMul().shard( | |||
| ((config.dp, config.mp, 1, 1), (config.dp, config.mp, 1, 1))) | |||
| self.scale = scale | |||
| self.real_div = P.RealDiv().shard(((config.dp, config.mp, 1, 1), ())) | |||
| self.sub = P.Sub().shard( | |||
| ((1,), (config.dp, 1, 1, 1))) | |||
| self.mul = P.Mul().shard( | |||
| ((config.dp, 1, 1, 1), (1,))) | |||
| self.add = P.TensorAdd().shard( | |||
| ((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1))) | |||
| if self.scale: | |||
| self.scale_factor = Tensor(math.sqrt(self.size_per_head)) | |||
| if layer_idx is not None: | |||
| self.coeff = math.sqrt(layer_idx * math.sqrt(self.size_per_head)) | |||
| self.coeff = Tensor(self.coeff) | |||
| self.use_past = config.use_past | |||
| self.dropout = Dropout(1 - config.dropout_rate) | |||
| self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) | |||
| self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) | |||
| self.prob_dropout = Dropout(1 - config.dropout_rate) | |||
| self.prob_dropout.dropout_gen_mask.shard( | |||
| ((config.dp, config.mp, 1, 1),)) | |||
| self.prob_dropout.dropout_do_mask.shard( | |||
| ((config.dp, config.mp, 1, 1),)) | |||
| self.softmax = nn.Softmax() | |||
| self.softmax.softmax.shard(((config.dp, config.mp, 1),)) | |||
| self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) | |||
| self.dense1 = nn.Dense(config.embedding_size, | |||
| config.embedding_size).to_float( | |||
| config.compute_dtype) | |||
| self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1))) | |||
| self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,))) | |||
| self.dense2 = nn.Dense(config.embedding_size, | |||
| config.embedding_size).to_float( | |||
| config.compute_dtype) | |||
| self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1))) | |||
| self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,))) | |||
| self.dense3 = nn.Dense(config.embedding_size, | |||
| config.embedding_size).to_float( | |||
| config.compute_dtype) | |||
| self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1))) | |||
| self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,))) | |||
| def construct(self, x, attention_mask, layer_past=None): | |||
| """ | |||
| self-attention | |||
| Inputs: | |||
| x: output of previous layer | |||
| attention_mask: the attention mask matrix with shape (batch_size, 1, | |||
| seq_length, seq_length) | |||
| layer_past: the previous feature map | |||
| Returns: | |||
| output: Tensor, the output logit of this layer | |||
| layer_present: Tensor, the feature map of current layer | |||
| """ | |||
| original_shape = F.shape(x) | |||
| x = F.reshape(x, (-1, original_shape[-1])) | |||
| query = self.dense1(x) | |||
| key = self.dense2(x) | |||
| value = self.dense3(x) | |||
| query = self.transpose( | |||
| F.reshape( | |||
| query, | |||
| (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 1, 3)) | |||
| key = self.transpose( | |||
| F.reshape( | |||
| key, (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 3, 1)) | |||
| value = self.transpose( | |||
| F.reshape( | |||
| value, | |||
| (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 1, 3)) | |||
| if self.use_past: | |||
| past_value = layer_past[1] | |||
| past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) | |||
| key = self.concat_k((past_key, key)) | |||
| value = self.concat_v(past_value, value) | |||
| layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) | |||
| attention = self._attn(query, key, value, attention_mask) | |||
| attention_merge = self.merge_heads(attention) | |||
| output = self.projection(attention_merge) | |||
| output = self.dropout(output) | |||
| return output, layer_present | |||
| def split_heads(self, x, transpose): | |||
| """ | |||
| split 3d tensor to 4d and switch certain axes | |||
| Inputs: | |||
| x: input tensor | |||
| transpose: tuple, the transpose sequence | |||
| Returns: | |||
| x_transpose: the 4d output | |||
| """ | |||
| x_size = P.Shape()(x) | |||
| new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head) | |||
| x = self.reshape(x, new_x_shape) | |||
| x_transpose = self.transpose(x, transpose) | |||
| return x_transpose | |||
| def merge_heads(self, x): | |||
| """ | |||
| convert a 4d input to a 3d output | |||
| Inputs: | |||
| x: input tensor | |||
| Returns: | |||
| x_merge: the 3d output | |||
| """ | |||
| x = self.merger_head_transpose( | |||
| x, (0, 2, 1, 3)) #bs, seq_length, head, size_per_head | |||
| x_shape = P.Shape()(x) | |||
| new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],) | |||
| x_merge = self.reshape(x, new_shape) | |||
| return x_merge | |||
| def _attn(self, query, key, value, attention_mask): | |||
| """ | |||
| Get the weighted score along the seq_length | |||
| Inputs: | |||
| query: the query matrix | |||
| key: the key matrix | |||
| value: the value matrix | |||
| attention_mask: the attention mask matrix with shape (batch_size, | |||
| 1, seq_length, seq_length) | |||
| Returns: | |||
| weighted_values: Tensor, the weighted sum scores | |||
| """ | |||
| if not self.scale: | |||
| query = query / F.cast(self.coeff, F.dtype(query)) | |||
| key = key / F.cast(self.coeff, F.dtype(key)) | |||
| score = self.batch_matmul(query, key) | |||
| if self.scale: | |||
| score = self.real_div( | |||
| score, | |||
| P.Cast()(self.scale_factor, P.DType()(score))) | |||
| ori_dtype = P.DType()(score) | |||
| score = P.Cast()(score, mstype.float32) | |||
| multiplu_out = self.sub( | |||
| P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), | |||
| P.Cast()(attention_mask, P.DType()(score))) | |||
| adder = self.mul(multiplu_out, self.multiply_data) | |||
| attention_scores = self.add(adder, score) | |||
| shape = F.shape(attention_scores) | |||
| attention_probs = self.softmax( | |||
| F.reshape(attention_scores, | |||
| (shape[0], -1, shape[-1]))) # yzz modify | |||
| attention_probs = P.Cast()(attention_probs, ori_dtype) | |||
| attention_probs = F.reshape(attention_probs, shape) | |||
| attention_probs = self.prob_dropout(attention_probs) | |||
| weighted_values = self.batch_matmul(attention_probs, value) | |||
| return weighted_values | |||
| class Block(nn.Cell): | |||
| """ | |||
| The basic block of PanguAlpha network | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| layer_idx: current layer index | |||
| Inputs: | |||
| x: the output of previous layer(input_ids for the first layer) | |||
| attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) | |||
| layer_past: the previous feature map | |||
| Returns: | |||
| output: Tensor, the output logit of this layer | |||
| layer_present: Tensor, the feature map of current layer | |||
| """ | |||
| def __init__(self, config, layer_idx): | |||
| super(Block, self).__init__() | |||
| scale = 1 / math.sqrt(2.0 * config.num_layers) | |||
| if config.self_layernorm: | |||
| self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| else: | |||
| self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32) | |||
| self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32) | |||
| self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm1.gamma.parallel_optimizer = False | |||
| self.layernorm1.beta.parallel_optimizer = False | |||
| self.attention = Attention(config, scale, layer_idx) | |||
| self.layernorm2.gamma.parallel_optimizer = False | |||
| self.layernorm2.beta.parallel_optimizer = False | |||
| self.output = Output(config, scale) | |||
| self.post_layernorm_residual = config.post_layernorm_residual | |||
| self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| self.last_add = P.TensorAdd().shard( | |||
| ((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| self.last_add.recompute(False) | |||
| self.dtype = config.compute_dtype | |||
| def construct(self, x, input_mask, layer_past=None): | |||
| r""" | |||
| The forward process of the block. | |||
| """ | |||
| input_x = self.layernorm1(x) | |||
| input_x = F.cast(input_x, self.dtype) | |||
| attention, layer_present = self.attention(input_x, input_mask, | |||
| layer_past) | |||
| if self.post_layernorm_residual: | |||
| x = self.add(input_x, attention) | |||
| else: | |||
| x = self.add(x, attention) | |||
| output_x = self.layernorm2(x) | |||
| output_x = F.cast(output_x, self.dtype) | |||
| mlp_logit = self.output(output_x) | |||
| if self.post_layernorm_residual: | |||
| output = self.last_add(output_x, mlp_logit) | |||
| else: | |||
| output = self.last_add(x, mlp_logit) | |||
| return output, layer_present | |||
| class QueryLayerAttention(Attention): | |||
| r""" | |||
| Self-Attention module using input query vector. | |||
| """ | |||
| def construct(self, x, query_hidden_state, attention_mask, layer_past=None): | |||
| original_shape = F.shape(x) | |||
| x = F.reshape(x, (-1, original_shape[-1])) | |||
| query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1])) | |||
| query = self.dense1(query_hidden_state) | |||
| key = self.dense2(x) | |||
| value = self.dense3(x) | |||
| query = self.transpose( | |||
| F.reshape( | |||
| query, | |||
| (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 1, 3)) | |||
| key = self.transpose( | |||
| F.reshape( | |||
| key, (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 3, 1)) | |||
| value = self.transpose( | |||
| F.reshape( | |||
| value, | |||
| (-1, original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 1, 3)) | |||
| if self.use_past: | |||
| past_value = layer_past[1] | |||
| past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) | |||
| key = self.concat_k((past_key, key)) | |||
| value = self.concat_v(past_value, value) | |||
| layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) | |||
| attention = self._attn(query, key, value, attention_mask) | |||
| attention_merge = self.merge_heads(attention) | |||
| output = self.projection(attention_merge) | |||
| output = self.dropout(output) | |||
| return output, layer_present | |||
| class QueryLayer(nn.Cell): | |||
| r""" | |||
| A block usingooked out position embedding as query vector. | |||
| This is used as the final block. | |||
| """ | |||
| def __init__(self, config): | |||
| super(QueryLayer, self).__init__() | |||
| scale = 1 / math.sqrt(2.0 * config.num_layers) | |||
| self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) | |||
| self.layernorm1.gamma.parallel_optimizer = False | |||
| self.layernorm1.beta.parallel_optimizer = False | |||
| self.attention = QueryLayerAttention(config, scale) | |||
| self.layernorm2.gamma.parallel_optimizer = False | |||
| self.layernorm2.beta.parallel_optimizer = False | |||
| self.output = Output(config, scale) | |||
| self.post_layernorm_residual = config.post_layernorm_residual | |||
| self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| self.last_add = P.TensorAdd().shard( | |||
| ((config.dp, 1, 1), (config.dp, 1, | |||
| 1))).add_prim_attr("recompute", False) | |||
| self.dtype = config.compute_dtype | |||
| def construct(self, x, query_hidden_state, input_mask, layer_past=None): | |||
| r""" | |||
| Query Layer. | |||
| """ | |||
| input_x = self.layernorm1(x) | |||
| input_x = F.cast(input_x, self.dtype) | |||
| attention, layer_present = self.attention(input_x, | |||
| query_hidden_state, | |||
| input_mask, | |||
| layer_past) | |||
| if self.post_layernorm_residual: | |||
| x = self.add(input_x, attention) | |||
| else: | |||
| x = self.add(x, attention) | |||
| output_x = self.layernorm2(x) | |||
| output_x = F.cast(output_x, self.dtype) | |||
| mlp_logit = self.output(output_x) | |||
| if self.post_layernorm_residual: | |||
| output = self.last_add(output_x, mlp_logit) | |||
| else: | |||
| output = self.last_add(x, mlp_logit) | |||
| return output, layer_present | |||
| class PanguAlpha_Model(nn.Cell): | |||
| """ | |||
| The backbone of PanguAlpha network | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| layer_past: the previous feature map | |||
| Returns: | |||
| output_state: Tensor, the output logit of backbone | |||
| present_layer: Tensor, the current feature map | |||
| embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_Model, self).__init__() | |||
| self.get_attention_mask = AttentionMask(config) | |||
| self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1) | |||
| if config.load_ckpt_path: | |||
| # Loading the embedding table from the ckpt path: | |||
| embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy') | |||
| if os.path.exists(embedding_path): | |||
| p_table = np.load(embedding_path) | |||
| position_table_param = Tensor(p_table, mstype.float32) | |||
| else: | |||
| raise ValueError(f"{embedding_path} file not exits, please check whether position_embedding file exit.") | |||
| else: | |||
| position_table_param = TruncatedNormal(0.02) | |||
| self.position_embedding = nn.Embedding( | |||
| config.seq_length, | |||
| config.embedding_size, | |||
| embedding_table=position_table_param).set_comm_fusion(1) | |||
| self.word_embedding.embedding_table.parallel_optimizer = False | |||
| self.position_embedding.embedding_table.parallel_optimizer = False | |||
| self.position_embedding.gather.shard(((1, 1), (config.dp,))) | |||
| self.position_embedding.expand.shard(((config.dp, 1),)) | |||
| self.blocks = nn.CellList() | |||
| fusion_group_num = 4 | |||
| fusion_group_size = config.num_layers // fusion_group_num | |||
| fusion_group_size = max(fusion_group_size, 1) | |||
| num_layers = config.num_layers | |||
| if config.use_top_query_attention: | |||
| num_layers -= 1 | |||
| self.num_layers = num_layers | |||
| print("After setting the layer is:", num_layers, flush=True) | |||
| for i in range(num_layers): | |||
| per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) | |||
| per_block.recompute() | |||
| per_block.attention.dropout.dropout_gen_mask.recompute(False) | |||
| per_block.attention.prob_dropout.dropout_gen_mask.recompute(False) | |||
| per_block.output.dropout.dropout_gen_mask.recompute(False) | |||
| self.blocks.append(per_block) | |||
| if config.self_layernorm: | |||
| self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float( | |||
| mstype.float32).set_comm_fusion( | |||
| int((num_layers - 1) / fusion_group_size) + 2) | |||
| else: | |||
| self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float( | |||
| mstype.float32).set_comm_fusion( | |||
| int((num_layers - 1) / fusion_group_size) + 2) | |||
| self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) | |||
| self.layernorm.gamma.parallel_optimizer = False | |||
| self.layernorm.beta.parallel_optimizer = False | |||
| self.use_past = config.use_past | |||
| self.past = tuple([None] * config.num_layers) | |||
| self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) | |||
| self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) | |||
| self.dtype = config.compute_dtype | |||
| self.dropout = Dropout(1 - config.dropout_rate) | |||
| self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) | |||
| self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) | |||
| self.eod_reset = config.eod_reset | |||
| if config.use_top_query_attention: | |||
| if config.load_ckpt_path: | |||
| # Loading the embedding table from the ckpt path: | |||
| embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy') | |||
| if os.path.exists(embedding_path): | |||
| top_query_table = np.load(embedding_path) | |||
| top_query_table_param = Tensor(top_query_table, mstype.float32) | |||
| else: | |||
| raise ValueError( | |||
| f"{embedding_path} file not exits, please check whether top_query_embedding file exist.") | |||
| else: | |||
| top_query_table_param = TruncatedNormal(0.02) | |||
| self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size, | |||
| embedding_table=top_query_table_param) | |||
| self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2) | |||
| self.top_query_embedding.embedding_table.parallel_optimizer = False | |||
| self.top_query_embedding.gather.shard(((1, 1), (config.dp,))) | |||
| self.top_query_embedding.expand.shard(((config.dp, 1),)) | |||
| self.top_query_layer = QueryLayer(config) | |||
| if config.use_recompute: | |||
| self.top_query_layer.recompute() | |||
| self.top_query_layer.output.dropout.dropout_gen_mask.recompute(False) | |||
| self.top_query_layer.attention.dropout.dropout_gen_mask.recompute(False) | |||
| self.top_query_layer.attention.prob_dropout.dropout_gen_mask.recompute(False) | |||
| self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_num) + 2) | |||
| self.use_top_query_attention = config.use_top_query_attention | |||
| def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None): | |||
| """PanguAlpha model""" | |||
| if not self.use_past: | |||
| layer_past = self.past | |||
| input_embedding, embedding_table = self.word_embedding(input_ids) | |||
| if not self.eod_reset: | |||
| batch_size, seq_length = F.shape(input_ids) | |||
| input_position = F.tuple_to_array(F.make_range(seq_length)) | |||
| input_position = P.Tile()(input_position, (batch_size, 1)) | |||
| attention_mask = self.get_attention_mask(input_mask) | |||
| position_embedding = self.position_embedding(input_position) | |||
| hidden_states = self.add(input_embedding, position_embedding) | |||
| hidden_states = self.dropout(hidden_states) | |||
| hidden_states = P.Cast()(hidden_states, mstype.float16) | |||
| attention_mask = self.expand_dims(attention_mask, 1) | |||
| present_layer = () | |||
| for i in range(self.num_layers): | |||
| hidden_states, present = self.blocks[i](hidden_states, | |||
| attention_mask, layer_past) | |||
| present_layer = present_layer + (present,) | |||
| output_state = self.layernorm(hidden_states) | |||
| output_state = F.cast(output_state, self.dtype) | |||
| if self.use_top_query_attention: | |||
| top_query_hidden_states = self.top_query_embedding(input_position) | |||
| output_state, present = self.top_query_layer(output_state, top_query_hidden_states, | |||
| attention_mask, layer_past) | |||
| present_layer = present_layer + (present,) | |||
| return output_state, present_layer, embedding_table | |||
| class PanguAlpha_Head(nn.Cell): | |||
| """ | |||
| Head for PanguAlpha to get the logits of each token in the vocab | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| state: the output of the backbone | |||
| embedding_table: the embedding table of the vocabulary | |||
| Returns: | |||
| logits: Tensor, the logits of the corresponding inputs | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha_Head, self).__init__() | |||
| if config.word_emb_dp: | |||
| self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (1, 1))) | |||
| else: | |||
| self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (config.mp, 1))) | |||
| self.embedding_size = config.embedding_size | |||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||
| self.dtype = config.compute_dtype | |||
| self.cast = P.Cast() | |||
| def construct(self, state, embedding_table): | |||
| state = P.Reshape()(state, (-1, self.embedding_size)) | |||
| logits = self.matmul(state, self.cast(embedding_table, self.dtype)) | |||
| return logits | |||
| class PanguAlpha(nn.Cell): | |||
| """ | |||
| The PanguAlpha network consisting of two parts the backbone and the head | |||
| Args: | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| past: the previous feature map | |||
| Returns: | |||
| logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanguAlpha, self).__init__() | |||
| self.backbone = PanguAlpha_Model(config) | |||
| self.head = PanguAlpha_Head(config) | |||
| def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None): | |||
| output_states, _, embedding_table = self.backbone( | |||
| input_ids, input_mask, input_position, attention_mask, past) | |||
| logits = self.head(output_states, embedding_table) | |||
| return logits | |||
| class CrossEntropyLoss(nn.Cell): | |||
| """ | |||
| Calculate the cross entropy loss | |||
| Args: | |||
| config(PanguAlphaConfig): the config of the network | |||
| Inputs: | |||
| logits: the output logits of the backbone | |||
| label: the ground truth label of the sample | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| Returns: | |||
| loss: Tensor, the corrsponding cross entropy loss | |||
| """ | |||
| def __init__(self, config): | |||
| super(CrossEntropyLoss, self).__init__() | |||
| self.mean = P.ReduceMean() | |||
| self.sum = P.ReduceSum().shard(((config.dp, config.mp),)) | |||
| self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ())) | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.vocab_size = config.vocab_size | |||
| self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard( | |||
| ((config.dp, config.mp),)) | |||
| self.eps_const = Tensor(1e-24, mstype.float32) | |||
| self.sub = P.Sub().shard(((config.dp, config.mp), (config.dp, 1))) | |||
| self.exp = P.Exp().shard(((config.dp, config.mp),)) | |||
| self.div = P.RealDiv().shard(((config.dp, config.mp), (config.dp, 1))) | |||
| self.log = P.Log().shard(((config.dp, config.mp),)) | |||
| self.add = P.TensorAdd().shard(((config.dp, config.mp), ())) | |||
| self.mul = P.Mul().shard( | |||
| ((config.dp, config.mp), (config.dp, config.mp))) | |||
| self.neg = P.Neg().shard(((config.dp, config.mp),)) | |||
| self.sum2 = P.ReduceSum().shard(((1,),)) | |||
| self.mul2 = P.Mul().shard(((1,), (1,))) | |||
| self.add2 = P.TensorAdd() | |||
| self.div2 = P.RealDiv() | |||
| def construct(self, logits, label, input_mask): | |||
| r""" | |||
| Compute loss using logits, label and input mask | |||
| """ | |||
| logits = F.cast(logits, mstype.float32) | |||
| _, logit_max = self.max(logits) | |||
| logit_sub = self.sub(logits, logit_max) | |||
| logit_exp = self.exp(logit_sub) | |||
| exp_sum = self.sum(logit_exp, -1) | |||
| exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1)) | |||
| softmax_result = self.div(logit_exp, exp_sum) | |||
| log_softmax_result = self.log(self.add(softmax_result, self.eps_const)) | |||
| label = P.Reshape()(label, (-1,)) | |||
| one_hot_label = self.onehot(label, self.vocab_size, self.on_value, | |||
| self.off_value) | |||
| loss = self.mul(log_softmax_result, one_hot_label) | |||
| loss_unsum = self.neg(loss) | |||
| loss_reduce = self.sum(loss_unsum, -1) | |||
| input_mask = P.Reshape()(input_mask, (-1,)) | |||
| numerator = self.sum2(self.mul2(loss_reduce, input_mask)) | |||
| denominator = self.add2( | |||
| self.sum2(input_mask), | |||
| P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32)) | |||
| loss = self.div2(numerator, denominator) | |||
| return loss | |||
| class PanguAlphaWithLoss(nn.Cell): | |||
| """ | |||
| PanguAlpha training loss | |||
| Args: | |||
| network: backbone network of PanguAlpha | |||
| loss: loss function, e.g., crossentropy | |||
| eos_token: the end_of_sentence token | |||
| Inputs: | |||
| input_ids: the tokenized inputs | |||
| past: the previous feature map | |||
| Returns: | |||
| output: Tensor, the loss of the network | |||
| """ | |||
| def __init__(self, config, network, loss, eos_token=6): | |||
| super(PanguAlphaWithLoss, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.loss = loss | |||
| self.eos_token = eos_token | |||
| self.slice = P.StridedSlice().shard(((config.dp, 1),)) | |||
| self.not_equal = P.NotEqual().shard(((config.dp, 1), ())) | |||
| self.batch_size = config.batch_size | |||
| self.len = config.seq_length | |||
| self.eod_reset = config.eod_reset | |||
| if self.eod_reset: | |||
| self.slice_mask = P.StridedSlice().shard(((config.dp, 1, 1),)) | |||
| def construct(self, input_ids, input_position=None, attention_mask=None): | |||
| r""" | |||
| PanguAlphaWithLoss | |||
| """ | |||
| tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) | |||
| if self.eod_reset: | |||
| input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1)) | |||
| attention_mask = self.slice_mask(attention_mask, (0, 0, 0), | |||
| (self.batch_size, self.len, self.len), | |||
| (1, 1, 1)) | |||
| input_mask = F.cast(self.not_equal(tokens, self.eos_token), | |||
| mstype.float32) | |||
| logits = self.network(tokens, input_mask, input_position, attention_mask) | |||
| labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1), | |||
| (1, 1)) | |||
| output = self.loss(logits, labels, input_mask) | |||
| return output | |||
| class EvalNet(nn.Cell): | |||
| """ | |||
| PanguAlpha evaluation net | |||
| Args: | |||
| backbone: backbone network of PanguAlpha | |||
| generate: enable generate mode | |||
| Inputs: | |||
| input_ids: the tokenized inpus | |||
| Returns: | |||
| outputs: Tensor, corresponding output for different tasks | |||
| """ | |||
| def __init__(self, backbone, generate=False): | |||
| super(EvalNet, self).__init__(auto_prefix=False) | |||
| self.backbone = backbone | |||
| self.argmax = P.Argmax() | |||
| self.generate = generate | |||
| self.topk = P.TopK(sorted=True).shard(((1, 1),)) | |||
| def construct(self, input_ids): | |||
| """evaluation net""" | |||
| input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32) | |||
| logits = self.backbone(input_ids, input_mask) | |||
| value, index = self.topk(logits, 5) | |||
| return value, index | |||
| @@ -0,0 +1,135 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting | |||
| """ | |||
| import mindspore.common.dtype as mstype | |||
| class PANGUALPHAConfig: | |||
| """ | |||
| PANGUALPHA config class which defines the model size | |||
| """ | |||
| def __init__(self, | |||
| data_parallel_num, | |||
| model_parallel_num, | |||
| batch_size=32, | |||
| seq_length=1024, | |||
| vocab_size=50257, | |||
| embedding_size=768, | |||
| num_layers=12, | |||
| num_heads=12, | |||
| expand_ratio=4, | |||
| post_layernorm_residual=False, | |||
| dropout_rate=0.1, | |||
| compute_dtype=mstype.float16, | |||
| use_past=False, | |||
| self_layernorm=True, | |||
| word_emb_dp=True, | |||
| stage_num=16, | |||
| eod_reset=True, | |||
| micro_size=32, | |||
| load_ckpt_path=None, | |||
| use_top_query_attention=True, | |||
| use_recompute=True): | |||
| self.batch_size = batch_size | |||
| self.seq_length = seq_length | |||
| self.vocab_size = vocab_size | |||
| self.embedding_size = embedding_size | |||
| self.num_layers = num_layers | |||
| self.num_heads = num_heads | |||
| self.expand_ratio = expand_ratio | |||
| self.post_layernorm_residual = post_layernorm_residual | |||
| self.dropout_rate = dropout_rate | |||
| self.compute_dtype = compute_dtype | |||
| self.use_past = use_past | |||
| self.dp = data_parallel_num | |||
| self.mp = model_parallel_num | |||
| self.self_layernorm = self_layernorm | |||
| self.stage_num = stage_num | |||
| self.micro_size = micro_size | |||
| self.word_emb_dp = word_emb_dp | |||
| self.eod_reset = eod_reset | |||
| # Used for loading embedding tables | |||
| self.load_ckpt_path = load_ckpt_path | |||
| self.use_top_query_attention = use_top_query_attention | |||
| self.use_recompute = use_recompute | |||
| def __str__(self): | |||
| info = "[PANGUALPHAConfig]" + '===' * 10 + '\n' | |||
| for k, v in self.__dict__.items(): | |||
| var_info = "{}:{}\n".format(k, v) | |||
| info += var_info | |||
| info += '=' * 10 | |||
| return info | |||
| def set_parse(args_opt): | |||
| r""" | |||
| Set config according to the mode | |||
| """ | |||
| if args_opt.mode == "200B": | |||
| args_opt.seq_length = 1024 | |||
| args_opt.vocab_size = 40000 | |||
| args_opt.embedding_size = 16384 | |||
| args_opt.num_layers = 64 | |||
| args_opt.num_heads = 128 | |||
| if args_opt.run_type == "train": | |||
| args_opt.start_lr = 6e-5 | |||
| args_opt.end_lr = 6e-6 | |||
| args_opt.optimizer_shard = False | |||
| args_opt.stage_num = 16 | |||
| args_opt.micro_size = 32 | |||
| args_opt.tensor_model_parallel_num = 16 | |||
| args_opt.per_batch_size = 1 | |||
| elif args_opt.run_type == "predict": | |||
| args_opt.stage_num = 4 | |||
| args_opt.micro_size = 1 | |||
| args_opt.per_batch_size = 1 | |||
| elif args_opt.mode == "13B": | |||
| args_opt.seq_length = 1024 | |||
| args_opt.vocab_size = 40000 | |||
| args_opt.embedding_size = 5120 | |||
| args_opt.num_layers = 40 | |||
| args_opt.num_heads = 40 | |||
| args_opt.tensor_model_parallel_num = 8 | |||
| if args_opt.run_type == "train": | |||
| args_opt.start_lr = 5e-5 | |||
| args_opt.end_lr = 1e-6 | |||
| args_opt.optimizer_shard = True | |||
| args_opt.stage_num = 1 | |||
| args_opt.micro_size = 1 | |||
| args_opt.per_batch_size = 16 | |||
| elif args_opt.run_type == "predict": | |||
| args_opt.stage_num = 1 | |||
| args_opt.micro_size = 1 | |||
| args_opt.per_batch_size = 1 | |||
| elif args_opt.mode == "2.6B": | |||
| args_opt.seq_length = 1024 | |||
| args_opt.vocab_size = 40000 | |||
| args_opt.embedding_size = 2560 | |||
| args_opt.num_layers = 32 | |||
| args_opt.num_heads = 32 | |||
| args_opt.tensor_model_parallel_num = 8 | |||
| if args_opt.run_type == "train": | |||
| args_opt.start_lr = 1e-4 | |||
| args_opt.end_lr = 1e-6 | |||
| args_opt.optimizer_shard = True | |||
| args_opt.stage_num = 1 | |||
| args_opt.micro_size = 1 | |||
| args_opt.per_batch_size = 2 | |||
| elif args_opt.run_type == "predict": | |||
| args_opt.stage_num = 1 | |||
| args_opt.micro_size = 1 | |||
| args_opt.per_batch_size = 1 | |||
| @@ -0,0 +1,138 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """GPT training wrapper""" | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops.operations.comm_ops import _VirtualDataset | |||
| from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell | |||
| from src.utils import ClipByGlobalNorm | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||
| @clip_grad.register("Number", "Number", "Tensor") | |||
| def _clip_grad(clip_type, clip_value, grad): | |||
| """ | |||
| Clip gradients. | |||
| Inputs: | |||
| clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. | |||
| clip_value (float): Specifies how much to clip. | |||
| grad (tuple[Tensor]): Gradients. | |||
| Outputs: | |||
| tuple[Tensor], clipped gradients. | |||
| """ | |||
| if clip_type not in [0, 1]: | |||
| return grad | |||
| dt = F.dtype(grad) | |||
| if clip_type == 0: | |||
| new_grad = C.clip_by_value( | |||
| grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| else: | |||
| new_grad = nn.ClipByNorm()(grad, | |||
| F.cast(F.tuple_to_array((clip_value,)), | |||
| dt)) | |||
| return new_grad | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| class VirtualDatasetOneInputCell(nn.Cell): | |||
| def __init__(self, backbone): | |||
| super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False) | |||
| self._backbone = backbone | |||
| self._virtual_dataset = _VirtualDataset() | |||
| def construct(self, *data): | |||
| data_ = self._virtual_dataset(*data) | |||
| return self._backbone(*data_) | |||
| class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| """ | |||
| Encapsulation class of PanguAlpha network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_update_cell (Cell): Cell to do the loss scale. Default: None. | |||
| """ | |||
| def __init__(self, | |||
| network, | |||
| optimizer, | |||
| scale_update_cell=None, | |||
| enable_global_norm=False, | |||
| config=None): | |||
| super(PanguAlphaTrainOneStepWithLossScaleCell, | |||
| self).__init__(network, optimizer, scale_update_cell) | |||
| self.network = network | |||
| self.config = config | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.default_lr = Tensor([0.0], dtype=mstype.float32) | |||
| self.enable_global_norm = enable_global_norm | |||
| self.clip = ClipByGlobalNorm(self.weights) | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(input_ids, input_position, attention_mask) | |||
| scaling_sens = self.scale_sense | |||
| # alloc status and clear should be right before gradoperation | |||
| status, scaling_sens = self.start_overflow_check(loss, scaling_sens) | |||
| scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
| grads = self.grad(self.network, | |||
| weights)(input_ids, | |||
| input_position, attention_mask, | |||
| scaling_sens_filled) | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map( | |||
| F.partial(grad_scale, scaling_sens), grads) | |||
| if self.enable_global_norm: | |||
| grads, _ = self.clip(grads) | |||
| else: | |||
| grads = self.hyper_map( | |||
| F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), | |||
| grads) | |||
| cond = self.get_overflow_status(status, grads) | |||
| overflow = self.process_loss_scale(cond) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| return F.depend(loss, succ), cond, scaling_sens | |||
| @@ -0,0 +1,216 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| transform wikitext-2, wikitext-103, lambada, openwebtext dataset to mindrecord. | |||
| """ | |||
| import argparse | |||
| import glob | |||
| import json | |||
| import os | |||
| import re | |||
| from multiprocessing import Pool, current_process | |||
| import numpy as np | |||
| try: | |||
| from transformers import GPT2Tokenizer | |||
| except ModuleNotFoundError: | |||
| print("module 'transformers' not installed.") | |||
| from mindspore.mindrecord import FileWriter | |||
| EOT = 50256 # id of endoftext | |||
| SEQ_LEN = 1025 # the length of sample | |||
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |||
| def chunks(lst, n): | |||
| """ yield n sized chunks from list""" | |||
| for i in range(0, len(lst), n): | |||
| yield lst[i:i+n] | |||
| def package_file(it, n): | |||
| """ package multiple files""" | |||
| stop = False | |||
| while not stop: | |||
| batch = [] | |||
| for _ in range(n): | |||
| try: | |||
| batch.append(next(it)) | |||
| except StopIteration: | |||
| stop = True | |||
| if not batch: | |||
| break | |||
| yield batch | |||
| def clean_wikitext(string): | |||
| """ cleaning wikitext dataset""" | |||
| # contractions | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # number separators | |||
| string = string.replace(" @-@ ", "-") | |||
| string = string.replace(" @,@ ", ",") | |||
| string = string.replace(" @.@ ", ".") | |||
| # punctuation | |||
| string = string.replace(" : ", ": ") | |||
| string = string.replace(" ; ", "; ") | |||
| string = string.replace(" . ", ". ") | |||
| string = string.replace(" ! ", "! ") | |||
| string = string.replace(" ? ", "? ") | |||
| string = string.replace(" , ", ", ") | |||
| # double brackets | |||
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |||
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |||
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||
| # miscellaneous | |||
| string = string.replace("= = = =", "====") | |||
| string = string.replace("= = =", "===") | |||
| string = string.replace("= =", "==") | |||
| string = string.replace(" "+chr(176)+" ", chr(176)) | |||
| string = string.replace(" \n", "\n") | |||
| string = string.replace("\n ", "\n") | |||
| string = string.replace(" N ", " 1 ") | |||
| string = string.replace(" 's", "'s") | |||
| return string | |||
| def tokenize_openwebtext(iterator): | |||
| """ tokenize openwebtext dataset""" | |||
| for file_path in iterator: | |||
| if os.path.getsize(file_path) == 0: | |||
| continue | |||
| content = [] | |||
| with open(file_path, 'r', encoding='utf-8') as f: | |||
| for para in f.read().split("\n\n"): | |||
| if para: | |||
| tokenized_text = tokenizer.tokenize(para) | |||
| content += tokenizer.convert_tokens_to_ids(tokenized_text) + [ | |||
| EOT] | |||
| for chunk in chunks(content, SEQ_LEN): | |||
| sample = {} | |||
| if len(chunk) == SEQ_LEN: | |||
| sample['input_ids'] = np.array(chunk, dtype=np.int32) | |||
| yield sample | |||
| def tokenize_wiki(file_path): | |||
| """tokenize wikitext-2/wikitext-103 dataset""" | |||
| content = [] | |||
| with open(file_path, 'r', encoding='utf-8') as f: | |||
| for para in clean_wikitext(f.read()).split("\n\n"): | |||
| if para and para.strip().startswith('=') is False: | |||
| tokenized_text = tokenizer.tokenize(para) | |||
| content += tokenizer.convert_tokens_to_ids(tokenized_text) + [ | |||
| EOT] | |||
| for chunk in chunks(content, SEQ_LEN): | |||
| sample = {} | |||
| if len(chunk) == SEQ_LEN: | |||
| sample['input_ids'] = np.array(chunk, dtype=np.int32) | |||
| yield sample | |||
| def tokenize_lambada(file_path): | |||
| """tokenize lambada dataset""" | |||
| content = [] | |||
| with open(file_path, 'r', encoding='utf-8') as f: | |||
| for line in f.readlines(): | |||
| para = json.loads(line)['text'].replace( | |||
| "“", '"').replace("”", '"').strip().strip(".") | |||
| tokenized_text = tokenizer.tokenize(para) | |||
| content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT] | |||
| for chunk in chunks(content, SEQ_LEN): | |||
| sample = {} | |||
| if len(chunk) == SEQ_LEN: | |||
| sample['input_ids'] = np.array(chunk, dtype=np.int32) | |||
| yield sample | |||
| def task_unit(iterator, parallel_writer=True): | |||
| """task for each process""" | |||
| p = current_process() | |||
| index = p.pid if p.pid else 0 | |||
| item_iter = tokenize_openwebtext(iterator) | |||
| batch_size = 1024 # size of write batch | |||
| count = 0 | |||
| while True: | |||
| data_batch = [] | |||
| try: | |||
| for _ in range(batch_size): | |||
| data_batch.append(next(item_iter)) | |||
| count += 1 | |||
| writer.write_raw_data(data_batch, parallel_writer=parallel_writer) | |||
| print("Process {} transformed {} records.".format( | |||
| index, count)) | |||
| except StopIteration: | |||
| if data_batch: | |||
| writer.write_raw_data(data_batch, | |||
| parallel_writer=parallel_writer) | |||
| print("Process {} transformed {} records.".format( | |||
| index, count)) | |||
| break | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--dataset_type', type=str, default='openwebtext') | |||
| parser.add_argument('--input_glob', type=str, default='*.txt') | |||
| parser.add_argument('--output_file', type=str, | |||
| default='./output/transfered_mindrecord') | |||
| parser.add_argument('--file_partition', type=int, default=1) | |||
| parser.add_argument('--file_batch_size', type=int, default=1024) | |||
| parser.add_argument('--num_process', type=int, default=64) | |||
| args = parser.parse_args() | |||
| ### | |||
| out_dir, out_file = os.path.split(os.path.abspath(args.output_file)) | |||
| if not os.path.exists(out_dir): | |||
| os.mkdir(out_dir) | |||
| schema = {"input_ids": {"type": "int32", "shape": [-1]},} | |||
| writer = FileWriter(file_name=args.output_file, | |||
| shard_num=args.file_partition) | |||
| writer.add_schema(schema, args.dataset_type) | |||
| writer.open_and_set_header() | |||
| ### | |||
| transforms_count = 0 | |||
| if args.dataset_type == 'wiki': | |||
| for x in tokenize_wiki(args.input_glob): | |||
| transforms_count += 1 | |||
| writer.write_raw_data([x]) | |||
| print("Transformed {} records.".format(transforms_count)) | |||
| elif args.dataset_type == 'lambada': | |||
| for x in tokenize_lambada(args.input_glob): | |||
| transforms_count += 1 | |||
| writer.write_raw_data([x]) | |||
| print("Transformed {} records.".format(transforms_count)) | |||
| elif args.dataset_type == 'openwebtext': | |||
| file_iter = glob.iglob(args.input_glob) | |||
| with Pool(processes=args.num_process) as pool: | |||
| pool.map(task_unit, package_file(file_iter, args.file_batch_size)) | |||
| else: | |||
| raise ValueError( | |||
| "Not support dataset type: {}".format(args.dataset_type)) | |||
| writer.commit() | |||
| out_file = args.output_file | |||
| if args.file_partition > 1: | |||
| out_file += '0' | |||
| print("Transform finished, output files refer: {}".format(out_file)) | |||
| @@ -0,0 +1,76 @@ | |||
| # coding=utf-8 | |||
| # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OpenAI GPT.""" | |||
| from __future__ import (absolute_import, division, print_function, | |||
| unicode_literals) | |||
| from io import open | |||
| import sentencepiece as spm | |||
| import jieba | |||
| class JIEBATokenizer(): | |||
| r""" | |||
| Jieba Tokenizer | |||
| """ | |||
| def __init__(self, vocab_file, model_file, max_len=None): | |||
| self.max_len = max_len if max_len is not None else int(1e12) | |||
| f = open(vocab_file, 'r') | |||
| lines = f.readlines() | |||
| self.encoder = {} | |||
| for line in enumerate(lines): | |||
| key = line[1].split('\t')[0] | |||
| self.encoder[key] = line[0] | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.sp = spm.SentencePieceProcessor(model_file=model_file) | |||
| self.translator = str.maketrans(" \n", "\u2582\u2583") | |||
| self.eod_id = self.encoder['<eod>'] | |||
| self.eot_id = self.encoder['<eot>'] | |||
| self.pad_id = self.encoder['<pad>'] | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.encoder) | |||
| def __len__(self): | |||
| return len(self.encoder) + len(self.special_tokens) | |||
| @property | |||
| def eod(self): | |||
| return self.eod_id | |||
| def tokenize(self, text): | |||
| """ Tokenize a string. """ | |||
| seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)] | |||
| new_seg = " ".join(seg_list) | |||
| return self.sp.encode(new_seg) | |||
| def convert_tokens_to_ids(self, tokens): | |||
| return tokens | |||
| def convert_ids_to_tokens(self, ids): | |||
| return self.decode(ids) | |||
| def encode(self, text): | |||
| res = self.tokenize(text) | |||
| return res | |||
| def decode(self, tokens): | |||
| text = self.sp.decode(tokens) | |||
| text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n') | |||
| return text | |||
| @@ -0,0 +1,261 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, gradient clip function and dynamic learning rate function | |||
| """ | |||
| import argparse | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR | |||
| from mindspore.parallel._utils import _get_global_rank | |||
| from mindspore.communication.management import get_group_size | |||
| get_square_sum = C.MultitypeFuncGraph("get_square_sum") | |||
| @get_square_sum.register("Tensor", "Tensor") | |||
| def _get_square_sum(grad, value): | |||
| norm = P.ReduceSum(False)(F.square(grad) / value, ()) | |||
| norm = F.expand_dims(F.cast(norm, mstype.float32), 0) | |||
| return norm | |||
| apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") | |||
| @apply_global_norm.register("Tensor", "Tensor", "Tensor") | |||
| def _apply_global_norm(clip_norm, global_norm, grad): | |||
| grad = grad * clip_norm / global_norm | |||
| return grad | |||
| class GlobalNorm(nn.Cell): | |||
| """ | |||
| Calculate the global norm value of given tensors | |||
| """ | |||
| def __init__(self, params): | |||
| super(GlobalNorm, self).__init__() | |||
| self.norm = nn.Norm() | |||
| self.hyper_map = C.HyperMap() | |||
| self.allreduce_filter = tuple( | |||
| "projection.bias" not in x.name and "layernorm" not in x.name and "embedding_table" | |||
| not in x.name for x in params) | |||
| self.length = len(params) | |||
| self.values = [] | |||
| self.group_size = get_group_size() | |||
| for item in self.allreduce_filter: | |||
| if item: | |||
| self.values.append(Tensor([1.0], mstype.float32)) | |||
| else: | |||
| self.values.append(Tensor([self.group_size*1.0], mstype.float32)) | |||
| self.values = tuple(self.values) | |||
| def construct(self, grads): | |||
| square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) | |||
| global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) | |||
| return global_norms | |||
| class ClipByGlobalNorm(nn.Cell): | |||
| """ | |||
| Clip grads by global norm | |||
| """ | |||
| def __init__(self, params, clip_norm=1.0): | |||
| super(ClipByGlobalNorm, self).__init__() | |||
| self.global_norm = GlobalNorm(params) | |||
| self.clip_norm = Tensor([clip_norm], mstype.float32) | |||
| self.hyper_map = C.HyperMap() | |||
| def construct(self, grads): | |||
| global_norm_value = self.global_norm(grads) | |||
| cond = P.GreaterEqual()(global_norm_value, self.clip_norm) | |||
| global_norm = F.select(cond, global_norm_value, self.clip_norm) | |||
| grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) | |||
| return grads, global_norm_value | |||
| def _get_model_parallel_group(dp, mp): | |||
| rank = _get_global_rank() | |||
| group = range(0, mp) | |||
| index = rank // dp | |||
| return [x + index * mp for x in group] | |||
| class LearningRate(LearningRateSchedule): | |||
| """ | |||
| Warmup-decay learning rate for PanguAlpha network. | |||
| """ | |||
| def __init__(self, | |||
| learning_rate, | |||
| end_learning_rate, | |||
| warmup_steps, | |||
| decay_steps, | |||
| power=1.0, | |||
| use_cosine=True, | |||
| lr_scale=0.125): | |||
| super(LearningRate, self).__init__() | |||
| self.warmup_flag = False | |||
| if warmup_steps > 0: | |||
| self.warmup_flag = True | |||
| self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) | |||
| self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, | |||
| decay_steps, power) | |||
| self.cosine_decay_lr = CosineDecayLR(end_learning_rate, learning_rate, | |||
| decay_steps) | |||
| self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) | |||
| self.greater = P.Greater() | |||
| self.one = Tensor(np.array([1.0]).astype(np.float32)) | |||
| self.cast = P.Cast() | |||
| self.use_cosine = use_cosine | |||
| self.lr_scale = lr_scale | |||
| def construct(self, global_step): | |||
| """dynamic learning rate""" | |||
| if not self.use_cosine: | |||
| decay_lr = self.decay_lr(global_step) | |||
| else: | |||
| decay_lr = self.cosine_decay_lr(global_step) | |||
| if self.warmup_flag: | |||
| is_warmup = self.cast(self.greater(self.warmup_steps, global_step), | |||
| mstype.float32) | |||
| warmup_lr = self.warmup_lr(global_step) | |||
| lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr | |||
| else: | |||
| lr = decay_lr | |||
| return lr * self.lr_scale | |||
| def add_training_params(opt): | |||
| """Add training params""" | |||
| opt.add_argument("--seq_length", | |||
| type=int, | |||
| default=1024, | |||
| help="sequence length, default is 1024.") | |||
| opt.add_argument("--vocab_size", | |||
| type=int, | |||
| default=40000, | |||
| help="vocabulary size, default is 40000.") | |||
| opt.add_argument("--embedding_size", | |||
| type=int, | |||
| default=16384, | |||
| help="embedding table size, default is 16384.") | |||
| opt.add_argument("--num_layers", | |||
| type=int, | |||
| default=64, | |||
| help="total layers, default is 64.") | |||
| opt.add_argument("--num_heads", | |||
| type=int, | |||
| default=128, | |||
| help="head size, default is 128.") | |||
| opt.add_argument("--stage_num", | |||
| type=int, | |||
| default=4, | |||
| help="Pipeline stage num, default is 4.") | |||
| opt.add_argument("--micro_size", | |||
| type=int, | |||
| default=1, | |||
| help="Pipeline micro_size, default is 1.") | |||
| opt.add_argument("--eod_reset", | |||
| type=int, | |||
| default=1, | |||
| help="Enable eod mask, default is 1.") | |||
| opt.add_argument("--warmup_step", | |||
| type=int, | |||
| default=2000, | |||
| help="Warmup step, default is 2000.") | |||
| opt.add_argument("--optimizer", | |||
| type=str, | |||
| default="adam", | |||
| choices=["adam", "lamb"], | |||
| help="select which optimizer to be used, default adam") | |||
| opt.add_argument("--eod_id", | |||
| type=int, | |||
| default=6, | |||
| help="The id of end of document") | |||
| opt.add_argument("--epoch_size", | |||
| type=int, | |||
| default=1, | |||
| help="The training epoch") | |||
| opt.add_argument("--sink_size", | |||
| type=int, | |||
| default=2, | |||
| help="The sink size of the training") | |||
| def get_args(): | |||
| """train function for PanguAlpha""" | |||
| parser = argparse.ArgumentParser(description="PanguAlpha training") | |||
| parser.add_argument('--device_id', | |||
| type=int, | |||
| default=0, | |||
| help="Device id, default is 0.") | |||
| parser.add_argument("--device_num", | |||
| type=int, | |||
| default=128, | |||
| help="Use device nums, default is 1.") | |||
| parser.add_argument("--distribute", | |||
| type=str, | |||
| default="true", | |||
| choices=["true", "false"], | |||
| help="Run distribute, default is false.") | |||
| parser.add_argument("--load_ckpt_name", | |||
| type=str, | |||
| default='PANGUALPHA3.ckpt', | |||
| help="checkpint file name.") | |||
| parser.add_argument("--load_ckpt_path", | |||
| type=str, | |||
| default=None, | |||
| help="predict file path.") | |||
| parser.add_argument('--data_url', | |||
| required=False, | |||
| default=None, | |||
| help='Location of data.') | |||
| parser.add_argument('--train_url', | |||
| required=False, | |||
| default=None, | |||
| help='Location of training outputs.') | |||
| parser.add_argument("--run_type", | |||
| type=str, | |||
| default="predict", | |||
| choices=["train", "predict"], | |||
| help="The run type") | |||
| parser.add_argument("--mode", | |||
| type=str, | |||
| default="2.6B", | |||
| choices=["200B", "13B", "2.6B", "self_define"], | |||
| help="The train/eval mode") | |||
| parser.add_argument("--strategy_load_ckpt_path", | |||
| type=str, | |||
| default="", | |||
| help="The training prallel strategy for the model.") | |||
| parser.add_argument("--tokenizer_path", | |||
| type=str, | |||
| default="./tokenizer_path", | |||
| help="The path where stores vocab and vocab model file") | |||
| add_training_params(parser) | |||
| args_opt = parser.parse_args() | |||
| return args_opt | |||
| @@ -0,0 +1,180 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| PanguAlpha train script | |||
| """ | |||
| import os | |||
| import math | |||
| import time | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| import mindspore.communication.management as D | |||
| from mindspore.context import ParallelMode | |||
| import mindspore.nn as nn | |||
| from mindspore.train.callback import TimeMonitor, Callback | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.parallel._cost_model_context import _set_multi_subgraphs | |||
| from mindspore.parallel import set_algo_parameters | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from src.dataset import create_dataset | |||
| from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss | |||
| from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell | |||
| from src.pangu_alpha_config import PANGUALPHAConfig, set_parse | |||
| from src.utils import LearningRate, get_args | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss in NAN or INF terminating training. | |||
| """ | |||
| def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0): | |||
| super(LossCallBack, self).__init__() | |||
| self._dataset_size = dataset_size | |||
| self.local_rank = local_rank | |||
| self.has_trained_epoch = has_trained_epoch | |||
| self.has_trained_step = has_trained_step | |||
| print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True) | |||
| def step_end(self, run_context): | |||
| """ | |||
| Print loss after each step | |||
| """ | |||
| cb_params = run_context.original_args() | |||
| if self._dataset_size > 0 and self.local_rank % 8 == 0: | |||
| percent, epoch_num = math.modf(cb_params.cur_step_num / | |||
| self._dataset_size) | |||
| if percent == 0: | |||
| epoch_num -= 1 | |||
| date = time.asctime(time.localtime(time.time())) | |||
| print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}". | |||
| format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch), | |||
| cb_params.cur_step_num + int(self.has_trained_step), cb_params.net_outputs[0].asnumpy(), | |||
| cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy())) | |||
| project_root = os.path.abspath( | |||
| os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..") | |||
| print('project_root:', project_root) | |||
| def run_train(args_opt): | |||
| r""" | |||
| The main training process. | |||
| """ | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", | |||
| device_id=device_id) | |||
| context.set_context(variable_memory_max_size="30GB") | |||
| if args_opt.distribute == "true": | |||
| D.init() | |||
| device_num = D.get_group_size() | |||
| rank = D.get_rank() | |||
| print("device_id is {}, rank_id is {}, device_num is {}".format( | |||
| device_id, rank, device_num)) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context( | |||
| parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, | |||
| gradients_mean=False, | |||
| device_num=device_num, | |||
| full_batch=True, | |||
| enable_parallel_optimizer=True) | |||
| auto_parallel_context().set_loss_repeated_mean(True) | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| _set_multi_subgraphs() | |||
| else: | |||
| rank = 0 | |||
| device_num = 1 | |||
| model_parallel_num = args_opt.tensor_model_parallel_num | |||
| data_parallel_num = int(device_num / model_parallel_num) | |||
| batch_size = args_opt.per_batch_size * device_num | |||
| config = PANGUALPHAConfig( | |||
| data_parallel_num=data_parallel_num, | |||
| model_parallel_num=model_parallel_num, | |||
| batch_size=batch_size, | |||
| seq_length=args_opt.seq_length, | |||
| vocab_size=args_opt.vocab_size, | |||
| embedding_size=args_opt.embedding_size, | |||
| num_layers=args_opt.num_layers, | |||
| num_heads=args_opt.num_heads, | |||
| expand_ratio=4, | |||
| dropout_rate=0.1, | |||
| compute_dtype=mstype.float16, | |||
| use_past=False, | |||
| self_layernorm=True, | |||
| stage_num=args_opt.stage_num, | |||
| micro_size=args_opt.micro_size, | |||
| eod_reset=bool(args_opt.eod_reset), | |||
| word_emb_dp=True) | |||
| print("===config is: ", config, flush=True) | |||
| pangu_alpha = PanguAlpha(config) | |||
| loss = CrossEntropyLoss(config) | |||
| pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) | |||
| pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) | |||
| print("=====args_opt is: ", args_opt, flush=True) | |||
| lr = LearningRate(learning_rate=args_opt.start_lr, | |||
| end_learning_rate=args_opt.end_lr, | |||
| warmup_steps=args_opt.warmup_step, | |||
| decay_steps=200000, | |||
| lr_scale=1) | |||
| decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() | |||
| params = pangu_alpha.trainable_params() | |||
| decay_params = list(filter(decay_filter, params)) | |||
| other_params = list(filter(lambda x: not decay_filter(x), params)) | |||
| group_params = [{ | |||
| 'params': decay_params, | |||
| 'weight_decay': 1e-1 | |||
| }, { | |||
| 'params': other_params, | |||
| 'weight_decay': 0.0 | |||
| }, { | |||
| 'order_params': params | |||
| }] | |||
| if args_opt.optimizer == "lamb": | |||
| optimizer = nn.Lamb(group_params, learning_rate=lr) | |||
| else: | |||
| optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) | |||
| loss_scale_value = math.pow(2, 32) | |||
| epoch_num = args_opt.epoch_size | |||
| ds = create_dataset(config.batch_size, data_path=args_opt.data_url, | |||
| data_start_index=0, eod_reset=config.eod_reset, | |||
| eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num) | |||
| step_per_epoch = ds.get_dataset_size() | |||
| callback_size = args_opt.sink_size | |||
| actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) | |||
| callback = [ | |||
| TimeMonitor(callback_size), | |||
| LossCallBack(callback_size, rank, 0, 0) | |||
| ] | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) | |||
| pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell( | |||
| pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, | |||
| config=config) | |||
| model = Model(pangu_alpha_with_grads) | |||
| print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) | |||
| model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) | |||
| if __name__ == "__main__": | |||
| opt = get_args() | |||
| set_parse(opt) | |||
| run_train(opt) | |||