|
|
|
@@ -26,12 +26,13 @@ from mindspore import context |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.ops import functional as F |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from mindspore import log as logger |
|
|
|
from .layers import _LayerNorm, _Linear, _check_input_shape, \ |
|
|
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \ |
|
|
|
_args_type_validator_check, _valid_type_checks, _valid_value_checks |
|
|
|
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \ |
|
|
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, Router |
|
|
|
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
@@ -44,10 +45,37 @@ __all__ = [ |
|
|
|
"TransformerEncoderLayer", |
|
|
|
"TransformerDecoderLayer", |
|
|
|
"Transformer", |
|
|
|
"MoEConfig", |
|
|
|
"TransformerOpParallelConfig", |
|
|
|
"EmbeddingOpParallelConfig"] |
|
|
|
|
|
|
|
|
|
|
|
class MoEConfig: |
|
|
|
r""" |
|
|
|
The configuration of MoE (Mixture of Expert). |
|
|
|
|
|
|
|
Args: |
|
|
|
expert_num (int): The number of experts employed. Default: 1 |
|
|
|
capacity_factor (float): The factor is used to indicate how much to expand expert capacity, |
|
|
|
which is >=1.0. Default: 1.1. |
|
|
|
aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the |
|
|
|
router) to be added to the entire model loss, which is < 1.0. Default: 0.05. |
|
|
|
num_experts_chosen (int): The number of experts is chosen by each token. Default: 1. |
|
|
|
noisy_policy (string): The noisy policy is used in routing tokens to experts. Default: None. |
|
|
|
noisy_epsilon (float): The parameter is used in adding noises in routing tokens to experts. Default: 1e-2. |
|
|
|
""" |
|
|
|
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, |
|
|
|
num_experts_chosen=1, noisy_policy=None, noisy_epsilon=1e-2): |
|
|
|
self.expert_num = expert_num |
|
|
|
self.capacity_factor = capacity_factor |
|
|
|
self.aux_loss_factor = aux_loss_factor |
|
|
|
self.num_experts_chosen = num_experts_chosen |
|
|
|
self.noisy_policy = noisy_policy |
|
|
|
self.noisy_epsilon = noisy_epsilon |
|
|
|
|
|
|
|
default_moe_config = MoEConfig() |
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingOpParallelConfig(_Config): |
|
|
|
r""" |
|
|
|
EmbeddingOpParallelConfig for the setting the data parallel or row slice for the embedding table. |
|
|
|
@@ -265,6 +293,8 @@ class FeedForward(Cell): |
|
|
|
hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', |
|
|
|
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', |
|
|
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu. |
|
|
|
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used |
|
|
|
and the first dimension in BatchMatMul indicate expert_num. Default: 1. |
|
|
|
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. |
|
|
|
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. |
|
|
|
Default `default_dpmp_config`, a instance of `OpParallelConfig` with default |
|
|
|
@@ -305,6 +335,7 @@ class FeedForward(Cell): |
|
|
|
ffn_hidden_size, |
|
|
|
dropout_rate, |
|
|
|
hidden_act='gelu', |
|
|
|
expert_num=1, |
|
|
|
param_init_type=mstype.float32, |
|
|
|
parallel_config=default_dpmp_config): |
|
|
|
super(FeedForward, self).__init__() |
|
|
|
@@ -320,22 +351,35 @@ class FeedForward(Cell): |
|
|
|
"but got {}".format(dropout_rate)) |
|
|
|
input_size = hidden_size |
|
|
|
output_size = ffn_hidden_size |
|
|
|
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number. |
|
|
|
ep = dp |
|
|
|
# Project to ffn_hidden_size |
|
|
|
self.mapping = _Linear(in_channels=input_size, |
|
|
|
out_channels=output_size, |
|
|
|
activation=hidden_act, |
|
|
|
transpose_b=False, |
|
|
|
expert_num=expert_num, |
|
|
|
param_init_type=param_init_type) |
|
|
|
self.mapping.shard(strategy_bias=((dp, mp), (mp,)), |
|
|
|
strategy_matmul=((dp, 1), (1, mp)), |
|
|
|
strategy_activation=((dp, 1, mp),)) |
|
|
|
if expert_num > 1: |
|
|
|
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)), |
|
|
|
strategy_bias=((ep, 1, mp), (mp,)), |
|
|
|
strategy_activation=((ep, 1, mp),)) |
|
|
|
else: |
|
|
|
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)), |
|
|
|
strategy_bias=((dp, mp), (mp,)), |
|
|
|
strategy_activation=((dp, 1, mp),)) |
|
|
|
# Project back to embedding_size |
|
|
|
self.projection = _Linear(in_channels=output_size, |
|
|
|
out_channels=input_size, |
|
|
|
transpose_b=False, |
|
|
|
expert_num=expert_num, |
|
|
|
param_init_type=param_init_type) |
|
|
|
self.projection.shard(strategy_bias=((dp, 1), (1,)), |
|
|
|
strategy_matmul=((dp, mp), (mp, 1))) |
|
|
|
if expert_num > 1: |
|
|
|
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)), |
|
|
|
strategy_bias=((ep, 1, 1), (1,))) |
|
|
|
else: |
|
|
|
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)), |
|
|
|
strategy_bias=((dp, 1), (1,))) |
|
|
|
self.projection.bias.parallel_optimizer = False |
|
|
|
self.dropout = nn.Dropout(1 - dropout_rate) |
|
|
|
self.dropout.dropout.shard(((dp, 1, 1),)) |
|
|
|
@@ -352,6 +396,124 @@ class FeedForward(Cell): |
|
|
|
output = self.dropout(output) |
|
|
|
return output |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim): |
|
|
|
return math.ceil(k * tokens_per_device * capacity_factor / expert_dim) |
|
|
|
|
|
|
|
|
|
|
|
class MoE(Cell): |
|
|
|
""" |
|
|
|
The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer. |
|
|
|
The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is |
|
|
|
obtained by multiplying FeedForward's output and router's combine weight. |
|
|
|
|
|
|
|
Args: |
|
|
|
hidden_size (int): The dimension of the inputs. |
|
|
|
ffn_hidden_size (int): The intermediate hidden size. |
|
|
|
dropout_rate (float): The dropout rate for the second linear's output. |
|
|
|
hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', |
|
|
|
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', |
|
|
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu. |
|
|
|
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). |
|
|
|
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. |
|
|
|
Default `default_dpmp_config`, a instance of `OpParallelConfig` with default |
|
|
|
args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. |
|
|
|
""" |
|
|
|
def __init__(self, hidden_size, |
|
|
|
ffn_hidden_size, |
|
|
|
dropout_rate, |
|
|
|
hidden_act='gelu', |
|
|
|
param_init_type=mstype.float32, |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_dpmp_config): |
|
|
|
super(MoE, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.expert_dim = moe_config.expert_num |
|
|
|
self.capacity_factor = moe_config.capacity_factor |
|
|
|
self.aux_loss_factor = moe_config.aux_loss_factor |
|
|
|
self.num_experts_chosen = moe_config.num_experts_chosen |
|
|
|
self.expert_parallel = parallel_config.data_parallel |
|
|
|
self.dp = parallel_config.data_parallel |
|
|
|
|
|
|
|
self.ffn = FeedForward(hidden_size=hidden_size, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
dropout_rate=dropout_rate, |
|
|
|
hidden_act=hidden_act, |
|
|
|
expert_num=self.expert_dim, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = P.Shape() |
|
|
|
self.transpose = P.Transpose().shard(((self.dp, 1, 1),)) |
|
|
|
self.transpose2 = P.Transpose().shard(((self.dp, 1, 1, 1),)) |
|
|
|
self.transpose3 = P.Transpose().shard(((self.dp, 1, 1, 1),)) |
|
|
|
self.transpose4 = P.Transpose().shard(((self.dp, 1, 1),)) |
|
|
|
self.transpose5 = P.Transpose().shard(((self.dp, 1, 1),)) |
|
|
|
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) |
|
|
|
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) |
|
|
|
self.mul = P.Mul().shard(((), ())) |
|
|
|
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, |
|
|
|
training=True, parallel_config=parallel_config) |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, input_tensor): |
|
|
|
bs = self.shape(input_tensor)[0] |
|
|
|
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) |
|
|
|
bs_and_dmodel = self.shape(input_tensor) |
|
|
|
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel |
|
|
|
input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size)) |
|
|
|
|
|
|
|
expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device, |
|
|
|
self.capacity_factor, self.expert_dim) |
|
|
|
# dispatch_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) |
|
|
|
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) |
|
|
|
dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor) |
|
|
|
|
|
|
|
# after transpose, input_tensor's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) |
|
|
|
input_tensor = self.transpose(input_tensor, (0, 2, 1)) |
|
|
|
dispatch_tensor = self.reshape(dispatch_tensor, (self.expert_parallel, tokens_per_device, |
|
|
|
self.expert_dim * expert_capacity)) |
|
|
|
dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor)) |
|
|
|
# expert_input's shape: (self.expert_parallel, self.hidden_size, self.expert_dim * expert_capacity) |
|
|
|
expert_input = self.batch_mm(input_tensor, dispatch_tensor) |
|
|
|
expert_input = self.reshape(expert_input, (self.expert_parallel, self.hidden_size, self.expert_dim, |
|
|
|
expert_capacity)) |
|
|
|
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size) |
|
|
|
expert_input = self.transpose2(expert_input, (2, 0, 3, 1)) |
|
|
|
expert_input = self.reshape(expert_input, (self.expert_dim, self.expert_parallel * expert_capacity, |
|
|
|
self.hidden_size)) |
|
|
|
|
|
|
|
# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size) |
|
|
|
expert_output = self.ffn(expert_input) |
|
|
|
expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel, |
|
|
|
expert_capacity, self.hidden_size)) |
|
|
|
# expert_output's shape: (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity) |
|
|
|
expert_output = self.transpose3(expert_output, (1, 3, 0, 2)) |
|
|
|
expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size, |
|
|
|
self.expert_dim*expert_capacity)) |
|
|
|
combine_tensor = self.reshape(combine_tensor, (self.expert_parallel, tokens_per_device, |
|
|
|
self.expert_dim*expert_capacity)) |
|
|
|
# combine_tensor's shape: (self.expert_parallel, self.expert_dim*expert_capacity, tokens_per_device) |
|
|
|
combine_tensor = self.transpose4(combine_tensor, (0, 2, 1)) |
|
|
|
combine_tensor = self.cast(combine_tensor, F.dtype(expert_output)) |
|
|
|
|
|
|
|
# combined_output's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) |
|
|
|
combined_output = self.batch_mm2(expert_output, combine_tensor) |
|
|
|
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size) |
|
|
|
combined_output = self.transpose5(combined_output, (0, 2, 1)) |
|
|
|
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1])) |
|
|
|
combined_output = self.reshape(combined_output, (bs, -1, self.hidden_size)) |
|
|
|
|
|
|
|
aux_loss = self.mul(self.aux_loss_factor, aux_loss) |
|
|
|
return combined_output, aux_loss |
|
|
|
|
|
|
|
class AttentionMask(Cell): |
|
|
|
r""" |
|
|
|
@@ -903,6 +1065,7 @@ class TransformerEncoderLayer(Cell): |
|
|
|
param_init_type(dtype.Number): The parameter initialization type of the module. |
|
|
|
Can be dtype.float32 or dtype.float16. Default dtype.float32. |
|
|
|
use_past(bool): Use the past state to compute, used for incremental prediction. Default False. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). |
|
|
|
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, |
|
|
|
a instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
@@ -973,6 +1136,7 @@ class TransformerEncoderLayer(Cell): |
|
|
|
param_init_type=mstype.float32, |
|
|
|
hidden_act='gelu', |
|
|
|
use_past=False, |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_dpmp_config): |
|
|
|
super(TransformerEncoderLayer, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
@@ -1000,13 +1164,23 @@ class TransformerEncoderLayer(Cell): |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
parallel_config=parallel_config) |
|
|
|
# Feed Forward Network, FFN |
|
|
|
self.output = FeedForward(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
param_init_type=param_init_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
if self.use_moe is True: |
|
|
|
self.output = MoE(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
param_init_type=param_init_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config) |
|
|
|
else: |
|
|
|
# Feed Forward Network, FFN |
|
|
|
self.output = FeedForward(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
param_init_type=param_init_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
@@ -1056,7 +1230,11 @@ class TransformerEncoderLayer(Cell): |
|
|
|
|
|
|
|
output_x = self.layernorm2(x) |
|
|
|
output_x = F.cast(output_x, self.dtype) |
|
|
|
mlp_logit = self.output(output_x) |
|
|
|
aux_loss = None |
|
|
|
if self.use_moe is True: |
|
|
|
mlp_logit, aux_loss = self.output(output_x) |
|
|
|
else: |
|
|
|
mlp_logit = self.output(output_x) |
|
|
|
|
|
|
|
value_update = None |
|
|
|
key_update = None |
|
|
|
@@ -1078,6 +1256,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
if self.use_moe is True: |
|
|
|
return output, layer_present, aux_loss |
|
|
|
return output, layer_present |
|
|
|
|
|
|
|
def _check_input(self, x, input_mask, init_reset, batch_valid_length): |
|
|
|
@@ -1123,6 +1303,7 @@ class TransformerDecoderLayer(Cell): |
|
|
|
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16. |
|
|
|
Default dtype.float32. |
|
|
|
use_past(bool): Use the past state to compute, used for incremental prediction. Default False. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). |
|
|
|
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`, |
|
|
|
a instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
@@ -1204,6 +1385,7 @@ class TransformerDecoderLayer(Cell): |
|
|
|
softmax_comptue_type=mstype.float32, |
|
|
|
param_init_type=mstype.float32, |
|
|
|
hidden_act='gelu', |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_dpmp_config): |
|
|
|
super(TransformerDecoderLayer, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
@@ -1247,14 +1429,23 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float( |
|
|
|
layernorm_compute_type) |
|
|
|
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
|
|
|
|
# Feed Forward Network, FFN |
|
|
|
self.output = FeedForward(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
hidden_act=hidden_act, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
if self.use_moe is True: |
|
|
|
self.output = MoE(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
param_init_type=param_init_type, |
|
|
|
hidden_act=hidden_act, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config) |
|
|
|
else: |
|
|
|
# Feed Forward Network, FFN |
|
|
|
self.output = FeedForward(hidden_size=hidden_size, |
|
|
|
dropout_rate=hidden_dropout_rate, |
|
|
|
ffn_hidden_size=ffn_hidden_size, |
|
|
|
hidden_act=hidden_act, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
@@ -1321,7 +1512,11 @@ class TransformerDecoderLayer(Cell): |
|
|
|
|
|
|
|
output_x = self.layernorm2(x) |
|
|
|
output_x = F.cast(output_x, self.dtype) |
|
|
|
mlp_logit = self.output(output_x) |
|
|
|
aux_loss = None |
|
|
|
if self.use_moe is True: |
|
|
|
mlp_logit, aux_loss = self.output(output_x) |
|
|
|
else: |
|
|
|
mlp_logit = self.output(output_x) |
|
|
|
|
|
|
|
value_update = None |
|
|
|
key_update = None |
|
|
|
@@ -1343,6 +1538,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
if self.use_moe is True: |
|
|
|
return output, layer_present, aux_loss |
|
|
|
return output, layer_present |
|
|
|
|
|
|
|
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length): |
|
|
|
@@ -1447,6 +1644,7 @@ class TransformerEncoder(Cell): |
|
|
|
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. |
|
|
|
offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not |
|
|
|
overlap with the encoder layer. |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). |
|
|
|
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`, |
|
|
|
a instance of `TransformerOpParallelConfig` with default args. |
|
|
|
|
|
|
|
@@ -1523,10 +1721,14 @@ class TransformerEncoder(Cell): |
|
|
|
lambda_func=None, |
|
|
|
offset=0, |
|
|
|
use_past=False, |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_transformer_config): |
|
|
|
super(TransformerEncoder, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
|
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
self.add = P.TensorAdd().shard(((), ())) |
|
|
|
self.aux_loss = Tensor(0.0, mstype.float32) |
|
|
|
self.num_layers = num_layers |
|
|
|
self.blocks = nn.CellList() |
|
|
|
for i in range(num_layers): |
|
|
|
@@ -1543,6 +1745,7 @@ class TransformerEncoder(Cell): |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
param_init_type=param_init_type, |
|
|
|
use_past=use_past, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
# If the user doesn't pass the fusion function, use the default one |
|
|
|
if not lambda_func: |
|
|
|
@@ -1554,6 +1757,17 @@ class TransformerEncoder(Cell): |
|
|
|
|
|
|
|
def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None): |
|
|
|
present_layer = () |
|
|
|
if self.use_moe is True: |
|
|
|
accum_loss = self.aux_loss |
|
|
|
for i in range(self.num_layers): |
|
|
|
hidden_states, present, aux_loss = self.blocks[i](hidden_states, |
|
|
|
attention_mask, |
|
|
|
init_reset, |
|
|
|
batch_valid_length) |
|
|
|
present_layer = present_layer + (present,) |
|
|
|
accum_loss = self.add(accum_loss, aux_loss) |
|
|
|
return hidden_states, present_layer, accum_loss |
|
|
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
hidden_states, present = self.blocks[i](hidden_states, |
|
|
|
attention_mask, |
|
|
|
@@ -1597,6 +1811,7 @@ class TransformerDecoder(Cell): |
|
|
|
zero, `offset(int)` means the layer_index needs a offset, if there are other modules in the net. The |
|
|
|
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`. |
|
|
|
Default: None |
|
|
|
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). |
|
|
|
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`, |
|
|
|
a instance of `TransformerOpParallelConfig` with default args. |
|
|
|
|
|
|
|
@@ -1686,12 +1901,16 @@ class TransformerDecoder(Cell): |
|
|
|
lambda_func=None, |
|
|
|
use_past=False, |
|
|
|
offset=0, |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_transformer_config): |
|
|
|
super(TransformerDecoder, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
|
|
|
|
self.add = P.TensorAdd().shard(((), ())) |
|
|
|
self.aux_loss = Tensor(0.0, mstype.float32) |
|
|
|
self.num_layers = num_layers |
|
|
|
self.blocks = nn.CellList() |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
for i in range(num_layers): |
|
|
|
block = TransformerDecoderLayer(hidden_size=hidden_size, |
|
|
|
batch_size=batch_size, |
|
|
|
@@ -1707,6 +1926,7 @@ class TransformerDecoder(Cell): |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
post_layernorm_residual=post_layernorm_residual, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config.dp_mp_config) |
|
|
|
# If the user doesn't pass the fusion function, use the default one |
|
|
|
if not lambda_func: |
|
|
|
@@ -1720,6 +1940,19 @@ class TransformerDecoder(Cell): |
|
|
|
def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None, |
|
|
|
init_reset=True, batch_valid_length=None): |
|
|
|
present_layer = () |
|
|
|
if self.use_moe is True: |
|
|
|
accum_loss = self.aux_loss |
|
|
|
for i in range(self.num_layers): |
|
|
|
hidden_states, present, aux_loss = self.blocks[i](hidden_states, |
|
|
|
attention_mask, |
|
|
|
encoder_output, |
|
|
|
memory_mask, |
|
|
|
init_reset, |
|
|
|
batch_valid_length) |
|
|
|
present_layer = present_layer + (present,) |
|
|
|
accum_loss = self.add(accum_loss, aux_loss) |
|
|
|
return hidden_states, present_layer, accum_loss |
|
|
|
|
|
|
|
# Loop through each self-attention layer |
|
|
|
for i in range(self.num_layers): |
|
|
|
hidden_states, present = self.blocks[i](hidden_states, |
|
|
|
@@ -1871,6 +2104,7 @@ class Transformer(Cell): |
|
|
|
param_init_type=mstype.float32, |
|
|
|
lambda_func=None, |
|
|
|
use_past=False, |
|
|
|
moe_config=default_moe_config, |
|
|
|
parallel_config=default_transformer_config): |
|
|
|
super(Transformer, self).__init__() |
|
|
|
_check_config(parallel_config) |
|
|
|
@@ -1888,6 +2122,9 @@ class Transformer(Cell): |
|
|
|
if not lambda_func: |
|
|
|
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers) |
|
|
|
|
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
self.add = P.TensorAdd().shard(((), ())) |
|
|
|
self.aux_loss = Tensor(0.0, mstype.float32) |
|
|
|
if encoder_layers > 0: |
|
|
|
self.encoder = TransformerEncoder(num_layers=encoder_layers, |
|
|
|
batch_size=batch_size, |
|
|
|
@@ -1904,6 +2141,7 @@ class Transformer(Cell): |
|
|
|
param_init_type=param_init_type, |
|
|
|
lambda_func=lambda_func, |
|
|
|
use_past=use_past, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config) |
|
|
|
else: |
|
|
|
self.encoder = None |
|
|
|
@@ -1929,6 +2167,7 @@ class Transformer(Cell): |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
offset=encoder_layers, |
|
|
|
moe_config=moe_config, |
|
|
|
parallel_config=parallel_config) |
|
|
|
|
|
|
|
def construct(self, encoder_inputs, |
|
|
|
@@ -1943,17 +2182,31 @@ class Transformer(Cell): |
|
|
|
output = None |
|
|
|
encoder_layer_present = None |
|
|
|
decoder_layer_present = None |
|
|
|
accum_loss = self.aux_loss |
|
|
|
if self.encoder is not None: |
|
|
|
encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset, |
|
|
|
batch_valid_length) |
|
|
|
if self.use_moe is True: |
|
|
|
encoder_output, encoder_layer_present, encoder_aux_loss = self.encoder(encoder_inputs, encoder_masks, |
|
|
|
init_reset, batch_valid_length) |
|
|
|
accum_loss = self.add(accum_loss, encoder_aux_loss) |
|
|
|
else: |
|
|
|
encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset, |
|
|
|
batch_valid_length) |
|
|
|
output = encoder_output |
|
|
|
|
|
|
|
if self.decoder is not None: |
|
|
|
# decoder mask can be created outside of the model |
|
|
|
decoder_output, decoder_layer_present = self.decoder(decoder_inputs, |
|
|
|
decoder_masks, |
|
|
|
encoder_output, |
|
|
|
memory_mask, init_reset, |
|
|
|
batch_valid_length) |
|
|
|
if self.use_moe is True: |
|
|
|
decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks, |
|
|
|
encoder_output, memory_mask, |
|
|
|
init_reset, batch_valid_length) |
|
|
|
accum_loss = self.add(accum_loss, decoder_aux_loss) |
|
|
|
else: |
|
|
|
decoder_output, decoder_layer_present = self.decoder(decoder_inputs, |
|
|
|
decoder_masks, |
|
|
|
encoder_output, |
|
|
|
memory_mask, init_reset, |
|
|
|
batch_valid_length) |
|
|
|
output = decoder_output |
|
|
|
if self.use_moe is True: |
|
|
|
return output, encoder_layer_present, decoder_layer_present, accum_loss |
|
|
|
return output, encoder_layer_present, decoder_layer_present |