Browse Source

!22434 [MoE] Adding a MoE implementation

Merge pull request !22434 from Xiaoda/84-add-moe-layer
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
e37b7374ae
3 changed files with 655 additions and 44 deletions
  1. +276
    -13
      mindspore/parallel/nn/layers.py
  2. +284
    -31
      mindspore/parallel/nn/transformer.py
  3. +95
    -0
      tests/ut/python/parallel/test_parallel_moe.py

+ 276
- 13
mindspore/parallel/nn/layers.py View File

@@ -28,9 +28,10 @@ from mindspore._extends import cell_attr_register
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
import mindspore.nn as nn
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import functional as F
from mindspore._checkparam import Validator
from mindspore.ops.primitive import constexpr
from mindspore.ops.primitive import constexpr, Primitive
from .op_parallel_config import default_dpmp_config, OpParallelConfig

__all__ = [
@@ -199,7 +200,7 @@ class _LayerNorm(Cell):
return self


class _Linear(Dense):
class _Linear(Cell):
r"""
The dense connected layer. Once the parallel mode is enabled, the input shape should be
3-D tensor.
@@ -224,6 +225,8 @@ class _Linear(Dense):
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): activate function applied to the output of the fully connected layer,
eg. 'ReLU'.Default: None.
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
compute_dtype (mstype): The computation type. Default: mstype.float16
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
@@ -254,14 +257,12 @@ class _Linear(Dense):
has_bias=True,
activation=None,
transpose_b=True,
expert_num=1,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(_Linear, self).__init__(in_channels=in_channels,
out_channels=out_channels,
weight_init=weight_init,
bias_init=bias_init,
has_bias=has_bias,
activation=activation)
super(_Linear, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
if param_init_type not in [mstype.float32, mstype.float16]:
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
if activation and not isinstance(activation, str):
@@ -274,26 +275,40 @@ class _Linear(Dense):
weight_shape = [out_channels, in_channels]
else:
weight_shape = [in_channels, out_channels]
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
self.matmul = P.MatMul(transpose_b=transpose_b)
self.expert_num = expert_num
if self.expert_num > 1:
self.expert_flag = True
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape), name="weight")
self.matmul = P.BatchMatMul(transpose_b=transpose_b)
else:
self.expert_flag = False
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
self.matmul = P.MatMul(transpose_b=transpose_b)
self.bias = None
self.has_bias = has_bias
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
self.bias_add = P.BiasAdd()
self.bias_add = P.TensorAdd()
self.act_name = activation
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None
self.dtype = compute_dtype
self.cast = P.Cast()
self.has_bias = self.has_bias

def construct(self, x):
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
x = P.Reshape()(x, (-1, self.in_channels))
if self.expert_flag is True:
x = P.Reshape()(x, (self.expert_num, -1, self.in_channels))
weight = self.cast(self.weight, self.dtype)
x = self.matmul(x, weight)
x = self.bias_add(x, self.cast(self.bias, self.dtype))
if self.has_bias:
x = self.bias_add(x, self.cast(self.bias, self.dtype))
output = P.Reshape()(x, out_shape)
if self.activation_flag:
output = self.activation(output)
@@ -549,3 +564,251 @@ class FixedSparseAttention(nn.Cell):
(-1, self.seq_length, self.size_per_head * self.num_heads))

return attention_merge


class _CumSum(Cell):
r"""
A layer used to calculate cumulative summation of a tensor along a dimension.

Inputs:
- **expert_mask** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
expert\_dim)`.

Outputs:
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
"""
def __init__(self, config):
super(_CumSum, self).__init__()
dp = config.data_parallel
self.range = P.Range().shard(((1,),))
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(((dp, 1), (1, 1)))
self.shape = P.Shape()
self.cast = P.Cast()

self.transpose = P.Transpose().shard(((dp, 1, 1),))
self.transpose2 = P.Transpose().shard(((1, 1),))
self.transpose3 = P.Transpose().shard(((dp, 1, 1),))
self.expand = P.ExpandDims().shard(((1,),))
self.greater = P.Greater().shard(((1, 1), (1, 1)))

self.start = Tensor(0, mstype.int32)
self.limit = Tensor(0, mstype.int32)
self.delta = Tensor(1, mstype.int32)
self.add = P.TensorAdd().shard(((1,), ()))


def construct(self, expert_mask):
# origin_shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
origin_shape = self.shape(expert_mask)
tokens_per_device = origin_shape[1]
# expert_mask_trans's shape: (self.expert_parallel, self.expert_dim, tokens_per_device)
expert_mask_trans = self.transpose(expert_mask, (0, 2, 1))
# expert_mask_reshaped's shape: (self.expert_parallel*self.expert_dim, tokens_per_device)
expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device))

one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0)
other_dim = self.transpose2(one_dim, (1, 0))
# up_tri_matrix's shape: (tokens_per_device, tokens_per_device)
up_tri_matrix = self.greater(one_dim, other_dim)
up_tri_matrix = self.cast(up_tri_matrix, mstype.float32)

# cum_sum's shape: (self.expert_parallel*self.expert_dim, tokens_per_device)
cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix)
# cum_sum's shape: (self.expert_parallel, self.expert_dim, tokens_per_device)
cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device))
# cum_sum's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
cum_sum = self.transpose3(cum_sum, (0, 2, 1))
return cum_sum


@constexpr
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim):
return math.ceil(k * tokens_per_device * capacity_factor / expert_dim)


class Router(Cell):
r"""
A router backbone used to calculate logits of each token, which should be cascaded by router implementations
mapping tokens to experts.

Args:
d_model (int): The hidden size of each token.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
routing_policy: The policy of mapping tokens to experts. Default: SwitchRouter
training (bool): The value indicating whether is in training phase.
parallel_config: The parallel-related configuration.
Inputs:
- **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
hidden\_size)`.

Outputs:
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
"""
def __init__(self,
d_model,
moe_config,
routing_policy=None,
training=True,
parallel_config=None):
super(Router, self).__init__()
dp = parallel_config.data_parallel
self.d_model = d_model
self.expert_dim = moe_config.expert_num
self.capacity_factor = moe_config.capacity_factor
self.training = training
self.routing_policy = routing_policy
self.noisy_policy = moe_config.noisy_policy # candidate: ["jitter", "rsample", "None"]
self.noisy_epsilon = moe_config.noisy_epsilon
self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))

self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
self.dense.matmul.shard(((dp, 1), (1, 1)))
self.mul = P.Mul().shard(((dp, 1, 1), (dp,)))
self.cast = P.Cast()

if self.routing_policy is None:
self.router = SwitchRouter(d_model=d_model, moe_config=moe_config, training=training,
parallel_config=parallel_config)
else:
self.router = routing_policy

def construct(self, input_tensor):
input_tensor = self.cast(input_tensor, mstype.float32)
if self.noisy_policy == "jitter" and self.training is True:
# Here, we temporarily implement the multiplicative jitter this way,
# for the lack of UniforReal parallel operator.
input_tensor = self.mul(input_tensor, self.noise)

router_logits = self.dense(input_tensor)
return self.router(router_logits)


class SwitchRouter(Cell):
r"""
A router implementation which maps each tokens to the top1 expert.
Reference: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py

Args:
d_model (int): The hidden size of each token.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
training (bool): The value indicating whether is in training phase.
config: The parallel-related configuration.
Inputs:
- **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
hidden\_size)`.

Outputs:
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
Tensor of shape :math:`(1)`.
"""
def __init__(self,
d_model,
moe_config,
training=True,
parallel_config=None):
super(SwitchRouter, self).__init__()
dp = parallel_config.data_parallel
self.d_model = d_model
self.expert_dim = moe_config.expert_num
self.capacity_factor = moe_config.capacity_factor
self.training = training
self.expert_parallel = dp
self.noisy_policy = moe_config.noisy_policy
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),))
self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),))

self.onehot = P.OneHot().shard(((dp, 1, 1), (), ()))
self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ()))
self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ()))
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)

self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
self.mul2 = P.Mul().shard(((1,), ()))
self.mul3 = P.Mul().shard(((1,), ()))
self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))

self.cumsum = _CumSum(config=parallel_config)
self.less = P.Less().shard(((dp, 1, 1), ()))
self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
self.expand = P.ExpandDims().shard(((dp, 1),))
self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))

def _auxiliary_loss(self, expert_mask, router_prob):
"""
Computing the load balance loss.
"""
# density_1's shape: (self.expert_parallel, self.expert_dim)
density_1 = self.reduce_mean(expert_mask, 1)
# density_1_proxy's shape: (self.expert_parallel, self.expert_dim)
density_1_proxy = self.reduce_mean2(router_prob, 1)
loss = self.mul(density_1, density_1_proxy)
loss = self.reduce_mean3(loss)
loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim)
return loss

def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate):
"""
Keeping only the tokens that fit within expert_capacity.
"""
cumsum = self.cumsum(expert_mask)
# position_in_expert's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
position_in_expert = self.mul4(cumsum, expert_mask)
less_result = self.less(position_in_expert, expert_capacity)
# expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
expert_mask = self.mul5(less_result, expert_mask)
# expert_mask_flat's shape: (self.expert_parallel, tokens_per_device)
expert_mask_flat = self.reduce_sum(expert_mask, -1)

# Mask out the experts that have overflowed the expert_capacity.
# expert_gate's shape: (self.expert_parallel, tokens_per_device)
expert_gate = self.mul6(expert_gate, expert_mask_flat)
return expert_gate, expert_mask_flat, position_in_expert


def construct(self, router_logits):
router_logits_shape = self.shape(router_logits)
router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
logits_shape = self.shape(router_logits)
tokens_per_device = logits_shape[0] / self.expert_parallel
expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim)
router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim))
# Currently, lack of gumbel sampler for router_logits.

# Probabilities for each token of what expert is should be sent to
router_prob = self.softmax(router_logits)
# shape: (self.expert_parallel, tokens_per_device)
expert_index, expert_gate = self.argmax(router_prob)
# expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)

# Computing the load balance loss:
loss = self._auxiliary_loss(expert_mask, router_prob)

expert_gate, expert_mask_flat, position_in_expert = \
self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate)

# combine_tensor's shape: (self.expert_parallel, tokens_per_device)
combine_tensor = self.mul7(expert_gate, expert_mask_flat)
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
combine_tensor = self.mul8(self.expand(combine_tensor, -1),
self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity)
combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
self.on_value, self.off_value))
dispatch_tensor = self.cast(combine_tensor, mstype.bool_)
return dispatch_tensor, combine_tensor, loss

+ 284
- 31
mindspore/parallel/nn/transformer.py View File

@@ -26,12 +26,13 @@ from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator
from mindspore import log as logger
from .layers import _LayerNorm, _Linear, _check_input_shape, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
_args_type_validator_check, _valid_type_checks, _valid_value_checks
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, Router
from .op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, _Config, _check_config

__all__ = [
@@ -44,10 +45,37 @@ __all__ = [
"TransformerEncoderLayer",
"TransformerDecoderLayer",
"Transformer",
"MoEConfig",
"TransformerOpParallelConfig",
"EmbeddingOpParallelConfig"]


class MoEConfig:
r"""
The configuration of MoE (Mixture of Expert).

Args:
expert_num (int): The number of experts employed. Default: 1
capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
which is >=1.0. Default: 1.1.
aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
num_experts_chosen (int): The number of experts is chosen by each token. Default: 1.
noisy_policy (string): The noisy policy is used in routing tokens to experts. Default: None.
noisy_epsilon (float): The parameter is used in adding noises in routing tokens to experts. Default: 1e-2.
"""
def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05,
num_experts_chosen=1, noisy_policy=None, noisy_epsilon=1e-2):
self.expert_num = expert_num
self.capacity_factor = capacity_factor
self.aux_loss_factor = aux_loss_factor
self.num_experts_chosen = num_experts_chosen
self.noisy_policy = noisy_policy
self.noisy_epsilon = noisy_epsilon

default_moe_config = MoEConfig()


class EmbeddingOpParallelConfig(_Config):
r"""
EmbeddingOpParallelConfig for the setting the data parallel or row slice for the embedding table.
@@ -265,6 +293,8 @@ class FeedForward(Cell):
hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, a instance of `OpParallelConfig` with default
@@ -305,6 +335,7 @@ class FeedForward(Cell):
ffn_hidden_size,
dropout_rate,
hidden_act='gelu',
expert_num=1,
param_init_type=mstype.float32,
parallel_config=default_dpmp_config):
super(FeedForward, self).__init__()
@@ -320,22 +351,35 @@ class FeedForward(Cell):
"but got {}".format(dropout_rate))
input_size = hidden_size
output_size = ffn_hidden_size
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
ep = dp
# Project to ffn_hidden_size
self.mapping = _Linear(in_channels=input_size,
out_channels=output_size,
activation=hidden_act,
transpose_b=False,
expert_num=expert_num,
param_init_type=param_init_type)
self.mapping.shard(strategy_bias=((dp, mp), (mp,)),
strategy_matmul=((dp, 1), (1, mp)),
strategy_activation=((dp, 1, mp),))
if expert_num > 1:
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
strategy_bias=((ep, 1, mp), (mp,)),
strategy_activation=((ep, 1, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
strategy_activation=((dp, 1, mp),))
# Project back to embedding_size
self.projection = _Linear(in_channels=output_size,
out_channels=input_size,
transpose_b=False,
expert_num=expert_num,
param_init_type=param_init_type)
self.projection.shard(strategy_bias=((dp, 1), (1,)),
strategy_matmul=((dp, mp), (mp, 1)))
if expert_num > 1:
self.projection.shard(strategy_matmul=((ep, 1, mp), (ep, mp, 1)),
strategy_bias=((ep, 1, 1), (1,)))
else:
self.projection.shard(strategy_matmul=((dp, mp), (mp, 1)),
strategy_bias=((dp, 1), (1,)))
self.projection.bias.parallel_optimizer = False
self.dropout = nn.Dropout(1 - dropout_rate)
self.dropout.dropout.shard(((dp, 1, 1),))
@@ -352,6 +396,124 @@ class FeedForward(Cell):
output = self.dropout(output)
return output

@constexpr
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim):
return math.ceil(k * tokens_per_device * capacity_factor / expert_dim)


class MoE(Cell):
"""
The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer.
The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is
obtained by multiplying FeedForward's output and router's combine weight.

Args:
hidden_size (int): The dimension of the inputs.
ffn_hidden_size (int): The intermediate hidden size.
dropout_rate (float): The dropout rate for the second linear's output.
hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, a instance of `OpParallelConfig` with default
args.

Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.

Outputs:
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
"""
def __init__(self, hidden_size,
ffn_hidden_size,
dropout_rate,
hidden_act='gelu',
param_init_type=mstype.float32,
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
super(MoE, self).__init__()
self.hidden_size = hidden_size
self.expert_dim = moe_config.expert_num
self.capacity_factor = moe_config.capacity_factor
self.aux_loss_factor = moe_config.aux_loss_factor
self.num_experts_chosen = moe_config.num_experts_chosen
self.expert_parallel = parallel_config.data_parallel
self.dp = parallel_config.data_parallel

self.ffn = FeedForward(hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
dropout_rate=dropout_rate,
hidden_act=hidden_act,
expert_num=self.expert_dim,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.reshape = P.Reshape()
self.shape = P.Shape()
self.transpose = P.Transpose().shard(((self.dp, 1, 1),))
self.transpose2 = P.Transpose().shard(((self.dp, 1, 1, 1),))
self.transpose3 = P.Transpose().shard(((self.dp, 1, 1, 1),))
self.transpose4 = P.Transpose().shard(((self.dp, 1, 1),))
self.transpose5 = P.Transpose().shard(((self.dp, 1, 1),))
self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
self.mul = P.Mul().shard(((), ()))
self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
training=True, parallel_config=parallel_config)
self.cast = P.Cast()


def construct(self, input_tensor):
bs = self.shape(input_tensor)[0]
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
bs_and_dmodel = self.shape(input_tensor)
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size))

expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device,
self.capacity_factor, self.expert_dim)
# dispatch_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity)
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity)
dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor)

# after transpose, input_tensor's shape: (self.expert_parallel, self.hidden_size, tokens_per_device)
input_tensor = self.transpose(input_tensor, (0, 2, 1))
dispatch_tensor = self.reshape(dispatch_tensor, (self.expert_parallel, tokens_per_device,
self.expert_dim * expert_capacity))
dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor))
# expert_input's shape: (self.expert_parallel, self.hidden_size, self.expert_dim * expert_capacity)
expert_input = self.batch_mm(input_tensor, dispatch_tensor)
expert_input = self.reshape(expert_input, (self.expert_parallel, self.hidden_size, self.expert_dim,
expert_capacity))
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
expert_input = self.transpose2(expert_input, (2, 0, 3, 1))
expert_input = self.reshape(expert_input, (self.expert_dim, self.expert_parallel * expert_capacity,
self.hidden_size))

# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
expert_output = self.ffn(expert_input)
expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel,
expert_capacity, self.hidden_size))
# expert_output's shape: (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity)
expert_output = self.transpose3(expert_output, (1, 3, 0, 2))
expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size,
self.expert_dim*expert_capacity))
combine_tensor = self.reshape(combine_tensor, (self.expert_parallel, tokens_per_device,
self.expert_dim*expert_capacity))
# combine_tensor's shape: (self.expert_parallel, self.expert_dim*expert_capacity, tokens_per_device)
combine_tensor = self.transpose4(combine_tensor, (0, 2, 1))
combine_tensor = self.cast(combine_tensor, F.dtype(expert_output))

# combined_output's shape: (self.expert_parallel, self.hidden_size, tokens_per_device)
combined_output = self.batch_mm2(expert_output, combine_tensor)
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
combined_output = self.transpose5(combined_output, (0, 2, 1))
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
combined_output = self.reshape(combined_output, (bs, -1, self.hidden_size))

aux_loss = self.mul(self.aux_loss_factor, aux_loss)
return combined_output, aux_loss

class AttentionMask(Cell):
r"""
@@ -903,6 +1065,7 @@ class TransformerEncoderLayer(Cell):
param_init_type(dtype.Number): The parameter initialization type of the module.
Can be dtype.float32 or dtype.float16. Default dtype.float32.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.

@@ -973,6 +1136,7 @@ class TransformerEncoderLayer(Cell):
param_init_type=mstype.float32,
hidden_act='gelu',
use_past=False,
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
super(TransformerEncoderLayer, self).__init__()
_check_config(parallel_config)
@@ -1000,13 +1164,23 @@ class TransformerEncoderLayer(Cell):
param_init_type=param_init_type,
use_past=use_past,
parallel_config=parallel_config)
# Feed Forward Network, FFN
self.output = FeedForward(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
param_init_type=param_init_type,
hidden_act=hidden_act,
parallel_config=parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe is True:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
param_init_type=param_init_type,
hidden_act=hidden_act,
moe_config=moe_config,
parallel_config=parallel_config)
else:
# Feed Forward Network, FFN
self.output = FeedForward(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
param_init_type=param_init_type,
hidden_act=hidden_act,
parallel_config=parallel_config)
self.post_layernorm_residual = post_layernorm_residual
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.dtype = mstype.float16
@@ -1056,7 +1230,11 @@ class TransformerEncoderLayer(Cell):

output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
aux_loss = None
if self.use_moe is True:
mlp_logit, aux_loss = self.output(output_x)
else:
mlp_logit = self.output(output_x)

value_update = None
key_update = None
@@ -1078,6 +1256,8 @@ class TransformerEncoderLayer(Cell):
output = self.add(output_x, mlp_logit)
else:
output = self.add(x, mlp_logit)
if self.use_moe is True:
return output, layer_present, aux_loss
return output, layer_present

def _check_input(self, x, input_mask, init_reset, batch_valid_length):
@@ -1123,6 +1303,7 @@ class TransformerDecoderLayer(Cell):
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
Default dtype.float32.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.

@@ -1204,6 +1385,7 @@ class TransformerDecoderLayer(Cell):
softmax_comptue_type=mstype.float32,
param_init_type=mstype.float32,
hidden_act='gelu',
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
super(TransformerDecoderLayer, self).__init__()
_check_config(parallel_config)
@@ -1247,14 +1429,23 @@ class TransformerDecoderLayer(Cell):
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(
layernorm_compute_type)
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),))

# Feed Forward Network, FFN
self.output = FeedForward(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe is True:
self.output = MoE(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
param_init_type=param_init_type,
hidden_act=hidden_act,
moe_config=moe_config,
parallel_config=parallel_config)
else:
# Feed Forward Network, FFN
self.output = FeedForward(hidden_size=hidden_size,
dropout_rate=hidden_dropout_rate,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.post_layernorm_residual = post_layernorm_residual
self.add = P.TensorAdd().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.dtype = mstype.float16
@@ -1321,7 +1512,11 @@ class TransformerDecoderLayer(Cell):

output_x = self.layernorm2(x)
output_x = F.cast(output_x, self.dtype)
mlp_logit = self.output(output_x)
aux_loss = None
if self.use_moe is True:
mlp_logit, aux_loss = self.output(output_x)
else:
mlp_logit = self.output(output_x)

value_update = None
key_update = None
@@ -1343,6 +1538,8 @@ class TransformerDecoderLayer(Cell):
output = self.add(output_x, mlp_logit)
else:
output = self.add(x, mlp_logit)
if self.use_moe is True:
return output, layer_present, aux_loss
return output, layer_present

def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
@@ -1447,6 +1644,7 @@ class TransformerEncoder(Cell):
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
overlap with the encoder layer.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
a instance of `TransformerOpParallelConfig` with default args.

@@ -1523,10 +1721,14 @@ class TransformerEncoder(Cell):
lambda_func=None,
offset=0,
use_past=False,
moe_config=default_moe_config,
parallel_config=default_transformer_config):
super(TransformerEncoder, self).__init__()
_check_config(parallel_config)

self.use_moe = (moe_config.expert_num > 1)
self.add = P.TensorAdd().shard(((), ()))
self.aux_loss = Tensor(0.0, mstype.float32)
self.num_layers = num_layers
self.blocks = nn.CellList()
for i in range(num_layers):
@@ -1543,6 +1745,7 @@ class TransformerEncoder(Cell):
post_layernorm_residual=post_layernorm_residual,
param_init_type=param_init_type,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
@@ -1554,6 +1757,17 @@ class TransformerEncoder(Cell):

def construct(self, hidden_states, attention_mask, init_reset=True, batch_valid_length=None):
present_layer = ()
if self.use_moe is True:
accum_loss = self.aux_loss
for i in range(self.num_layers):
hidden_states, present, aux_loss = self.blocks[i](hidden_states,
attention_mask,
init_reset,
batch_valid_length)
present_layer = present_layer + (present,)
accum_loss = self.add(accum_loss, aux_loss)
return hidden_states, present_layer, accum_loss

for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states,
attention_mask,
@@ -1597,6 +1811,7 @@ class TransformerDecoder(Cell):
zero, `offset(int)` means the layer_index needs a offset, if there are other modules in the net. The
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
Default: None
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
a instance of `TransformerOpParallelConfig` with default args.

@@ -1686,12 +1901,16 @@ class TransformerDecoder(Cell):
lambda_func=None,
use_past=False,
offset=0,
moe_config=default_moe_config,
parallel_config=default_transformer_config):
super(TransformerDecoder, self).__init__()
_check_config(parallel_config)

self.add = P.TensorAdd().shard(((), ()))
self.aux_loss = Tensor(0.0, mstype.float32)
self.num_layers = num_layers
self.blocks = nn.CellList()
self.use_moe = (moe_config.expert_num > 1)
for i in range(num_layers):
block = TransformerDecoderLayer(hidden_size=hidden_size,
batch_size=batch_size,
@@ -1707,6 +1926,7 @@ class TransformerDecoder(Cell):
use_past=use_past,
param_init_type=param_init_type,
post_layernorm_residual=post_layernorm_residual,
moe_config=moe_config,
parallel_config=parallel_config.dp_mp_config)
# If the user doesn't pass the fusion function, use the default one
if not lambda_func:
@@ -1720,6 +1940,19 @@ class TransformerDecoder(Cell):
def construct(self, hidden_states, attention_mask, encoder_output=None, memory_mask=None,
init_reset=True, batch_valid_length=None):
present_layer = ()
if self.use_moe is True:
accum_loss = self.aux_loss
for i in range(self.num_layers):
hidden_states, present, aux_loss = self.blocks[i](hidden_states,
attention_mask,
encoder_output,
memory_mask,
init_reset,
batch_valid_length)
present_layer = present_layer + (present,)
accum_loss = self.add(accum_loss, aux_loss)
return hidden_states, present_layer, accum_loss

# Loop through each self-attention layer
for i in range(self.num_layers):
hidden_states, present = self.blocks[i](hidden_states,
@@ -1871,6 +2104,7 @@ class Transformer(Cell):
param_init_type=mstype.float32,
lambda_func=None,
use_past=False,
moe_config=default_moe_config,
parallel_config=default_transformer_config):
super(Transformer, self).__init__()
_check_config(parallel_config)
@@ -1888,6 +2122,9 @@ class Transformer(Cell):
if not lambda_func:
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)

self.use_moe = (moe_config.expert_num > 1)
self.add = P.TensorAdd().shard(((), ()))
self.aux_loss = Tensor(0.0, mstype.float32)
if encoder_layers > 0:
self.encoder = TransformerEncoder(num_layers=encoder_layers,
batch_size=batch_size,
@@ -1904,6 +2141,7 @@ class Transformer(Cell):
param_init_type=param_init_type,
lambda_func=lambda_func,
use_past=use_past,
moe_config=moe_config,
parallel_config=parallel_config)
else:
self.encoder = None
@@ -1929,6 +2167,7 @@ class Transformer(Cell):
use_past=use_past,
param_init_type=param_init_type,
offset=encoder_layers,
moe_config=moe_config,
parallel_config=parallel_config)

def construct(self, encoder_inputs,
@@ -1943,17 +2182,31 @@ class Transformer(Cell):
output = None
encoder_layer_present = None
decoder_layer_present = None
accum_loss = self.aux_loss
if self.encoder is not None:
encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset,
batch_valid_length)
if self.use_moe is True:
encoder_output, encoder_layer_present, encoder_aux_loss = self.encoder(encoder_inputs, encoder_masks,
init_reset, batch_valid_length)
accum_loss = self.add(accum_loss, encoder_aux_loss)
else:
encoder_output, encoder_layer_present = self.encoder(encoder_inputs, encoder_masks, init_reset,
batch_valid_length)
output = encoder_output

if self.decoder is not None:
# decoder mask can be created outside of the model
decoder_output, decoder_layer_present = self.decoder(decoder_inputs,
decoder_masks,
encoder_output,
memory_mask, init_reset,
batch_valid_length)
if self.use_moe is True:
decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks,
encoder_output, memory_mask,
init_reset, batch_valid_length)
accum_loss = self.add(accum_loss, decoder_aux_loss)
else:
decoder_output, decoder_layer_present = self.decoder(decoder_inputs,
decoder_masks,
encoder_output,
memory_mask, init_reset,
batch_valid_length)
output = decoder_output
if self.use_moe is True:
return output, encoder_layer_present, decoder_layer_present, accum_loss
return output, encoder_layer_present, decoder_layer_present

+ 95
- 0
tests/ut/python/parallel/test_parallel_moe.py View File

@@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore.ops import composite as C
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
from mindspore.train import Model
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss

grad_all = C.GradOperation(get_all=True)


class Dataset(MindData):
def __init__(self, *inputs, length=3):
super(Dataset, self).__init__(size=length)
self.inputs = inputs
self.index = 0
self.length = length

def __iter__(self):
return self

def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.inputs

def reset(self):
self.index = 0


config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, vocab_emb_dp=False)
moe_config = MoEConfig(expert_num=4)


class NetWithLossFiveInputs(nn.Cell):
def __init__(self, network):
super(NetWithLossFiveInputs, self).__init__()
self.loss = VirtualLoss()
self.network = network

def construct(self, x1, x2, x3, x4, x5):
predict, _, _, _ = self.network(x1, x2, x3, x4, x5)
return self.loss(predict)


def test_transformer_model():
context.set_context(save_graphs=True)
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = Transformer(encoder_layers=1,
decoder_layers=1,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
moe_config=moe_config,
parallel_config=config)

encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)

model.train(1, dataset, dataset_sink_mode=False)

Loading…
Cancel
Save