|
|
|
@@ -34,7 +34,8 @@ from mindspore.context import ParallelMode |
|
|
|
from .layers import _LayerNorm, _Linear, _Dropout, _check_input_shape, \ |
|
|
|
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \ |
|
|
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value |
|
|
|
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config |
|
|
|
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config, \ |
|
|
|
MoEParallelConfig |
|
|
|
from .moe import default_moe_config, MoE, _check_moe_config |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -212,6 +213,8 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
according to the data parallel way. Default: 1. |
|
|
|
model_parallel (int): The model parallel way. The parameters of dense layers in MultiheadAttention and |
|
|
|
FeedForward layer will be sliced according to the model parallel way. Default: 1. |
|
|
|
expert_parallel (int): The expert parallel way. This is effective only when MoE (Mixture of Experts) is applied. |
|
|
|
This value specifies the number of partitions to split the experts into. |
|
|
|
pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1. |
|
|
|
micro_batch_num (int): The micro size of the batches for the pipeline training. Default: 1. |
|
|
|
optimizer_shard (bool): Whether to enable optimizer shard. Default False. |
|
|
|
@@ -230,7 +233,7 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, |
|
|
|
def __init__(self, data_parallel=1, model_parallel=1, expert_parallel=1, pipeline_stage=1, micro_batch_num=1, |
|
|
|
recompute=default_transformer_recompute_config, |
|
|
|
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): |
|
|
|
self.recompute = recompute |
|
|
|
@@ -239,6 +242,8 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, |
|
|
|
vocab_emb_dp=vocab_emb_dp) |
|
|
|
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num) |
|
|
|
self._moe_config = MoEParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, |
|
|
|
expert_parallel=expert_parallel) |
|
|
|
|
|
|
|
@property |
|
|
|
def recompute(self): |
|
|
|
@@ -284,6 +289,7 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
@model_parallel.setter |
|
|
|
def model_parallel(self, value): |
|
|
|
self._embed_dp_mp_config.model_parallel = value |
|
|
|
self._moe_config.model_parallel = value |
|
|
|
|
|
|
|
@property |
|
|
|
def data_parallel(self): |
|
|
|
@@ -292,6 +298,15 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
@data_parallel.setter |
|
|
|
def data_parallel(self, value): |
|
|
|
self._embed_dp_mp_config.data_parallel = value |
|
|
|
self._moe_config.data_parallel = value |
|
|
|
|
|
|
|
@property |
|
|
|
def expert_parallel(self): |
|
|
|
return self._moe_config.expert_parallel |
|
|
|
|
|
|
|
@expert_parallel.setter |
|
|
|
def expert_parallel(self, value): |
|
|
|
self._moe_config.expert_parallel = value |
|
|
|
|
|
|
|
@property |
|
|
|
def pipeline_stage(self): |
|
|
|
@@ -342,6 +357,10 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
""" |
|
|
|
return self._embed_dp_mp_config.dp_mp_config |
|
|
|
|
|
|
|
@property |
|
|
|
def moe_parallel_config(self): |
|
|
|
return self._moe_config |
|
|
|
|
|
|
|
|
|
|
|
default_transformer_config = TransformerOpParallelConfig() |
|
|
|
default_embedding_parallel_config = EmbeddingOpParallelConfig() |
|
|
|
@@ -370,9 +389,9 @@ class FeedForward(Cell): |
|
|
|
and the first dimension in BatchMatMul indicate expert_num. Default: 1. |
|
|
|
param_init_type: (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16. |
|
|
|
Default: dtype.float32. |
|
|
|
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. |
|
|
|
Default `default_dpmp_config`, an instance of `OpParallelConfig` with |
|
|
|
default args. |
|
|
|
parallel_config (OpParallelConfig, MoEParallelConfig): The config of parallel setting, see `OpParallelConfig` or |
|
|
|
`MoEParallelConfig`. When MoE is applied, MoEParallelConfig is effective, otherwise OpParallelConfig is |
|
|
|
effective. Default `default_dpmp_config`, an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. |
|
|
|
@@ -409,7 +428,7 @@ class FeedForward(Cell): |
|
|
|
hidden_act=_valid_type_checks([str], "FeedForward"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"FeedForward"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig], |
|
|
|
"FeedForward")) |
|
|
|
def __init__(self, hidden_size, |
|
|
|
ffn_hidden_size, |
|
|
|
@@ -422,6 +441,10 @@ class FeedForward(Cell): |
|
|
|
_check_config(parallel_config) |
|
|
|
dp = parallel_config.data_parallel |
|
|
|
mp = parallel_config.model_parallel |
|
|
|
if expert_num > 1: |
|
|
|
ep = parallel_config.expert_parallel |
|
|
|
else: |
|
|
|
ep = 1 |
|
|
|
if ffn_hidden_size % mp != 0: |
|
|
|
raise ValueError("For 'FeedForward', the class variable 'ffn_hidden_size' must be a multiple of the num of " |
|
|
|
"model parallel, but got the ffn_hidden_size is {} and the num of model parallel is {}." |
|
|
|
@@ -435,20 +458,20 @@ class FeedForward(Cell): |
|
|
|
"but got the value : {}.".format(dropout_rate)) |
|
|
|
input_size = hidden_size |
|
|
|
output_size = ffn_hidden_size |
|
|
|
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number. |
|
|
|
ep = dp |
|
|
|
|
|
|
|
# Project to ffn_hidden_size |
|
|
|
self.mapping = _Linear(in_channels=input_size, |
|
|
|
out_channels=output_size, |
|
|
|
activation=hidden_act, |
|
|
|
transpose_b=False, |
|
|
|
expert_num=expert_num, |
|
|
|
outer_batch=dp, |
|
|
|
param_init_type=param_init_type) |
|
|
|
|
|
|
|
if expert_num > 1: |
|
|
|
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)), |
|
|
|
strategy_bias=((ep, 1, mp), (mp,)), |
|
|
|
strategy_activation=((ep, 1, mp),)) |
|
|
|
self.mapping.shard(strategy_matmul=((dp, ep, 1, 1), (ep, 1, mp)), |
|
|
|
strategy_bias=((dp, ep, 1, mp), (mp,)), |
|
|
|
strategy_activation=((dp, ep, 1, mp),)) |
|
|
|
else: |
|
|
|
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)), |
|
|
|
strategy_bias=((dp, mp), (mp,)), |
|
|
|
@@ -458,10 +481,11 @@ class FeedForward(Cell): |
|
|
|
out_channels=input_size, |
|
|
|
transpose_b=False, |
|
|
|
expert_num=expert_num, |
|
|
|
outer_batch=dp, |
|
|
|
param_init_type=param_init_type) |
|
|
|
if expert_num > 1: |
|
|
|
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)), |
|
|
|
strategy_bias=((ep, 1, 1), (1,))) |
|
|
|
self.projection.shard(strategy_matmul=((dp, ep, 1, mp), (ep, mp, 1)), |
|
|
|
strategy_bias=((dp, ep, 1, 1), (1,))) |
|
|
|
else: |
|
|
|
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)), |
|
|
|
strategy_bias=((dp, 1), (1,))) |
|
|
|
@@ -470,6 +494,8 @@ class FeedForward(Cell): |
|
|
|
self.dropout.shard(((dp, 1),)) |
|
|
|
self.dropout_3d = _Dropout(1 - dropout_rate) |
|
|
|
self.dropout_3d.shard(((dp, 1, 1),)) |
|
|
|
self.dropout_4d = _Dropout(1 - dropout_rate) |
|
|
|
self.dropout_4d.shard(((dp, ep, 1, 1),)) |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
@@ -482,8 +508,10 @@ class FeedForward(Cell): |
|
|
|
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] |
|
|
|
if len(F.shape(output)) == 3: |
|
|
|
output = self.dropout_3d(output) |
|
|
|
else: |
|
|
|
elif len(F.shape(output)) == 2: |
|
|
|
output = self.dropout(output) |
|
|
|
else: |
|
|
|
output = self.dropout_4d(output) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@@ -1164,7 +1192,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
pass the single step's input tensor, and loop it. Default False. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with |
|
|
|
default values. Please see `MoEConfig`. |
|
|
|
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, |
|
|
|
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, |
|
|
|
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, |
|
|
|
an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -1253,7 +1282,7 @@ class TransformerEncoderLayer(Cell): |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig], |
|
|
|
"TransformerEncoderLayer"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, |
|
|
|
@@ -1287,6 +1316,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
"by the 'parallel_config.model_parallel', but got the ffn_hidden_size is {} " |
|
|
|
"and parallel_config. model_parallel is {}." |
|
|
|
.format(ffn_hidden_size, parallel_config.model_parallel)) |
|
|
|
_check_moe_config(moe_config, parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
self.use_past = use_past |
|
|
|
self.seq_length = seq_length |
|
|
|
self.hidden_size = hidden_size |
|
|
|
@@ -1295,20 +1326,30 @@ class TransformerEncoderLayer(Cell): |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
|
|
|
|
self.attention = MultiHeadAttention(batch_size=batch_size, |
|
|
|
src_seq_length=seq_length, |
|
|
|
tgt_seq_length=seq_length, |
|
|
|
hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
parallel_config=parallel_config) |
|
|
|
_check_moe_config(moe_config, parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
if self.use_moe is True: |
|
|
|
self.attention = MultiHeadAttention(batch_size=batch_size, |
|
|
|
src_seq_length=seq_length, |
|
|
|
tgt_seq_length=seq_length, |
|
|
|
hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
parallel_config=parallel_config.dpmp) |
|
|
|
else: |
|
|
|
self.attention = MultiHeadAttention(batch_size=batch_size, |
|
|
|
src_seq_length=seq_length, |
|
|
|
tgt_seq_length=seq_length, |
|
|
|
hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
parallel_config=parallel_config) |
|
|
|
if self.use_moe: |
|
|
|
self.output = MoE(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
@@ -1480,7 +1521,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with |
|
|
|
default values. Please see `MoEConfig`. |
|
|
|
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, |
|
|
|
parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, |
|
|
|
MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, |
|
|
|
an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
@@ -1553,7 +1595,7 @@ class TransformerDecoderLayer(Cell): |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig], |
|
|
|
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig], |
|
|
|
"TransformerDecoderLayer"), |
|
|
|
use_past=Validator.check_bool) |
|
|
|
def __init__(self, hidden_size, |
|
|
|
@@ -1588,6 +1630,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
"divisibled by 'parallel_config.model_parallel', but got the ffn_hidden_size is {} " |
|
|
|
"and parallel_config.model_parallel is {}." |
|
|
|
.format(ffn_hidden_size, parallel_config.model_parallel)) |
|
|
|
_check_moe_config(moe_config, parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
if use_past: |
|
|
|
raise ValueError(f"The {self.cls_name} does not support use_past=True.") |
|
|
|
self.batch_size = batch_size |
|
|
|
@@ -1603,35 +1647,59 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
|
|
|
|
self.attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
use_past=use_past, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
if self.use_moe is True: |
|
|
|
self.attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
use_past=use_past, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config.dpmp) |
|
|
|
else: |
|
|
|
self.attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
use_past=use_past, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
# Cross attention with the output of encoder as memory tensor |
|
|
|
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=src_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
if self.use_moe is True: |
|
|
|
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=src_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config.dpmp) |
|
|
|
else: |
|
|
|
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=src_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float( |
|
|
|
layernorm_compute_type) |
|
|
|
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
_check_moe_config(moe_config, parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
|
|
|
|
if self.use_moe: |
|
|
|
self.output = MoE(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
@@ -2019,21 +2087,38 @@ class TransformerEncoder(Cell): |
|
|
|
self.num_layers = num_layers |
|
|
|
self.blocks = nn.CellList() |
|
|
|
for i in range(num_layers): |
|
|
|
block = TransformerEncoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_act=hidden_act, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
if self.use_moe is True: |
|
|
|
block = TransformerEncoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_act=hidden_act, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.moe_parallel_config) |
|
|
|
else: |
|
|
|
block = TransformerEncoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
num_heads=num_heads, |
|
|
|
hidden_act=hidden_act, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
# If the user doesn't pass the fusion function, use the default one |
|
|
|
if not lambda_func: |
|
|
|
lambda_func = _get_lambda_func() |
|
|
|
@@ -2214,22 +2299,40 @@ class TransformerDecoder(Cell): |
|
|
|
_check_moe_config(moe_config, parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
for i in range(num_layers): |
|
|
|
block = TransformerDecoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
src_seq_length=src_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
num_heads=num_heads, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
if self.use_moe is True: |
|
|
|
block = TransformerDecoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
src_seq_length=src_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
num_heads=num_heads, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.moe_parallel_config) |
|
|
|
else: |
|
|
|
block = TransformerDecoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
src_seq_length=src_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
num_heads=num_heads, |
|
|
|
layernorm_compute_type=layernorm_compute_type, |
|
|
|
softmax_compute_type=softmax_compute_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
# If the user doesn't pass the fusion function, use the default one |
|
|
|
if not lambda_func: |
|
|
|
lambda_func = _get_lambda_func() |
|
|
|
|