@@ -34,7 +34,8 @@ from mindspore.context import ParallelMode
from .layers import _LayerNorm, _Linear, _Dropout, _check_input_shape, \
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config, \
MoEParallelConfig
from .moe import default_moe_config, MoE, _check_moe_config
__all__ = [
@@ -212,6 +213,8 @@ class TransformerOpParallelConfig(_Config):
according to the data parallel way. Default: 1.
model_parallel (int): The model parallel way. The parameters of dense layers in MultiheadAttention and
FeedForward layer will be sliced according to the model parallel way. Default: 1.
expert_parallel (int): The expert parallel way. This is effective only when MoE (Mixture of Experts) is applied.
This value specifies the number of partitions to split the experts into.
pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
micro_batch_num (int): The micro size of the batches for the pipeline training. Default: 1.
optimizer_shard (bool): Whether to enable optimizer shard. Default False.
@@ -230,7 +233,7 @@ class TransformerOpParallelConfig(_Config):
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
"""
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1,
recompute=default_transformer_recompute_config,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
self.recompute = recompute
@@ -239,6 +242,8 @@ class TransformerOpParallelConfig(_Config):
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
vocab_emb_dp=vocab_emb_dp)
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num)
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel,
expert_parallel=expert_parallel)
@property
def recompute(self):
@@ -284,6 +289,7 @@ class TransformerOpParallelConfig(_Config):
@model_parallel.setter
def model_parallel(self, value):
self._embed_dp_mp_config.model_parallel = value
self._moe_config.model_parallel = value
@property
def data_parallel(self):
@@ -292,6 +298,15 @@ class TransformerOpParallelConfig(_Config):
@data_parallel.setter
def data_parallel(self, value):
self._embed_dp_mp_config.data_parallel = value
self._moe_config.data_parallel = value
@property
def expert_parallel(self):
return self._moe_config.expert_parallel
@expert_parallel.setter
def expert_parallel(self, value):
self._moe_config.expert_parallel = value
@property
def pipeline_stage(self):
@@ -342,6 +357,10 @@ class TransformerOpParallelConfig(_Config):
"""
return self._embed_dp_mp_config.dp_mp_config
@property
def moe_parallel_config(self):
return self._moe_config
default_transformer_config = TransformerOpParallelConfig()
default_embedding_parallel_config = EmbeddingOpParallelConfig()
@@ -370,9 +389,9 @@ class FeedForward(Cell):
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
param_init_type: (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16.
Default: dtype.float32.
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, an instance of `OpParallelConfig` with
default args.
parallel_config (OpParallelConfig, MoE ParallelConfig): The config of parallel setting, see `OpParallelConfig` or
`MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective, otherwise OpParallelConfig is
effective. Default `default_dpmp_config`, an instance of `OpParallelConfig` with default args.
Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
@@ -409,7 +428,7 @@ class FeedForward(Cell):
hidden_act=_valid_type_checks([str], "FeedForward"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"FeedForward"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig ],
"FeedForward"))
def __init__(self, hidden_size,
ffn_hidden_size,
@@ -422,6 +441,10 @@ class FeedForward(Cell):
_check_config(parallel_config)
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
if expert_num > 1:
ep = parallel_config.expert_parallel
else:
ep = 1
if ffn_hidden_size % mp != 0:
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the num of "
"model parallel, but got the ffn_hidden_size is {} and the num of model parallel is {}."
@@ -435,20 +458,20 @@ class FeedForward(Cell):
"but got the value : {}.".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
ep = dp
# Project to ffn_hidden_size
self.mapping = _Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
outer_batch=dp,
param_init_type=param_init_type)
if expert_num > 1:
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
strategy_bias=((ep, 1, mp), (mp,)),
strategy_activation=((ep, 1, mp),))
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)),
strategy_bias=((dp, ep, 1, mp), (mp,)),
strategy_activation=((dp, ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
@@ -458,10 +481,11 @@ class FeedForward(Cell):
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
outer_batch=dp,
param_init_type=param_init_type)
if expert_num > 1:
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)),
strategy_bias=((ep, 1, 1), (1,)))
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)),
strategy_bias=((dp, ep, 1, 1), (1,)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
strategy_bias=((dp, 1), (1,)))
@@ -470,6 +494,8 @@ class FeedForward(Cell):
self.dropout.shard(((dp, 1),))
self.dropout_3d = _Dropout(1 - dropout_rate)
self.dropout_3d.shard(((dp, 1, 1),))
self.dropout_4d = _Dropout(1 - dropout_rate)
self.dropout_4d.shard(((dp, ep, 1, 1),))
self.cast = P.Cast()
def construct(self, x):
@@ -482,8 +508,10 @@ class FeedForward(Cell):
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
if len(F.shape(output)) == 3:
output = self.dropout_3d(output)
else:
elif len(F. shap e(output)) == 2 :
output = self.dropout(output)
else:
output = self.dropout_4d(output)
return output
@@ -1164,7 +1192,8 @@ class TransformerEncoderLayer(Cell):
pass the single step's input tensor, and loop it. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
@@ -1253,7 +1282,7 @@ class TransformerEncoderLayer(Cell):
"TransformerEncoderLayer"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerEncoderLayer"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig ],
"TransformerEncoderLayer"),
use_past=Validator.check_bool)
def __init__(self,
@@ -1287,6 +1316,8 @@ class TransformerEncoderLayer(Cell):
"by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
"and parallel_config. model_parallel is {}."
.format(ffn_hidden_size, parallel_config.model_parallel))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
self.use_past = use_past
self.seq_length = seq_length
self.hidden_size = hidden_size
@@ -1295,20 +1326,30 @@ class TransformerEncoderLayer(Cell):
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config)
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe is True:
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config.dpmp)
else:
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
tgt_seq_length=seq_length,
hidden_size=hidden_size,
num_heads=num_heads,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config)
if self.use_moe:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
@@ -1480,7 +1521,8 @@ class TransformerDecoderLayer(Cell):
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
default values. Please see `MoEConfig`.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied,
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
@@ -1553,7 +1595,7 @@ class TransformerDecoderLayer(Cell):
"TransformerDecoderLayer"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerDecoderLayer"),
parallel_config=_valid_type_checks([OpParallelConfig],
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig ],
"TransformerDecoderLayer"),
use_past=Validator.check_bool)
def __init__(self, hidden_size,
@@ -1588,6 +1630,8 @@ class TransformerDecoderLayer(Cell):
"divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} "
"and parallel_config.model_parallel is {}."
.format(ffn_hidden_size, parallel_config.model_parallel))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if use_past:
raise ValueError(f"The {self.cls_name} does not support use_past=True.")
self.batch_size = batch_size
@@ -1603,35 +1647,59 @@ class TransformerDecoderLayer(Cell):
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config)
if self.use_moe is True:
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config.dpmp)
else:
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=tgt_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
use_past=use_past,
softmax_compute_type=softmax_compute_type,
param_init_type=param_init_type,
parallel_config=parallel_config)
# Cross attention with the output of encoder as memory tensor
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config)
if self.use_moe is True:
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config.dpmp)
else:
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
batch_size=batch_size,
src_seq_length=tgt_seq_length,
tgt_seq_length=src_seq_length,
hidden_dropout_rate=hidden_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
softmax_compute_type=softmax_compute_type,
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
layernorm_compute_type)
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
@@ -2019,21 +2087,38 @@ class TransformerEncoder(Cell):
self.num_layers = num_layers
self.blocks = nn.CellList()
for i in range(num_layers):
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
if self.use_moe is True:
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.moe_parallel_config)
else:
block = TransformerEncoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
seq_length=seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
num_heads=num_heads,
hidden_act=hidden_act,
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
lambda_func = _get_lambda_func()
@@ -2214,22 +2299,40 @@ class TransformerDecoder(Cell):
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
for i in range(num_layers):
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
if self.use_moe is True:
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.moe_parallel_config)
else:
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
ffn_hidden_size=ffn_hidden_size,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
attention_dropout_rate=attention_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
num_heads=num_heads,
layernorm_compute_type=layernorm_compute_type,
softmax_compute_type=softmax_compute_type,
hidden_act=hidden_act,
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
lambda_func = _get_lambda_func()