|
|
|
@@ -13,7 +13,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
""" |
|
|
|
NOTE: |
|
|
|
Note: |
|
|
|
Transformer Networks. This is an experimental interface that is subject to change and/or deletion. |
|
|
|
""" |
|
|
|
import math |
|
|
|
@@ -50,7 +50,7 @@ __all__ = [ |
|
|
|
@constexpr |
|
|
|
def _check_input_shape(input_shape, param_name, func_name, target_len): |
|
|
|
if len(input_shape) != target_len: |
|
|
|
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}") |
|
|
|
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}") |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
@@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config): |
|
|
|
def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True): |
|
|
|
self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel) |
|
|
|
Validator.check_bool(vocab_emb_dp, "vocab_emb_dp") |
|
|
|
self._vocab_emb_dp = vocab_emb_dp |
|
|
|
self.vocab_emb_dp = vocab_emb_dp |
|
|
|
|
|
|
|
@property |
|
|
|
def data_parallel(self): |
|
|
|
@@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
|
|
|
|
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, |
|
|
|
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): |
|
|
|
Validator.check_bool(recompute, "recompute") |
|
|
|
Validator.check_bool(optimizer_shard, "optimizer_shard") |
|
|
|
Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group") |
|
|
|
self.recompute = recompute |
|
|
|
self.optimizer_shard = optimizer_shard |
|
|
|
self.gradient_aggregation_group = gradient_aggregation_group |
|
|
|
self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, |
|
|
|
vocab_emb_dp=vocab_emb_dp) |
|
|
|
self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num) |
|
|
|
self._recompute = recompute |
|
|
|
self._optimizer_shard = optimizer_shard |
|
|
|
self._gradient_aggregation_group = gradient_aggregation_group |
|
|
|
|
|
|
|
@property |
|
|
|
def recompute(self): |
|
|
|
@@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config): |
|
|
|
def optimizer_shard(self, value): |
|
|
|
Validator.check_bool(value, "optimizer_shard") |
|
|
|
self._optimizer_shard = value |
|
|
|
context.set_auto_parallel_context(optimizer_shard=value) |
|
|
|
context.set_auto_parallel_context(enable_parallel_optimizer=value) |
|
|
|
|
|
|
|
@property |
|
|
|
def embedding_dp_mp_config(self): |
|
|
|
@@ -322,6 +319,8 @@ class FeedForward(Cell): |
|
|
|
Raises: |
|
|
|
ValueError: `hidden_act` is not a string. |
|
|
|
ValueError: `parallel_config` is not a subclass of OpParallelConfig. |
|
|
|
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way. |
|
|
|
ValueError: `hidden_size` is not a multiple of the model parallel way. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
@@ -349,6 +348,13 @@ class FeedForward(Cell): |
|
|
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") |
|
|
|
dp = parallel_config.data_parallel |
|
|
|
mp = parallel_config.model_parallel |
|
|
|
if ffn_hidden_size % mp != 0: |
|
|
|
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}") |
|
|
|
if hidden_size % mp != 0: |
|
|
|
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}") |
|
|
|
if dropout_rate < 0 or dropout_rate >= 1: |
|
|
|
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), " |
|
|
|
"but got {}".format(dropout_rate)) |
|
|
|
input_size = hidden_size |
|
|
|
output_size = ffn_hidden_size |
|
|
|
# Project to ffn_hidden_size |
|
|
|
@@ -428,7 +434,7 @@ class AttentionMask(Cell): |
|
|
|
raise ValueError( |
|
|
|
f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") |
|
|
|
self.seq_length = seq_length |
|
|
|
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),)) |
|
|
|
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.mul = P.BatchMatMul().shard( |
|
|
|
((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) |
|
|
|
@@ -492,6 +498,7 @@ class VocabEmbedding(Cell): |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30) |
|
|
|
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32) |
|
|
|
@@ -612,7 +619,17 @@ class MultiHeadAttention(Cell): |
|
|
|
_check_config(parallel_config) |
|
|
|
self.src_seq_length = src_seq_length |
|
|
|
self.tgt_seq_length = tgt_seq_length |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.batch_size = batch_size |
|
|
|
Validator.check_positive_int(num_heads, "num_heads") |
|
|
|
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1: |
|
|
|
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), " |
|
|
|
"but got {}".format(hidden_dropout_rate)) |
|
|
|
if attention_dropout_rate < 0 or attention_dropout_rate >= 1: |
|
|
|
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), " |
|
|
|
"but got {}".format(attention_dropout_rate)) |
|
|
|
if hidden_size % num_heads != 0: |
|
|
|
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}") |
|
|
|
if num_heads % parallel_config.model_parallel != 0: |
|
|
|
raise ValueError(f"The number of heads {num_heads} must be a " |
|
|
|
f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") |
|
|
|
@@ -719,7 +736,8 @@ class MultiHeadAttention(Cell): |
|
|
|
output: Tensor, the output logits of this layer |
|
|
|
layer_present: Tensor, the feature map of current layer |
|
|
|
""" |
|
|
|
|
|
|
|
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past, |
|
|
|
value_past, batch_valid_length) |
|
|
|
query_tensor_original_shape = F.shape(query_tensor) |
|
|
|
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1])) |
|
|
|
|
|
|
|
@@ -795,7 +813,7 @@ class MultiHeadAttention(Cell): |
|
|
|
# multi head attention considering attention mask |
|
|
|
attention = self._attn(query, key, value, attention_mask) |
|
|
|
# [bs, seq_length, embedding_size] |
|
|
|
attention_merge = self.merge_heads(attention) |
|
|
|
attention_merge = self._merge_heads(attention) |
|
|
|
# Output |
|
|
|
output = self.projection(attention_merge) |
|
|
|
output = self.dropout(output) |
|
|
|
@@ -804,10 +822,14 @@ class MultiHeadAttention(Cell): |
|
|
|
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, |
|
|
|
value_past=None, batch_valid_length=None): |
|
|
|
r"""Check inputs""" |
|
|
|
_check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3) |
|
|
|
_check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3) |
|
|
|
_check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3) |
|
|
|
_check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3) |
|
|
|
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.hidden_size]) |
|
|
|
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, |
|
|
|
[self.batch_size, self.tgt_seq_length, self.hidden_size]) |
|
|
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, |
|
|
|
[self.batch_size, self.src_seq_length, self.tgt_seq_length]) |
|
|
|
|
|
|
|
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) |
|
|
|
@@ -816,23 +838,8 @@ class MultiHeadAttention(Cell): |
|
|
|
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past) |
|
|
|
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past) |
|
|
|
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) |
|
|
|
|
|
|
|
def split_heads(self, x, transpose): |
|
|
|
""" |
|
|
|
split 3d tensor to 4d and switch certain axes |
|
|
|
Inputs: |
|
|
|
x: input tensor |
|
|
|
transpose: tuple, the transpose sequence |
|
|
|
Outputs: |
|
|
|
x_transpose: the 4d output |
|
|
|
""" |
|
|
|
x_size = P.Shape()(x) |
|
|
|
new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head) |
|
|
|
x = self.reshape(x, new_x_shape) |
|
|
|
x_transpose = self.transpose(x, transpose) |
|
|
|
return x_transpose |
|
|
|
|
|
|
|
def merge_heads(self, x): |
|
|
|
return True |
|
|
|
def _merge_heads(self, x): |
|
|
|
""" |
|
|
|
convert a 4d input to a 3d output |
|
|
|
|
|
|
|
@@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell): |
|
|
|
self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, |
|
|
|
num_heads=num_heads, |
|
|
|
batch_size=batch_size, |
|
|
|
src_seq_length=src_seq_length, |
|
|
|
tgt_seq_length=tgt_seq_length, |
|
|
|
src_seq_length=tgt_seq_length, |
|
|
|
tgt_seq_length=src_seq_length, |
|
|
|
hidden_dropout_rate=hidden_dropout_rate, |
|
|
|
attention_dropout_rate=attention_dropout_rate, |
|
|
|
softmax_comptue_type=softmax_comptue_type, |
|
|
|
@@ -1705,9 +1712,9 @@ class Transformer(Cell): |
|
|
|
r""" |
|
|
|
Transformer module including encoder and decoder. The difference with the original implements is the module use |
|
|
|
the residual addition before the layernormalization. And the default hidden act is `gelu`. |
|
|
|
The detials can be found in `Attention is all you need<https://arxiv.org/pdf/1706.03762v5.pdf>`_. |
|
|
|
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_. |
|
|
|
|
|
|
|
NOTE: |
|
|
|
Note: |
|
|
|
This is an experimental interface that is subject to change and/or deletion. |
|
|
|
|
|
|
|
Args: |
|
|
|
@@ -1832,7 +1839,7 @@ class Transformer(Cell): |
|
|
|
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder" |
|
|
|
f"layer {decoder_layers}, please use TransformerDecoder") |
|
|
|
if encoder_layers > 0 and decoder_layers > 0 and use_past is True: |
|
|
|
raise ValueError("The transformer with encoder and decoder does not support use_past.") |
|
|
|
raise ValueError("The transformer with encoder and decoder does not support use_past=True.") |
|
|
|
# The shard setting of Transformer is set within the class StackedTransformer |
|
|
|
if not lambda_func: |
|
|
|
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers) |
|
|
|
|