From: @huangxinjing Reviewed-by: @stsuteng,@zhunaipan Signed-off-by: @lilongfei15pull/15681/MERGE
| @@ -4,7 +4,7 @@ | |||
| # Contents | |||
| - [Contents](#contents) | |||
| - [PanGu1 Description](#bert-description) | |||
| - [PanGu-Alpha Description](#pangu-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Environment Requirements](#environment-requirements) | |||
| @@ -13,30 +13,34 @@ | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [PanGu1 Description](#contents) | |||
| # [PanGu-Alpha Description](#pangu-description) | |||
| We release the code to explore the new front-edge of training large model with billions or even trillions of parameters. | |||
| By MindSpore's parallel feature, we adopt the efficient model parallel and data parallel technology such as operator level parallelism, | |||
| to minimize the communication cost and maximize computation efficiency. | |||
| The code is easy to scale to thousands of NPUs and trillion parameters with little modifications. | |||
| In the mean while, we run our parallel training upon a language model, named PanGu1, to demonstrate the large model can be trained easily | |||
| In the mean while, we run our parallel training upon a language model, named PanGu-Alpha, to demonstrate the large model can be trained easily | |||
| with our parallel setting. We summarized the training tricks as followings: | |||
| 1. Data Parallel | |||
| 2. Model Parallel | |||
| 3. Optimizer Parallel | |||
| 1. Op-level Model Parallelism | |||
| 2. Pipeline Model Parallelism | |||
| 3. Optimizer Model Parallelism | |||
| The above features can be found [here](https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.1/auto_parallel.html). | |||
| The above features can be found [here](https://www.mindspore.cn/doc/programming_guide/en/r1.2/auto_parallel.html). | |||
| More amazing features are still under developing. | |||
| The technical report and checkpoint file can be found [here](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-AIpha). | |||
| # [Model Architecture](#contents) | |||
| The PanGu1 model stacks many layers of transformer decoder with a few modifications. Hybrid parallelism is applied to maximize the device utilization | |||
| and reduce the communication consumption. The code demonstrate the training procedure on 32 Ascend card by 4-way data parallelism on the top 8 model parallelism. | |||
| Both 4-way data parallelism and 8-way model parallelism can be modified in the configuration to fit the variable scale of the cluster. | |||
|  | |||
| The architecture of PanGu-α is based on Transformer, which has been extensively used as the backbone of a variety of | |||
| pretrained language models such as BERT and GPT. Different from them, we develop an additional query layeron top of | |||
| Transformer layers to predict the next token. The diagram of the model is shown in Figure 1. | |||
| # [Dataset](#contents) | |||
| # [Dataset](#dataset) | |||
| - Open Source Dataset. | |||
| @@ -80,8 +84,8 @@ https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools. | |||
| │ └── run_distribute_train.sh | |||
| ├── src | |||
| │ ├── dataset.py | |||
| │ ├── pangu1.py | |||
| │ ├── pangu1_wrapcell.py | |||
| │ ├── pangu_alpha.py | |||
| │ ├── pangu_alpha_wrapcell.py | |||
| │ └── utils.py | |||
| └── train.py | |||
| ``` | |||
| @@ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """PanGu1 model""" | |||
| """PanguAlpha model""" | |||
| import math | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| @@ -189,7 +189,7 @@ class Output(nn.Cell): | |||
| """ | |||
| The output mapping module for each layer | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| scale: scale factor for initialization | |||
| Inputs: | |||
| x: output of the self-attention module | |||
| @@ -219,7 +219,7 @@ class AttentionMask(nn.Cell): | |||
| r""" | |||
| Get the attention matrix for self-attention module | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| Returns: | |||
| @@ -255,7 +255,7 @@ class EmbeddingLookup(nn.Cell): | |||
| """ | |||
| The embedding lookup table for vocabulary | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| Returns: | |||
| @@ -289,7 +289,7 @@ class Attention(nn.Cell): | |||
| Self-Attention module for each layer | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| scale: scale factor for initialization | |||
| layer_idx: current layer index | |||
| """ | |||
| @@ -477,9 +477,9 @@ class Attention(nn.Cell): | |||
| class Block(nn.Cell): | |||
| """ | |||
| The basic block of PanGu1 network | |||
| The basic block of PanguAlpha network | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| layer_idx: current layer index | |||
| Inputs: | |||
| x: the output of previous layer(input_ids for the first layer) | |||
| @@ -623,11 +623,11 @@ class QueryLayer(nn.Cell): | |||
| output = self.last_add(x, mlp_logit) | |||
| return output, layer_present | |||
| class PanGu1_Model(nn.Cell): | |||
| class PanguAlpha_Model(nn.Cell): | |||
| """ | |||
| The backbone of PanGu1 network | |||
| The backbone of PanguAlpha network | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs with datatype int32 | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| @@ -638,7 +638,7 @@ class PanGu1_Model(nn.Cell): | |||
| embedding_table: Tensor, the embedding table for the vocabulary | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanGu1_Model, self).__init__() | |||
| super(PanguAlpha_Model, self).__init__() | |||
| self.get_attention_mask = AttentionMask(config) | |||
| self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1) | |||
| self.position_embedding = nn.Embedding( | |||
| @@ -714,7 +714,7 @@ class PanGu1_Model(nn.Cell): | |||
| def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None): | |||
| """PanGu1 model""" | |||
| """PanguAlpha model""" | |||
| if not self.use_past: | |||
| layer_past = self.past | |||
| @@ -748,11 +748,11 @@ class PanGu1_Model(nn.Cell): | |||
| return output_state, present_layer, embedding_table | |||
| class PanGu1_Head(nn.Cell): | |||
| class PanguAlpha_Head(nn.Cell): | |||
| """ | |||
| Head for PanGu1 to get the logits of each token in the vocab | |||
| Head for PanguAlpha to get the logits of each token in the vocab | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| state: the output of the backbone | |||
| embedding_table: the embedding table of the vocabulary | |||
| @@ -760,7 +760,7 @@ class PanGu1_Head(nn.Cell): | |||
| logits: Tensor, the logits of the corresponding inputs | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanGu1_Head, self).__init__() | |||
| super(PanguAlpha_Head, self).__init__() | |||
| if config.word_emb_dp: | |||
| self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (1, 1))) | |||
| else: | |||
| @@ -776,11 +776,11 @@ class PanGu1_Head(nn.Cell): | |||
| return logits | |||
| class PanGu1(nn.Cell): | |||
| class PanguAlpha(nn.Cell): | |||
| """ | |||
| The PanGu1 network consisting of two parts the backbone and the head | |||
| The PanguAlpha network consisting of two parts the backbone and the head | |||
| Args: | |||
| config(PanGu1Config): the config of network | |||
| config(PanguAlphaConfig): the config of network | |||
| Inputs: | |||
| input_ids: the tokenized inputs | |||
| input_mask: the mask indicating whether each position is a valid input | |||
| @@ -789,9 +789,9 @@ class PanGu1(nn.Cell): | |||
| logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) | |||
| """ | |||
| def __init__(self, config): | |||
| super(PanGu1, self).__init__() | |||
| self.backbone = PanGu1_Model(config) | |||
| self.head = PanGu1_Head(config) | |||
| super(PanguAlpha, self).__init__() | |||
| self.backbone = PanguAlpha_Model(config) | |||
| self.head = PanguAlpha_Head(config) | |||
| def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None): | |||
| output_states, _, embedding_table = self.backbone( | |||
| @@ -804,7 +804,7 @@ class CrossEntropyLoss(nn.Cell): | |||
| """ | |||
| Calculate the cross entropy loss | |||
| Args: | |||
| config(PanGu1Config): the config of the network | |||
| config(PanguAlphaConfig): the config of the network | |||
| Inputs: | |||
| logits: the output logits of the backbone | |||
| label: the ground truth label of the sample | |||
| @@ -865,11 +865,11 @@ class CrossEntropyLoss(nn.Cell): | |||
| return loss | |||
| class PanGu1WithLoss(nn.Cell): | |||
| class PanguAlphaWithLoss(nn.Cell): | |||
| """ | |||
| PanGu1 training loss | |||
| PanguAlpha training loss | |||
| Args: | |||
| network: backbone network of PanGu1 | |||
| network: backbone network of PanguAlpha | |||
| loss: loss function, e.g., crossentropy | |||
| eos_token: the end_of_sentence token | |||
| Inputs: | |||
| @@ -879,7 +879,7 @@ class PanGu1WithLoss(nn.Cell): | |||
| output: Tensor, the loss of the network | |||
| """ | |||
| def __init__(self, config, network, loss, eos_token=6): | |||
| super(PanGu1WithLoss, self).__init__(auto_prefix=False) | |||
| super(PanguAlphaWithLoss, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.loss = loss | |||
| self.eos_token = eos_token | |||
| @@ -893,7 +893,7 @@ class PanGu1WithLoss(nn.Cell): | |||
| def construct(self, input_ids, input_position=None, attention_mask=None): | |||
| r""" | |||
| PanGu1WithLoss | |||
| PanguAlphaWithLoss | |||
| """ | |||
| tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) | |||
| @@ -914,9 +914,9 @@ class PanGu1WithLoss(nn.Cell): | |||
| class EvalNet(nn.Cell): | |||
| """ | |||
| PanGu1 evaluation net | |||
| PanguAlpha evaluation net | |||
| Args: | |||
| backbone: backbone network of PanGu1 | |||
| backbone: backbone network of PanguAlpha | |||
| generate: enable generate mode | |||
| Inputs: | |||
| input_ids: the tokenized inpus | |||
| @@ -75,9 +75,9 @@ class VirtualDatasetOneInputCell(nn.Cell): | |||
| data_ = self._virtual_dataset(*data) | |||
| return self._backbone(*data_) | |||
| class PanGu1TrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| """ | |||
| Encapsulation class of PanGu1 network training. | |||
| Encapsulation class of PanguAlpha network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| @@ -93,7 +93,7 @@ class PanGu1TrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): | |||
| scale_update_cell=None, | |||
| enable_global_norm=False, | |||
| config=None): | |||
| super(PanGu1TrainOneStepWithLossScaleCell, | |||
| super(PanguAlphaTrainOneStepWithLossScaleCell, | |||
| self).__init__(network, optimizer, scale_update_cell) | |||
| self.network = network | |||
| self.config = config | |||
| @@ -28,9 +28,9 @@ from mindspore.nn.learning_rate_schedule import LearningRateSchedule, Polynomial | |||
| from mindspore.parallel._utils import _get_global_rank | |||
| from mindspore.communication.management import get_group_size | |||
| class PanGu1Config: | |||
| class PanguAlphaConfig: | |||
| """ | |||
| PanGu1 config class which defines the model size | |||
| PanguAlpha config class which defines the model size | |||
| """ | |||
| def __init__(self, | |||
| data_parallel_num, | |||
| @@ -77,7 +77,7 @@ class PanGu1Config: | |||
| self.use_top_query_attention = use_top_query_attention | |||
| def __str__(self): | |||
| info = "[PanGu1 Config]" + '===' * 10 + '\n' | |||
| info = "[PanguAlpha Config]" + '===' * 10 + '\n' | |||
| for k, v in self.__dict__.items(): | |||
| var_info = "{}:{}\n".format(k, v) | |||
| info += var_info | |||
| @@ -162,7 +162,7 @@ def _get_model_parallel_group(dp, mp): | |||
| class LearningRate(LearningRateSchedule): | |||
| """ | |||
| Warmup-decay learning rate for PanGu1 network. | |||
| Warmup-decay learning rate for PanguAlpha network. | |||
| """ | |||
| def __init__(self, | |||
| learning_rate, | |||
| @@ -206,8 +206,8 @@ class LearningRate(LearningRateSchedule): | |||
| def get_args(): | |||
| """train function for PanGu1""" | |||
| parser = argparse.ArgumentParser(description="PanGu1 training") | |||
| """train function for PanguAlpha""" | |||
| parser = argparse.ArgumentParser(description="PanguAlpha training") | |||
| parser.add_argument('--device_id', | |||
| type=int, | |||
| default=0, | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| PanGu1 train script | |||
| PanguAlpha train script | |||
| """ | |||
| import os | |||
| @@ -31,9 +31,9 @@ from mindspore.parallel._cost_model_context import _set_multi_subgraphs | |||
| from mindspore.parallel import set_algo_parameters | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from src.dataset import create_dataset | |||
| from src.pangu1 import PanGu1, PanGu1WithLoss, CrossEntropyLoss | |||
| from src.pangu1_wrapcell import PanGu1TrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell | |||
| from src.utils import PanGu1Config, LearningRate, get_args | |||
| from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss | |||
| from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell | |||
| from src.utils import PanguAlphaConfig, LearningRate, get_args | |||
| class LossCallBack(Callback): | |||
| @@ -110,7 +110,7 @@ def run_train(): | |||
| model_parallel_num = args_opt.mp | |||
| data_parallel_num = int(device_num / model_parallel_num) | |||
| batch_size = args_opt.per_batch_size * device_num | |||
| config = PanGu1Config( | |||
| config = PanguAlphaConfig( | |||
| data_parallel_num=data_parallel_num, | |||
| model_parallel_num=model_parallel_num, | |||
| batch_size=batch_size, | |||
| @@ -124,10 +124,10 @@ def run_train(): | |||
| compute_dtype=mstype.float16, | |||
| eod_reset=bool(args_opt.eod_reset)) | |||
| print("===config is: ", config, flush=True) | |||
| pangu1 = PanGu1(config) | |||
| pangu_alpha = PanguAlpha(config) | |||
| loss = CrossEntropyLoss(config) | |||
| pangu1_with_loss = PanGu1WithLoss(config, pangu1, loss) | |||
| pangu1_with_loss = VirtualDatasetOneInputCell(pangu1_with_loss) | |||
| pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) | |||
| pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) | |||
| print("=====args_opt is: ", args_opt, flush=True) | |||
| lr = LearningRate(learning_rate=args_opt.start_lr, | |||
| @@ -137,7 +137,7 @@ def run_train(): | |||
| lr_scale=1) | |||
| decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() | |||
| params = pangu1.trainable_params() | |||
| params = pangu_alpha.trainable_params() | |||
| decay_params = list(filter(decay_filter, params)) | |||
| other_params = list(filter(lambda x: not decay_filter(x), params)) | |||
| group_params = [{ | |||
| @@ -171,10 +171,11 @@ def run_train(): | |||
| LossCallBack(callback_size, rank, has_trained_epoch, has_trained_step) | |||
| ] | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) | |||
| pangu1_with_grads = PanGu1TrainOneStepWithLossScaleCell( | |||
| pangu1_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, config=config) | |||
| pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell( | |||
| pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, | |||
| config=config) | |||
| model = Model(pangu1_with_grads) | |||
| model = Model(pangu_alpha_with_grads) | |||
| print("=====dataset size: ", ds.get_dataset_size(), flush=True) | |||
| print("=====actual_epoch_num: ", actual_epoch_num, flush=True) | |||
| model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) | |||