Browse Source

!15681 Add PanGu-Alpha model

From: @huangxinjing
Reviewed-by: @stsuteng,@zhunaipan
Signed-off-by: @lilongfei15
pull/15681/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
06b5835cd7
8 changed files with 69 additions and 64 deletions
  1. +17
    -13
      model_zoo/official/nlp/pangu_alpha/README.md
  2. BIN
      model_zoo/official/nlp/pangu_alpha/docs/model.png
  3. +0
    -0
      model_zoo/official/nlp/pangu_alpha/scripts/run_distribute_train.sh
  4. +0
    -0
      model_zoo/official/nlp/pangu_alpha/src/dataset.py
  5. +30
    -30
      model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py
  6. +3
    -3
      model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py
  7. +6
    -6
      model_zoo/official/nlp/pangu_alpha/src/utils.py
  8. +13
    -12
      model_zoo/official/nlp/pangu_alpha/train.py

model_zoo/official/nlp/pangu1/README.md → model_zoo/official/nlp/pangu_alpha/README.md View File

@@ -4,7 +4,7 @@
# Contents

- [Contents](#contents)
- [PanGu1 Description](#bert-description)
- [PanGu-Alpha Description](#pangu-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
@@ -13,30 +13,34 @@
- [Script and Sample Code](#script-and-sample-code)
- [ModelZoo Homepage](#modelzoo-homepage)

# [PanGu1 Description](#contents)
# [PanGu-Alpha Description](#pangu-description)

We release the code to explore the new front-edge of training large model with billions or even trillions of parameters.
By MindSpore's parallel feature, we adopt the efficient model parallel and data parallel technology such as operator level parallelism,
to minimize the communication cost and maximize computation efficiency.
The code is easy to scale to thousands of NPUs and trillion parameters with little modifications.

In the mean while, we run our parallel training upon a language model, named PanGu1, to demonstrate the large model can be trained easily
In the mean while, we run our parallel training upon a language model, named PanGu-Alpha, to demonstrate the large model can be trained easily
with our parallel setting. We summarized the training tricks as followings:

1. Data Parallel
2. Model Parallel
3. Optimizer Parallel
1. Op-level Model Parallelism
2. Pipeline Model Parallelism
3. Optimizer Model Parallelism

The above features can be found [here](https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.1/auto_parallel.html).
The above features can be found [here](https://www.mindspore.cn/doc/programming_guide/en/r1.2/auto_parallel.html).
More amazing features are still under developing.

The technical report and checkpoint file can be found [here](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-AIpha).

# [Model Architecture](#contents)

The PanGu1 model stacks many layers of transformer decoder with a few modifications. Hybrid parallelism is applied to maximize the device utilization
and reduce the communication consumption. The code demonstrate the training procedure on 32 Ascend card by 4-way data parallelism on the top 8 model parallelism.
Both 4-way data parallelism and 8-way model parallelism can be modified in the configuration to fit the variable scale of the cluster.
![](./docs/model.png)

The architecture of PanGu-α is based on Transformer, which has been extensively used as the backbone of a variety of
pretrained language models such as BERT and GPT. Different from them, we develop an additional query layeron top of
Transformer layers to predict the next token. The diagram of the model is shown in Figure 1.

# [Dataset](#contents)
# [Dataset](#dataset)

- Open Source Dataset.

@@ -80,8 +84,8 @@ https:gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
│ └── run_distribute_train.sh
├── src
│ ├── dataset.py
│ ├── pangu1.py
│ ├── pangu1_wrapcell.py
│ ├── pangu_alpha.py
│ ├── pangu_alpha_wrapcell.py
│ └── utils.py
└── train.py
```

BIN
model_zoo/official/nlp/pangu_alpha/docs/model.png View File

Before After
Width: 1214  |  Height: 630  |  Size: 133 kB

model_zoo/official/nlp/pangu1/scripts/run_distribute_train.sh → model_zoo/official/nlp/pangu_alpha/scripts/run_distribute_train.sh View File


model_zoo/official/nlp/pangu1/src/dataset.py → model_zoo/official/nlp/pangu_alpha/src/dataset.py View File


model_zoo/official/nlp/pangu1/src/pangu1.py → model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PanGu1 model"""
"""PanguAlpha model"""
import math
import numpy as np
import mindspore.nn as nn
@@ -189,7 +189,7 @@ class Output(nn.Cell):
"""
The output mapping module for each layer
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
scale: scale factor for initialization
Inputs:
x: output of the self-attention module
@@ -219,7 +219,7 @@ class AttentionMask(nn.Cell):
r"""
Get the attention matrix for self-attention module
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
Inputs:
input_mask: the mask indicating whether each position is a valid input
Returns:
@@ -255,7 +255,7 @@ class EmbeddingLookup(nn.Cell):
"""
The embedding lookup table for vocabulary
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
Returns:
@@ -289,7 +289,7 @@ class Attention(nn.Cell):
Self-Attention module for each layer

Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
scale: scale factor for initialization
layer_idx: current layer index
"""
@@ -477,9 +477,9 @@ class Attention(nn.Cell):

class Block(nn.Cell):
"""
The basic block of PanGu1 network
The basic block of PanguAlpha network
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
layer_idx: current layer index
Inputs:
x: the output of previous layer(input_ids for the first layer)
@@ -623,11 +623,11 @@ class QueryLayer(nn.Cell):
output = self.last_add(x, mlp_logit)
return output, layer_present

class PanGu1_Model(nn.Cell):
class PanguAlpha_Model(nn.Cell):
"""
The backbone of PanGu1 network
The backbone of PanguAlpha network
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
input_mask: the mask indicating whether each position is a valid input
@@ -638,7 +638,7 @@ class PanGu1_Model(nn.Cell):
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanGu1_Model, self).__init__()
super(PanguAlpha_Model, self).__init__()
self.get_attention_mask = AttentionMask(config)
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
self.position_embedding = nn.Embedding(
@@ -714,7 +714,7 @@ class PanGu1_Model(nn.Cell):


def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None):
"""PanGu1 model"""
"""PanguAlpha model"""
if not self.use_past:
layer_past = self.past

@@ -748,11 +748,11 @@ class PanGu1_Model(nn.Cell):
return output_state, present_layer, embedding_table


class PanGu1_Head(nn.Cell):
class PanguAlpha_Head(nn.Cell):
"""
Head for PanGu1 to get the logits of each token in the vocab
Head for PanguAlpha to get the logits of each token in the vocab
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
Inputs:
state: the output of the backbone
embedding_table: the embedding table of the vocabulary
@@ -760,7 +760,7 @@ class PanGu1_Head(nn.Cell):
logits: Tensor, the logits of the corresponding inputs
"""
def __init__(self, config):
super(PanGu1_Head, self).__init__()
super(PanguAlpha_Head, self).__init__()
if config.word_emb_dp:
self.matmul = P.MatMul(transpose_b=True).shard(((config.dp, 1), (1, 1)))
else:
@@ -776,11 +776,11 @@ class PanGu1_Head(nn.Cell):
return logits


class PanGu1(nn.Cell):
class PanguAlpha(nn.Cell):
"""
The PanGu1 network consisting of two parts the backbone and the head
The PanguAlpha network consisting of two parts the backbone and the head
Args:
config(PanGu1Config): the config of network
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs
input_mask: the mask indicating whether each position is a valid input
@@ -789,9 +789,9 @@ class PanGu1(nn.Cell):
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
def __init__(self, config):
super(PanGu1, self).__init__()
self.backbone = PanGu1_Model(config)
self.head = PanGu1_Head(config)
super(PanguAlpha, self).__init__()
self.backbone = PanguAlpha_Model(config)
self.head = PanguAlpha_Head(config)

def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None):
output_states, _, embedding_table = self.backbone(
@@ -804,7 +804,7 @@ class CrossEntropyLoss(nn.Cell):
"""
Calculate the cross entropy loss
Args:
config(PanGu1Config): the config of the network
config(PanguAlphaConfig): the config of the network
Inputs:
logits: the output logits of the backbone
label: the ground truth label of the sample
@@ -865,11 +865,11 @@ class CrossEntropyLoss(nn.Cell):
return loss


class PanGu1WithLoss(nn.Cell):
class PanguAlphaWithLoss(nn.Cell):
"""
PanGu1 training loss
PanguAlpha training loss
Args:
network: backbone network of PanGu1
network: backbone network of PanguAlpha
loss: loss function, e.g., crossentropy
eos_token: the end_of_sentence token
Inputs:
@@ -879,7 +879,7 @@ class PanGu1WithLoss(nn.Cell):
output: Tensor, the loss of the network
"""
def __init__(self, config, network, loss, eos_token=6):
super(PanGu1WithLoss, self).__init__(auto_prefix=False)
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
self.network = network
self.loss = loss
self.eos_token = eos_token
@@ -893,7 +893,7 @@ class PanGu1WithLoss(nn.Cell):

def construct(self, input_ids, input_position=None, attention_mask=None):
r"""
PanGu1WithLoss
PanguAlphaWithLoss
"""
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))

@@ -914,9 +914,9 @@ class PanGu1WithLoss(nn.Cell):

class EvalNet(nn.Cell):
"""
PanGu1 evaluation net
PanguAlpha evaluation net
Args:
backbone: backbone network of PanGu1
backbone: backbone network of PanguAlpha
generate: enable generate mode
Inputs:
input_ids: the tokenized inpus

model_zoo/official/nlp/pangu1/src/pangu1_wrapcell.py → model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py View File

@@ -75,9 +75,9 @@ class VirtualDatasetOneInputCell(nn.Cell):
data_ = self._virtual_dataset(*data)
return self._backbone(*data_)

class PanGu1TrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
"""
Encapsulation class of PanGu1 network training.
Encapsulation class of PanguAlpha network training.

Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
@@ -93,7 +93,7 @@ class PanGu1TrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
scale_update_cell=None,
enable_global_norm=False,
config=None):
super(PanGu1TrainOneStepWithLossScaleCell,
super(PanguAlphaTrainOneStepWithLossScaleCell,
self).__init__(network, optimizer, scale_update_cell)
self.network = network
self.config = config

model_zoo/official/nlp/pangu1/src/utils.py → model_zoo/official/nlp/pangu_alpha/src/utils.py View File

@@ -28,9 +28,9 @@ from mindspore.nn.learning_rate_schedule import LearningRateSchedule, Polynomial
from mindspore.parallel._utils import _get_global_rank
from mindspore.communication.management import get_group_size

class PanGu1Config:
class PanguAlphaConfig:
"""
PanGu1 config class which defines the model size
PanguAlpha config class which defines the model size
"""
def __init__(self,
data_parallel_num,
@@ -77,7 +77,7 @@ class PanGu1Config:
self.use_top_query_attention = use_top_query_attention

def __str__(self):
info = "[PanGu1 Config]" + '===' * 10 + '\n'
info = "[PanguAlpha Config]" + '===' * 10 + '\n'
for k, v in self.__dict__.items():
var_info = "{}:{}\n".format(k, v)
info += var_info
@@ -162,7 +162,7 @@ def _get_model_parallel_group(dp, mp):

class LearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for PanGu1 network.
Warmup-decay learning rate for PanguAlpha network.
"""
def __init__(self,
learning_rate,
@@ -206,8 +206,8 @@ class LearningRate(LearningRateSchedule):


def get_args():
"""train function for PanGu1"""
parser = argparse.ArgumentParser(description="PanGu1 training")
"""train function for PanguAlpha"""
parser = argparse.ArgumentParser(description="PanguAlpha training")
parser.add_argument('--device_id',
type=int,
default=0,

model_zoo/official/nlp/pangu1/train.py → model_zoo/official/nlp/pangu_alpha/train.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
PanGu1 train script
PanguAlpha train script
"""

import os
@@ -31,9 +31,9 @@ from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.dataset import create_dataset
from src.pangu1 import PanGu1, PanGu1WithLoss, CrossEntropyLoss
from src.pangu1_wrapcell import PanGu1TrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
from src.utils import PanGu1Config, LearningRate, get_args
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
from src.utils import PanguAlphaConfig, LearningRate, get_args


class LossCallBack(Callback):
@@ -110,7 +110,7 @@ def run_train():
model_parallel_num = args_opt.mp
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * device_num
config = PanGu1Config(
config = PanguAlphaConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
batch_size=batch_size,
@@ -124,10 +124,10 @@ def run_train():
compute_dtype=mstype.float16,
eod_reset=bool(args_opt.eod_reset))
print("===config is: ", config, flush=True)
pangu1 = PanGu1(config)
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
pangu1_with_loss = PanGu1WithLoss(config, pangu1, loss)
pangu1_with_loss = VirtualDatasetOneInputCell(pangu1_with_loss)
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)

print("=====args_opt is: ", args_opt, flush=True)
lr = LearningRate(learning_rate=args_opt.start_lr,
@@ -137,7 +137,7 @@ def run_train():
lr_scale=1)

decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
params = pangu1.trainable_params()
params = pangu_alpha.trainable_params()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{
@@ -171,10 +171,11 @@ def run_train():
LossCallBack(callback_size, rank, has_trained_epoch, has_trained_step)
]
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu1_with_grads = PanGu1TrainOneStepWithLossScaleCell(
pangu1_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, config=config)
pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
config=config)

model = Model(pangu1_with_grads)
model = Model(pangu_alpha_with_grads)
print("=====dataset size: ", ds.get_dataset_size(), flush=True)
print("=====actual_epoch_num: ", actual_epoch_num, flush=True)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)

Loading…
Cancel
Save