|
|
|
@@ -281,10 +281,12 @@ class FeedForward(Cell): |
|
|
|
default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. |
|
|
|
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. |
|
|
|
Float tensor. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. |
|
|
|
Tensor, the output of this layer after mapping. |
|
|
|
The shape is `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`. |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: `hidden_act` is not a string. |
|
|
|
@@ -344,11 +346,11 @@ class FeedForward(Cell): |
|
|
|
if expert_num > 1: |
|
|
|
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)), |
|
|
|
strategy_bias=((ep, 1, mp), (mp,)), |
|
|
|
strategy_activation=((ep, 1, mp),)) |
|
|
|
strategy_activation=((ep, mp),)) |
|
|
|
else: |
|
|
|
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)), |
|
|
|
strategy_bias=((dp, mp), (mp,)), |
|
|
|
strategy_activation=((dp, 1, mp),)) |
|
|
|
strategy_activation=((dp, mp),)) |
|
|
|
# Project back to hidden_size |
|
|
|
self.projection = _Linear(in_channels=output_size, |
|
|
|
out_channels=input_size, |
|
|
|
@@ -363,17 +365,17 @@ class FeedForward(Cell): |
|
|
|
strategy_bias=((dp, 1), (1,))) |
|
|
|
self.projection.bias.parallel_optimizer = False |
|
|
|
self.dropout = nn.Dropout(1 - dropout_rate) |
|
|
|
self.dropout.dropout.shard(((dp, 1, 1),)) |
|
|
|
self.dropout.dropout.shard(((dp, 1),)) |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
_check_input_shape(F.shape(x), "x", self.cls_name, 3) |
|
|
|
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3]) |
|
|
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
x = self.cast(x, mstype.float16) |
|
|
|
# returned shape is [bs, seq_length, ffn_hidden_size] |
|
|
|
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] |
|
|
|
hidden = self.mapping(x) |
|
|
|
output = self.projection(hidden) |
|
|
|
# returned shape is [bs, seq_length, hidden_size] |
|
|
|
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] |
|
|
|
output = self.dropout(output) |
|
|
|
return output |
|
|
|
|
|
|
|
@@ -556,9 +558,12 @@ class MultiHeadAttention(Cell): |
|
|
|
an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size). |
|
|
|
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size). |
|
|
|
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size). |
|
|
|
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or |
|
|
|
(batch_size * src_seq_length, hidden_size). |
|
|
|
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or |
|
|
|
(batch_size * src_seq_length, hidden_size). |
|
|
|
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or |
|
|
|
(batch_size * src_seq_length, hidden_size). |
|
|
|
- **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length, |
|
|
|
tgt_seq_length). |
|
|
|
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length). |
|
|
|
@@ -574,7 +579,7 @@ class MultiHeadAttention(Cell): |
|
|
|
Tuple, a tuple contains(`output`, `layer_present`) |
|
|
|
|
|
|
|
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with |
|
|
|
shape (batch_size, src_seq_length, hidden_size) |
|
|
|
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size) |
|
|
|
|
|
|
|
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with |
|
|
|
((batch_size, num_heads, size_per_head, tgt_seq_length), |
|
|
|
@@ -683,11 +688,11 @@ class MultiHeadAttention(Cell): |
|
|
|
self.scale_factor = Tensor(math.sqrt(self.size_per_head)) |
|
|
|
self.use_past = use_past |
|
|
|
self.dropout = nn.Dropout(1 - hidden_dropout_rate) |
|
|
|
self.dropout.dropout.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.dropout.dropout.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate) |
|
|
|
self.prob_dropout.dropout.shard( |
|
|
|
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),)) |
|
|
|
self.softmax = nn.Softmax() |
|
|
|
self.softmax = nn.Softmax().to_float(softmax_compute_type) |
|
|
|
self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),)) |
|
|
|
self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
|
|
|
|
@@ -737,14 +742,11 @@ class MultiHeadAttention(Cell): |
|
|
|
value_past=None, batch_valid_length=None): |
|
|
|
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past, |
|
|
|
value_past, batch_valid_length) |
|
|
|
query_tensor_original_shape = F.shape(query_tensor) |
|
|
|
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1])) |
|
|
|
|
|
|
|
key_tensor_original_shape = F.shape(key_tensor) |
|
|
|
key_tensor = F.reshape(key_tensor, (-1, key_tensor_original_shape[-1])) |
|
|
|
|
|
|
|
value_tensor_original_shape = F.shape(value_tensor) |
|
|
|
value_tensor = F.reshape(value_tensor, (-1, value_tensor_original_shape[-1])) |
|
|
|
batch_size = F.shape(attention_mask)[0] |
|
|
|
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor, |
|
|
|
key_tensor, |
|
|
|
value_tensor, |
|
|
|
attention_mask) |
|
|
|
|
|
|
|
# multi head attention: query, key, value are derived from the same inputs |
|
|
|
query = self.dense1(query_tensor) |
|
|
|
@@ -754,18 +756,18 @@ class MultiHeadAttention(Cell): |
|
|
|
query = self.transpose( |
|
|
|
F.reshape( |
|
|
|
query, |
|
|
|
(-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)), |
|
|
|
(batch_size, -1, self.n_head, self.size_per_head)), |
|
|
|
(0, 2, 1, 3)) |
|
|
|
# the returned shape is [bs, num_heads, size_per_head, seq_length] |
|
|
|
# the returned shape is [bs, size_per_head, seq_length, num_heads] |
|
|
|
key = self.transpose( |
|
|
|
F.reshape( |
|
|
|
key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)), |
|
|
|
key, (batch_size, -1, self.n_head, self.size_per_head)), |
|
|
|
(0, 2, 3, 1)) |
|
|
|
# the returned shape is [bs, num_heads, seq_length, size_per_head] |
|
|
|
value = self.transpose( |
|
|
|
F.reshape( |
|
|
|
value, |
|
|
|
(-1, value_tensor_original_shape[1], self.n_head, self.size_per_head)), |
|
|
|
(batch_size, -1, self.n_head, self.size_per_head)), |
|
|
|
(0, 2, 1, 3)) |
|
|
|
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq] |
|
|
|
if len(F.shape(attention_mask)) == 3: |
|
|
|
@@ -810,22 +812,26 @@ class MultiHeadAttention(Cell): |
|
|
|
|
|
|
|
layer_present = (key_present, value_present) |
|
|
|
# multi head attention considering attention mask |
|
|
|
# the return shape is [bs, seq_length, hidden_size] |
|
|
|
# the return shape is [bs * seq_length, hidden_size] |
|
|
|
attention = self._attn(query, key, value, attention_mask) |
|
|
|
# Output |
|
|
|
output = self.projection(attention) |
|
|
|
output = self.dropout(output) |
|
|
|
output = F.reshape(output, ori_shape) |
|
|
|
return output, layer_present |
|
|
|
|
|
|
|
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, |
|
|
|
value_past=None, batch_valid_length=None): |
|
|
|
r"""Check inputs""" |
|
|
|
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.src_seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.src_seq_length, self.hidden_size]]) |
|
|
|
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]]) |
|
|
|
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]]) |
|
|
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.tgt_seq_length]) |
|
|
|
|
|
|
|
@@ -839,20 +845,30 @@ class MultiHeadAttention(Cell): |
|
|
|
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) |
|
|
|
return True |
|
|
|
|
|
|
|
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask): |
|
|
|
"""convert a nd tensor to a 2d tensor""" |
|
|
|
query_shape = F.shape(query_tensor) |
|
|
|
query_tensor = F.reshape(query_tensor, (-1, query_shape[-1])) |
|
|
|
key_shape = F.shape(key_tensor) |
|
|
|
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1])) |
|
|
|
value_shape = F.shape(value_tensor) |
|
|
|
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1])) |
|
|
|
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape |
|
|
|
|
|
|
|
def _merge_heads(self, x): |
|
|
|
""" |
|
|
|
convert a 4d input to a 3d output |
|
|
|
convert a 4d input to a 2d output |
|
|
|
|
|
|
|
Inputs: |
|
|
|
x: input tensor |
|
|
|
|
|
|
|
Output: |
|
|
|
x_merge: the 3d output |
|
|
|
x_merge: the 2d output |
|
|
|
""" |
|
|
|
x = self.merger_head_transpose( |
|
|
|
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head |
|
|
|
x_shape = P.Shape()(x) |
|
|
|
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],) |
|
|
|
new_shape = (-1, x_shape[-2] * x_shape[-1]) |
|
|
|
x_merge = self.reshape(x, new_shape) |
|
|
|
return x_merge |
|
|
|
|
|
|
|
@@ -947,7 +963,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size]. |
|
|
|
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size]. |
|
|
|
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length]. |
|
|
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and |
|
|
|
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True. |
|
|
|
@@ -958,7 +975,7 @@ class TransformerEncoderLayer(Cell): |
|
|
|
Tuple, a tuple contains(`output`, `layer_present`). |
|
|
|
|
|
|
|
- **output** (Tensor) - The float tensor of the output of the layer with |
|
|
|
shape (batch_size, seq_length, hidden_size). |
|
|
|
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size). |
|
|
|
|
|
|
|
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with |
|
|
|
((batch_size, num_heads, size_per_head, seq_length), |
|
|
|
@@ -1034,9 +1051,9 @@ class TransformerEncoderLayer(Cell): |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.batch_size = batch_size |
|
|
|
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
|
|
|
|
self.attention = MultiHeadAttention(batch_size=batch_size, |
|
|
|
src_seq_length=seq_length, |
|
|
|
@@ -1067,7 +1084,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
hidden_act=hidden_act, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1))) |
|
|
|
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
self.key_past = None |
|
|
|
self.value_past = None |
|
|
|
@@ -1089,6 +1107,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
|
|
|
|
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None): |
|
|
|
self._check_input(x, input_mask, init_reset, batch_valid_length) |
|
|
|
x_shape = F.shape(x) |
|
|
|
x = F.reshape(x, (-1, x_shape[-1])) |
|
|
|
input_x = self.layernorm1(x) |
|
|
|
input_x = F.cast(input_x, self.dtype) |
|
|
|
|
|
|
|
@@ -1137,10 +1157,23 @@ class TransformerEncoderLayer(Cell): |
|
|
|
mlp_logit = F.depend(mlp_logit, value_update) |
|
|
|
mlp_logit = F.depend(mlp_logit, key_update) |
|
|
|
|
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
# if shape is 3d, we reshape the inputs of the add |
|
|
|
if len(x_shape) == 3: |
|
|
|
output_x = P.Reshape()(output_x, x_shape) |
|
|
|
mlp_logit = P.Reshape()(mlp_logit, x_shape) |
|
|
|
x = P.Reshape()(x, x_shape) |
|
|
|
|
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add_3d(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add_3d(x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
output = F.reshape(output, x_shape) |
|
|
|
|
|
|
|
if self.use_moe is True: |
|
|
|
return output, layer_present, aux_loss |
|
|
|
return output, layer_present |
|
|
|
@@ -1148,7 +1181,8 @@ class TransformerEncoderLayer(Cell): |
|
|
|
def _check_input(self, x, input_mask, init_reset, batch_valid_length): |
|
|
|
r"""Check inputs""" |
|
|
|
_check_shape_equal(F.shape(x), "x", self.cls_name, |
|
|
|
[self.batch_size, self.seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.seq_length, self.hidden_size]]) |
|
|
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, |
|
|
|
[self.batch_size, self.seq_length, self.seq_length]) |
|
|
|
@@ -1193,10 +1227,12 @@ class TransformerDecoderLayer(Cell): |
|
|
|
an instance of `OpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size]. |
|
|
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or |
|
|
|
[batch_size * tgt_seq_length, hidden_size]. |
|
|
|
- **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length, |
|
|
|
seq_length]. |
|
|
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size]. |
|
|
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size]. |
|
|
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, |
|
|
|
src_seq_length], where tgt_seq_length is the length of the decoder. |
|
|
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and |
|
|
|
@@ -1207,7 +1243,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
Outputs: |
|
|
|
Tuple, a tuple contains(`output`, `layer_present`) |
|
|
|
|
|
|
|
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] |
|
|
|
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or |
|
|
|
[batch * seq_length, hidden_size]. |
|
|
|
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value |
|
|
|
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length), |
|
|
|
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector |
|
|
|
@@ -1295,10 +1332,10 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.use_past = use_past |
|
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
self.layernorm1 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type) |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.layernorm2 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type) |
|
|
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
|
|
|
|
self.attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
@@ -1323,9 +1360,9 @@ class TransformerDecoderLayer(Cell): |
|
|
|
use_past=use_past, |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float( |
|
|
|
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float( |
|
|
|
layernorm_compute_type) |
|
|
|
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),)) |
|
|
|
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.use_moe = (moe_config.expert_num > 1) |
|
|
|
if self.use_moe is True: |
|
|
|
self.output = MoE(hidden_size=hidden_size, |
|
|
|
@@ -1344,7 +1381,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
param_init_type=param_init_type, |
|
|
|
parallel_config=parallel_config) |
|
|
|
self.post_layernorm_residual = post_layernorm_residual |
|
|
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1))) |
|
|
|
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
self.dtype = mstype.float16 |
|
|
|
self.key_past = None |
|
|
|
self.value_past = None |
|
|
|
@@ -1369,7 +1407,9 @@ class TransformerDecoderLayer(Cell): |
|
|
|
memory_mask=None, |
|
|
|
init_reset=True, batch_valid_length=None): |
|
|
|
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length) |
|
|
|
# the returned shape is [bs, seq_length, embedding_size] |
|
|
|
# the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size] |
|
|
|
hidden_shape = F.shape(hidden_stats) |
|
|
|
hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1])) |
|
|
|
input_x = self.layernorm1(hidden_stats) |
|
|
|
input_x = F.cast(input_x, self.dtype) |
|
|
|
|
|
|
|
@@ -1431,10 +1471,23 @@ class TransformerDecoderLayer(Cell): |
|
|
|
mlp_logit = F.depend(mlp_logit, value_update) |
|
|
|
mlp_logit = F.depend(mlp_logit, key_update) |
|
|
|
|
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
# if shape is 3d, we reshape the inputs of the add |
|
|
|
if len(hidden_shape) == 3: |
|
|
|
output_x = P.Reshape()(output_x, hidden_shape) |
|
|
|
mlp_logit = P.Reshape()(mlp_logit, hidden_shape) |
|
|
|
x = P.Reshape()(x, hidden_shape) |
|
|
|
|
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add_3d(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add_3d(x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
if self.post_layernorm_residual: |
|
|
|
output = self.add(output_x, mlp_logit) |
|
|
|
else: |
|
|
|
output = self.add(x, mlp_logit) |
|
|
|
output = F.reshape(output, hidden_shape) |
|
|
|
|
|
|
|
if self.use_moe is True: |
|
|
|
return output, layer_present, aux_loss |
|
|
|
return output, layer_present |
|
|
|
@@ -1442,14 +1495,16 @@ class TransformerDecoderLayer(Cell): |
|
|
|
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length): |
|
|
|
r"""Check inputs""" |
|
|
|
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]]) |
|
|
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length]) |
|
|
|
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
if encoder_output is not None: |
|
|
|
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.hidden_size]) |
|
|
|
[[self.batch_size, self.src_seq_length, self.hidden_size], |
|
|
|
[self.batch_size * self.src_seq_length, self.hidden_size]]) |
|
|
|
_check_input_dtype(F.dtype(encoder_output), "encoder_output", |
|
|
|
[mstype.float32, mstype.float16], self.cls_name) |
|
|
|
if memory_mask is not None: |
|
|
|
@@ -1547,7 +1602,8 @@ class TransformerEncoder(Cell): |
|
|
|
an instance of `TransformerOpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] |
|
|
|
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size] |
|
|
|
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length] |
|
|
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and |
|
|
|
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True |
|
|
|
@@ -1558,7 +1614,7 @@ class TransformerEncoder(Cell): |
|
|
|
Tuple, a tuple contains(`output`, `layer_present`) |
|
|
|
|
|
|
|
- **output** (Tensor) - The float tensor of the output of the layer with |
|
|
|
shape (batch_size, seq_length, hidden_size) |
|
|
|
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size) |
|
|
|
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the |
|
|
|
projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length), |
|
|
|
and (batch_size, num_heads, seq_length, size_per_head)). |
|
|
|
@@ -1717,9 +1773,11 @@ class TransformerDecoder(Cell): |
|
|
|
an instance of `TransformerOpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] |
|
|
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size] |
|
|
|
- **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length] |
|
|
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] |
|
|
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size] |
|
|
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, |
|
|
|
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape |
|
|
|
[batch_size, seq_length, hidden_size], |
|
|
|
@@ -1731,7 +1789,8 @@ class TransformerDecoder(Cell): |
|
|
|
Outputs: |
|
|
|
Tuple, a tuple contains(`output`, `layer_present`) |
|
|
|
|
|
|
|
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] |
|
|
|
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or |
|
|
|
[batch * tgt_seq_length, hidden_size] |
|
|
|
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected |
|
|
|
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length), |
|
|
|
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector |
|
|
|
@@ -1912,9 +1971,11 @@ class Transformer(Cell): |
|
|
|
an instance of `TransformerOpParallelConfig` with default args. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size]. |
|
|
|
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size]. |
|
|
|
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]. |
|
|
|
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size], |
|
|
|
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or |
|
|
|
[batch_size * seq_length, hidden_size], |
|
|
|
this should be none if the decoder layer is 0. |
|
|
|
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length] |
|
|
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length, |
|
|
|
@@ -1930,8 +1991,9 @@ class Transformer(Cell): |
|
|
|
Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`) |
|
|
|
|
|
|
|
- **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is |
|
|
|
[batch, src_seq_length, hidden_size], if there are encoder and decoders, the output is from the |
|
|
|
decoder layer. The shape is [batch, tgt_seq_length, hidden_size]. |
|
|
|
[batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and |
|
|
|
decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or |
|
|
|
[batch * tgt_seq_length, hidden_size]. |
|
|
|
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the |
|
|
|
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, |
|
|
|
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)). |
|
|
|
|