From 62496d75f3c1ec75406883012464c13350c30050 Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Wed, 25 Aug 2021 15:19:54 +0800 Subject: [PATCH] less the interface exposed --- mindspore/parallel/nn/op_parallel_config.py | 16 ++-- mindspore/parallel/nn/transformer.py | 85 ++++++++++--------- tests/ut/python/nn/test_transformer.py | 14 +++ .../parallel/test_parallel_transformer.py | 29 +++++++ 4 files changed, 98 insertions(+), 46 deletions(-) diff --git a/mindspore/parallel/nn/op_parallel_config.py b/mindspore/parallel/nn/op_parallel_config.py index 4bcc544318..9184c0122b 100644 --- a/mindspore/parallel/nn/op_parallel_config.py +++ b/mindspore/parallel/nn/op_parallel_config.py @@ -21,6 +21,7 @@ from mindspore import context import mindspore.communication.management as D from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode +from mindspore import log as logger __all__ = [ "OpParallelConfig" @@ -56,8 +57,8 @@ class OpParallelConfig(_Config): def __init__(self, data_parallel=1, model_parallel=1): Validator.check_positive_int(data_parallel, "data_parallel") Validator.check_positive_int(model_parallel, "model_parallel") - self._data_parallel = data_parallel - self._model_parallel = model_parallel + self.data_parallel = data_parallel + self.model_parallel = model_parallel @property def data_parallel(self): @@ -95,8 +96,8 @@ class _PipeLineConfig(_Config): def __init__(self, pipeline_stage=1, micro_batch_num=1): Validator.check_positive_int(pipeline_stage, "pipeline_stage") Validator.check_positive_int(micro_batch_num, "micro_batch_num") - self._pipeline_stage = pipeline_stage - self._micro_batch_num = micro_batch_num + self.pipeline_stage = pipeline_stage + self.micro_batch_num = micro_batch_num @property def pipeline_stage(self): @@ -150,9 +151,10 @@ def _check_config(config): "should be less than device_num {device_num}") # the config optimizer_shard is same with context.optimizer_shard - if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard: - raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" - f"optimizer_shard {config.optimizer_shard} in the config") + if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard: + logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" + f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the " + f"optimizer_shard to make them consistent.") # pipeline_stage <= micro_batch_num if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\ diff --git a/mindspore/parallel/nn/transformer.py b/mindspore/parallel/nn/transformer.py index 6a11d621a5..52bd69dd20 100644 --- a/mindspore/parallel/nn/transformer.py +++ b/mindspore/parallel/nn/transformer.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -NOTE: +Note: Transformer Networks. This is an experimental interface that is subject to change and/or deletion. """ import math @@ -50,7 +50,7 @@ __all__ = [ @constexpr def _check_input_shape(input_shape, param_name, func_name, target_len): if len(input_shape) != target_len: - raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}") + raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}") return True @@ -107,7 +107,7 @@ class EmbeddingOpParallelConfig(_Config): def __init__(self, data_parallel=1, model_parallel=1, vocab_emb_dp=True): self._dp_mp_config = OpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel) Validator.check_bool(vocab_emb_dp, "vocab_emb_dp") - self._vocab_emb_dp = vocab_emb_dp + self.vocab_emb_dp = vocab_emb_dp @property def data_parallel(self): @@ -180,15 +180,12 @@ class TransformerOpParallelConfig(_Config): def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): - Validator.check_bool(recompute, "recompute") - Validator.check_bool(optimizer_shard, "optimizer_shard") - Validator.check_positive_int(gradient_aggregation_group, "gradient_aggregation_group") + self.recompute = recompute + self.optimizer_shard = optimizer_shard + self.gradient_aggregation_group = gradient_aggregation_group self._embed_dp_mp_config = EmbeddingOpParallelConfig(data_parallel=data_parallel, model_parallel=model_parallel, vocab_emb_dp=vocab_emb_dp) self._pp_config = _PipeLineConfig(pipeline_stage=pipeline_stage, micro_batch_num=micro_batch_num) - self._recompute = recompute - self._optimizer_shard = optimizer_shard - self._gradient_aggregation_group = gradient_aggregation_group @property def recompute(self): @@ -256,7 +253,7 @@ class TransformerOpParallelConfig(_Config): def optimizer_shard(self, value): Validator.check_bool(value, "optimizer_shard") self._optimizer_shard = value - context.set_auto_parallel_context(optimizer_shard=value) + context.set_auto_parallel_context(enable_parallel_optimizer=value) @property def embedding_dp_mp_config(self): @@ -322,6 +319,8 @@ class FeedForward(Cell): Raises: ValueError: `hidden_act` is not a string. ValueError: `parallel_config` is not a subclass of OpParallelConfig. + ValueError: `ffn_hidden_size` is not a multiple of the model parallel way. + ValueError: `hidden_size` is not a multiple of the model parallel way. Supported Platforms: ``Ascend`` ``GPU`` @@ -349,6 +348,13 @@ class FeedForward(Cell): f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") dp = parallel_config.data_parallel mp = parallel_config.model_parallel + if ffn_hidden_size % mp != 0: + raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}") + if hidden_size % mp != 0: + raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}") + if dropout_rate < 0 or dropout_rate >= 1: + raise ValueError("dropout_rate probability should be a number in range [0, 1.0), " + "but got {}".format(dropout_rate)) input_size = hidden_size output_size = ffn_hidden_size # Project to ffn_hidden_size @@ -428,7 +434,7 @@ class AttentionMask(Cell): raise ValueError( f"The parallel_config should be a OpParallelConfig type, but found {type(parallel_config)}") self.seq_length = seq_length - self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1),)) + self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ())) self.reshape = P.Reshape() self.mul = P.BatchMatMul().shard( ((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1))) @@ -492,6 +498,7 @@ class VocabEmbedding(Cell): Supported Platforms: ``Ascend`` ``GPU`` + Examples: >>> model = VocabEmbedding(vocab_size=30, embedding_size=30) >>> tensor = Tensor(np.ones((20, 15)), dtype.int32) @@ -612,7 +619,17 @@ class MultiHeadAttention(Cell): _check_config(parallel_config) self.src_seq_length = src_seq_length self.tgt_seq_length = tgt_seq_length + self.hidden_size = hidden_size + self.batch_size = batch_size Validator.check_positive_int(num_heads, "num_heads") + if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1: + raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), " + "but got {}".format(hidden_dropout_rate)) + if attention_dropout_rate < 0 or attention_dropout_rate >= 1: + raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), " + "but got {}".format(attention_dropout_rate)) + if hidden_size % num_heads != 0: + raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}") if num_heads % parallel_config.model_parallel != 0: raise ValueError(f"The number of heads {num_heads} must be a " f"multiple of parallel_config.model_parallel {parallel_config.model_parallel}.") @@ -719,7 +736,8 @@ class MultiHeadAttention(Cell): output: Tensor, the output logits of this layer layer_present: Tensor, the feature map of current layer """ - + self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past, + value_past, batch_valid_length) query_tensor_original_shape = F.shape(query_tensor) query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1])) @@ -795,7 +813,7 @@ class MultiHeadAttention(Cell): # multi head attention considering attention mask attention = self._attn(query, key, value, attention_mask) # [bs, seq_length, embedding_size] - attention_merge = self.merge_heads(attention) + attention_merge = self._merge_heads(attention) # Output output = self.projection(attention_merge) output = self.dropout(output) @@ -804,10 +822,14 @@ class MultiHeadAttention(Cell): def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, value_past=None, batch_valid_length=None): r"""Check inputs""" - _check_input_shape(F.shape(query_tensor), "query_tensor", self.cls_name, 3) - _check_input_shape(F.shape(key_tensor), "key_tensor", self.cls_name, 3) - _check_input_shape(F.shape(value_tensor), "value_tensor", self.cls_name, 3) - _check_input_shape(F.shape(attention_mask), "attention_mask", self.cls_name, 3) + _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, + [self.batch_size, self.src_seq_length, self.hidden_size]) + _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, + [self.batch_size, self.tgt_seq_length, self.hidden_size]) + _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, + [self.batch_size, self.tgt_seq_length, self.hidden_size]) + _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, + [self.batch_size, self.src_seq_length, self.tgt_seq_length]) _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) @@ -816,23 +838,8 @@ class MultiHeadAttention(Cell): _check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past) _check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past) _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length) - - def split_heads(self, x, transpose): - """ - split 3d tensor to 4d and switch certain axes - Inputs: - x: input tensor - transpose: tuple, the transpose sequence - Outputs: - x_transpose: the 4d output - """ - x_size = P.Shape()(x) - new_x_shape = x_size[:-1] + (self.n_head, self.size_per_head) - x = self.reshape(x, new_x_shape) - x_transpose = self.transpose(x, transpose) - return x_transpose - - def merge_heads(self, x): + return True + def _merge_heads(self, x): """ convert a 4d input to a 3d output @@ -1235,8 +1242,8 @@ class TransformerDecoderLayer(Cell): self.cross_attention = MultiHeadAttention(hidden_size=hidden_size, num_heads=num_heads, batch_size=batch_size, - src_seq_length=src_seq_length, - tgt_seq_length=tgt_seq_length, + src_seq_length=tgt_seq_length, + tgt_seq_length=src_seq_length, hidden_dropout_rate=hidden_dropout_rate, attention_dropout_rate=attention_dropout_rate, softmax_comptue_type=softmax_comptue_type, @@ -1705,9 +1712,9 @@ class Transformer(Cell): r""" Transformer module including encoder and decoder. The difference with the original implements is the module use the residual addition before the layernormalization. And the default hidden act is `gelu`. - The detials can be found in `Attention is all you need`_. + The detials can be found in `Attention is all you need `_. - NOTE: + Note: This is an experimental interface that is subject to change and/or deletion. Args: @@ -1832,7 +1839,7 @@ class Transformer(Cell): raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder" f"layer {decoder_layers}, please use TransformerDecoder") if encoder_layers > 0 and decoder_layers > 0 and use_past is True: - raise ValueError("The transformer with encoder and decoder does not support use_past.") + raise ValueError("The transformer with encoder and decoder does not support use_past=True.") # The shard setting of Transformer is set within the class StackedTransformer if not lambda_func: lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers) diff --git a/tests/ut/python/nn/test_transformer.py b/tests/ut/python/nn/test_transformer.py index 8731a5ea7b..2c7fb2369a 100644 --- a/tests/ut/python/nn/test_transformer.py +++ b/tests/ut/python/nn/test_transformer.py @@ -203,6 +203,20 @@ def test_multihead_attention(): _executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask) +def test_multihead_attention_wrong_batch(): + model = MultiHeadAttention(hidden_size=15, + src_seq_length=20, + tgt_seq_length=20, + batch_size=2, + num_heads=3) + from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32) + to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16) + attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16) + + with pytest.raises(ValueError): + _executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask) + + def test_feedforward_layer(): model = FeedForward(hidden_size=15, ffn_hidden_size=30, diff --git a/tests/ut/python/parallel/test_parallel_transformer.py b/tests/ut/python/parallel/test_parallel_transformer.py index bc3c97ef50..8c9db5540f 100644 --- a/tests/ut/python/parallel/test_parallel_transformer.py +++ b/tests/ut/python/parallel/test_parallel_transformer.py @@ -212,6 +212,35 @@ def test_pipeline_single_transformer(): model.train(1, dataset, dataset_sink_mode=False) +def test_transformer_wrong_head(): + set_auto_parallel_context(device_num=32, + full_batch=True, + pipeline_stages=pipeline_config.pipeline_stage, global_rank=0, + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) + error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False) + with pytest.raises(ValueError): + net = Transformer(batch_size=4, + src_seq_length=20, + tgt_seq_length=10, + encoder_layers=2, + decoder_layers=2, + hidden_size=64, + num_heads=7, + ffn_hidden_size=64, + parallel_config=error_test_config) + + with pytest.raises(ValueError): + net = Transformer(batch_size=4, + src_seq_length=20, + tgt_seq_length=10, + encoder_layers=2, + decoder_layers=2, + hidden_size=63, + num_heads=7, + ffn_hidden_size=64, + parallel_config=error_test_config) + del net + def test_encoder(): class NetWithLoss(nn.Cell): def __init__(self, network):