From 02c913a0fee0bbb0a6ade9086c9c142f508ab3e0 Mon Sep 17 00:00:00 2001 From: "suluyan.sly" Date: Tue, 11 Oct 2022 17:26:43 +0800 Subject: [PATCH] [to #42322933] add plug doc string Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10337105 --- .../models/nlp/plug/configuration_plug.py | 165 +++++++----- .../models/nlp/plug/distributed_plug.py | 44 +++- modelscope/models/nlp/plug/modeling_plug.py | 243 ++++++++---------- 3 files changed, 240 insertions(+), 212 deletions(-) diff --git a/modelscope/models/nlp/plug/configuration_plug.py b/modelscope/models/nlp/plug/configuration_plug.py index 64807392..c3a526a9 100644 --- a/modelscope/models/nlp/plug/configuration_plug.py +++ b/modelscope/models/nlp/plug/configuration_plug.py @@ -40,8 +40,6 @@ class PlugNLUConfig(PretrainedConfig): max_position_embeddings=2048, type_vocab_size=3, initializer_range=0.00707, - deep_init=False, - deepspeed=False, lr_decay_style='linear', weight_decay=1e-2, clip_grad=1.0, @@ -53,20 +51,7 @@ class PlugNLUConfig(PretrainedConfig): fp32_tokentypes=False, layernorm_epsilon=1e-5, dec_hidden_layers=6, - pruning_method=None, - pruning_mask_init='constant', - pruning_mask_scale=0.0, - pruning_initial_threshold=1.0, - pruning_final_threshold=0.01, - pruning_initial_warmup=1, - pruning_final_warmup=20, - pruning_module='decoder', - pruning_decay_step=50, - pruning_decay_type='exp', - ft_module=None, attn_separate=False, - LR_weight_rank=8, - LR_mask_rank=8, **kwargs): super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) @@ -82,8 +67,6 @@ class PlugNLUConfig(PretrainedConfig): self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.deep_init = deep_init - self.deepspeed = deepspeed self.lr_decay_style = lr_decay_style self.weight_decay = weight_decay self.clip_grad = clip_grad @@ -95,20 +78,7 @@ class PlugNLUConfig(PretrainedConfig): self.layernorm_epsilon = layernorm_epsilon self.fp32_tokentypes = fp32_tokentypes self.dec_hidden_layers = dec_hidden_layers - self.pruning_method = pruning_method - self.pruning_mask_init = pruning_mask_init - self.pruning_mask_scale = pruning_mask_scale - self.pruning_module = pruning_module - self.pruning_initial_threshold = pruning_initial_threshold - self.pruning_final_threshold = pruning_final_threshold - self.pruning_initial_warmup = pruning_initial_warmup - self.pruning_final_warmup = pruning_final_warmup - self.pruning_decay_step = pruning_decay_step - self.pruning_decay_type = pruning_decay_type - self.ft_module = ft_module self.attn_separate = attn_separate - self.LR_weight_rank = LR_weight_rank - self.LR_mask_rank = LR_mask_rank @classmethod def from_dict(cls, json_object): @@ -148,47 +118,115 @@ class PlugNLUConfig(PretrainedConfig): class PlugNLGConfig(PlugNLUConfig): + """ + This is the configuration class to store the configuration of a [`PlugModel`]. It is used to instantiate a + PLUG understanding model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the PLUG + [PLUG](https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary) architecture. + + Configuration objects inherit from [`PlugNLUConfig`] and can be used to control the model outputs. Read the + documentation from [`PlugNLUConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 21504): + Padded vocabulary size of the PLUG model for vocab tensor parallel. Defines the number of different tokens + that can be represented by the `inputs_ids` passed when calling [`PlugModel`]. + original_vocab_size (`int`, *optional*, defaults to 21128): + True vocabulary size of the PLUG model. Defines the number of different tokens that can be represented. + hidden_size (`int`, *optional*, defaults to 8192): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + dec_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 32768): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the Transformer Attention. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 3): + The vocabulary size of the `token_type_ids` passed when calling [`PlugModel`]. + initializer_range (`float`, *optional*, defaults to 0.00707): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + lr_decay_style (`str`, *optional*, defaults to 'linear'): + The decay style of learning rate during fine-tunining. If string, `"linear"`, `"cosine"`, `"exponential"`, + `"constant"`, `"None"` are supported. + weight_decay (`float`, *optional*, defaults to 1e-2): + Decoupled weight decay to apply. + clip_grad (`float`, *optional*, defaults to 1.0): + Maximum gradient norm for gradient clipping. + warmup (`float`, *optional*, defaults to 0.01): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + pre_ln (`boolean`, *optional*, defaults to `True`): + Whether or not to apply LayerNorm to the input instead of the output in the blocks. + fp16 (`boolean`, *optional*, defaults to `True`): + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. + fp32_layernorm (`boolean`, *optional*, defaults to `True`): + Whether to use fp32 32-bit precision LayerNorm training while the argument `fp16` set to `True`. + fp32_embedding (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision Embedding training while the argument `fp16` set to `True`. + fp32_tokentypes (`boolean`, *optional*, defaults to `False`): + Whether to use fp32 32-bit precision token types training while the argument `fp16` set to `True`. + layernorm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + attn_separate (`boolean`, *optional*, defaults to `False`): + Whether or not to separate query-key-value to query, key, value in the Attention. + + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = 'plugNLG' def __init__(self, vocab_size=21504, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, + original_vocab_size=21128, + hidden_size=8192, + num_hidden_layers=24, + dec_hidden_layers=6, + num_attention_heads=128, + intermediate_size=32768, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, + max_position_embeddings=2048, + type_vocab_size=3, initializer_range=0.00707, - deep_init=False, - deepspeed=False, lr_decay_style='linear', weight_decay=1e-2, clip_grad=1.0, warmup=0.01, - pre_ln=False, - fp16=False, - fp32_layernorm=False, + pre_ln=True, + fp16=True, + fp32_layernorm=True, fp32_embedding=False, fp32_tokentypes=False, - layernorm_epsilon=1e-12, - dec_hidden_layers=6, - pruning_method=None, - pruning_mask_init='constant', - pruning_mask_scale=0.0, - pruning_initial_threshold=1.0, - pruning_final_threshold=0.01, - pruning_initial_warmup=1, - pruning_final_warmup=20, - pruning_module='decoder', - pruning_decay_step=50, - pruning_decay_type='exp', - ft_module=None, + layernorm_epsilon=1e-5, attn_separate=False, - LR_weight_rank=8, - LR_mask_rank=8, **kwargs): super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) @@ -203,8 +241,6 @@ class PlugNLGConfig(PlugNLUConfig): self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range - self.deep_init = deep_init - self.deepspeed = deepspeed self.lr_decay_style = lr_decay_style self.weight_decay = weight_decay self.clip_grad = clip_grad @@ -216,17 +252,4 @@ class PlugNLGConfig(PlugNLUConfig): self.layernorm_epsilon = layernorm_epsilon self.fp32_tokentypes = fp32_tokentypes self.dec_hidden_layers = dec_hidden_layers - self.pruning_method = pruning_method - self.pruning_mask_init = pruning_mask_init - self.pruning_mask_scale = pruning_mask_scale - self.pruning_module = pruning_module - self.pruning_initial_threshold = pruning_initial_threshold - self.pruning_final_threshold = pruning_final_threshold - self.pruning_initial_warmup = pruning_initial_warmup - self.pruning_final_warmup = pruning_final_warmup - self.pruning_decay_step = pruning_decay_step - self.pruning_decay_type = pruning_decay_type - self.ft_module = ft_module self.attn_separate = attn_separate - self.LR_weight_rank = LR_weight_rank - self.LR_mask_rank = LR_mask_rank diff --git a/modelscope/models/nlp/plug/distributed_plug.py b/modelscope/models/nlp/plug/distributed_plug.py index 2992f595..06009ba1 100644 --- a/modelscope/models/nlp/plug/distributed_plug.py +++ b/modelscope/models/nlp/plug/distributed_plug.py @@ -20,6 +20,48 @@ logger = get_logger(__name__) class DistributedPlug(TorchModel): + """ + The wapper class of PLUG Model to initialize parallel environment, load model weights, generate sentences. + Parameters: + model_dir (`str`, *required*): + Path to model damo/nlp_plug_text-generation_27B. + The model structure in model_dir should be like this: + model_dir + |_ config.json + |_ configuration.json + |_ ds_zero-offload_10B_config.json + |_ vocab.txt + |_ model <-- an empty directory + + Model binaries shall be downloaded separately to populate the model directory, so that + the model directory would contain the following binaries: + |_ model + |_ mp_rank_00_model_states.pt + |_ mp_rank_01_model_states.pt + |_ mp_rank_02_model_states.pt + |_ mp_rank_03_model_states.pt + |_ mp_rank_04_model_states.pt + |_ mp_rank_05_model_states.pt + |_ mp_rank_06_model_states.pt + |_ mp_rank_07_model_states.pt + rank (`int`, *required*): + Used to identify different GPUs in a tensor parallel environment. eg. The rank of GPU #0 is 0, and the + model file `mp_rank_00_model_states.pt` will be loaded on this GPU. + world_size (`int`, *required*, defaults to 8): + The parallel size in total. + model_parallel_size (`int`, *required*, defaults to 8): + The parallel size of model(tensor parallel). + master_ip (`str`, *required*): + The master IP, can usually be set to `"127.0.0.1"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + master_port (`str`, *required*): + The master port, can usually be set to `"29500"`, used as part of + [`~torch.distributed.init_process_group`] method parameter `init_method`. + `init_method` = `"tcp://{master_ip}:{master_port}"` + seed (`int`, *optional*, defaults to 42): + Random seed to control sampling. + """ def __init__(self, model_dir, rank, **kwargs): super().__init__(model_dir, **kwargs) @@ -29,7 +71,7 @@ class DistributedPlug(TorchModel): initialize_distributed(rank, mpu, kwargs['world_size'], kwargs['model_parallel_size'], kwargs['master_ip'], kwargs['master_port']) - seed = 0 if 'seed' not in kwargs else kwargs['seed'] + seed = 42 if 'seed' not in kwargs else kwargs['seed'] set_random_seed_mpu(seed) self.iteration = 0 self.dist_model = self.initialize_model(path_load_tag='model') diff --git a/modelscope/models/nlp/plug/modeling_plug.py b/modelscope/models/nlp/plug/modeling_plug.py index 9d2bb14f..df00006b 100644 --- a/modelscope/models/nlp/plug/modeling_plug.py +++ b/modelscope/models/nlp/plug/modeling_plug.py @@ -152,15 +152,7 @@ class BertSelfOutput(nn.Module): bias=True, input_is_parallel=True, stride=1, - init_method=init_method, - pruning_method=config.pruning_method if config.pruning_module in [ - 'all', 'encoder', 'encoder_self', 'encoder_selfvo', - 'encoder_selfo' - ] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + init_method=init_method) self.fp32_layernorm = config.fp32_layernorm if not config.pre_ln: self.LayerNorm = BertLayerNorm( @@ -173,12 +165,8 @@ class BertSelfOutput(nn.Module): self, hidden_states, input_tensor, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) ln_input = hidden_states + input_tensor if self.LayerNorm is not None: @@ -210,20 +198,13 @@ class BertAttention(nn.Module): output_parallel=True, init_method=normal_init_method( mean=0.0, std=config.initializer_range), - separate=config.attn_separate, - pruning_method=config.pruning_method, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - pruning_module=config.pruning_module, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + separate=config.attn_separate) self.output = BertSelfOutput(config) def forward( self, input_tensor, attention_mask, - pruning_threshold=None, ): if self.LayerNorm is not None: ln_input = input_tensor @@ -236,20 +217,16 @@ class BertAttention(nn.Module): self_output = self.self( ln_output, attention_mask, - pruning_threshold=pruning_threshold, ) else: self_output = self.self( input_tensor, attention_mask, - pruning_threshold=pruning_threshold, ) - output_pruning_threshold = pruning_threshold attention_output = self.output( self_output, input_tensor, - pruning_threshold=output_pruning_threshold, ) return attention_output @@ -265,25 +242,15 @@ class BertIntermediate(nn.Module): gather_output=False, stride=1, init_method=normal_init_method( - mean=0.0, std=config.initializer_range), - pruning_method=config.pruning_method if config.pruning_module - in ['all', 'encoder', 'encoder_ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + mean=0.0, std=config.initializer_range)) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act def forward( self, hidden_states, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states @@ -306,13 +273,7 @@ class BertOutput(nn.Module): bias=True, input_is_parallel=True, stride=1, - init_method=init_method, - pruning_method=config.pruning_method if config.pruning_module - in ['all', 'encoder', 'encoder_ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank) + init_method=init_method) self.fp32_layernorm = config.fp32_layernorm if not config.pre_ln: self.LayerNorm = BertLayerNorm( @@ -325,12 +286,8 @@ class BertOutput(nn.Module): self, hidden_states, input_tensor, - pruning_threshold=None, ): - hidden_states = self.dense( - hidden_states, - pruning_threshold=pruning_threshold, - ) + hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) ln_input = hidden_states + input_tensor if self.LayerNorm is not None: @@ -359,14 +316,8 @@ class BertLayer(nn.Module): else: self.LayerNorm = None - def forward( - self, - hidden_states, - attention_mask, - pruning_threshold=None, - ): - attention_output = self.attention( - hidden_states, attention_mask, pruning_threshold=pruning_threshold) + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) if self.LayerNorm is not None: ln_input = attention_output previous_type = attention_output.type() @@ -375,15 +326,10 @@ class BertLayer(nn.Module): ln_output = self.LayerNorm(ln_input) if self.fp32_layernorm: ln_output = ln_output.type(previous_type) - intermediate_output = self.intermediate( - ln_output, pruning_threshold=pruning_threshold) + intermediate_output = self.intermediate(ln_output) else: - intermediate_output = self.intermediate( - attention_output, pruning_threshold=pruning_threshold) - layer_output = self.output( - intermediate_output, - attention_output, - pruning_threshold=pruning_threshold) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) return layer_output @@ -407,7 +353,6 @@ class BertEncoder(nn.Module): output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, - pruning_threshold=None, ): all_encoder_layers = [] @@ -417,8 +362,7 @@ class BertEncoder(nn.Module): layers = self.layer[start:end] x_ = inputs[0] for layer in layers: - x_ = layer( - x_, inputs[1], pruning_threshold=pruning_threshold) + x_ = layer(x_, inputs[1]) return x_ return custom_forward @@ -654,7 +598,6 @@ class BertModel(PreTrainedBertModel): output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, - pruning_threshold=None, ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) @@ -683,8 +626,7 @@ class BertModel(PreTrainedBertModel): extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, checkpoint_activations=checkpoint_activations, - detach_index=detach_index, - pruning_threshold=pruning_threshold) + detach_index=detach_index) sequence_output = encoded_layers[-1] for p in self.pooler.parameters(): if p is None: @@ -709,18 +651,6 @@ class DecodeLayer(nn.Module): std=config.initializer_range, num_layers=config.num_hidden_layers) - self_pruning_method = config.pruning_method - cross_pruning_method = config.pruning_method - ffn_pruning_method = config.pruning_method - - if config.ft_module is not None: - if 'decoder_self' in config.ft_module: - self_pruning_method = 'finetune' - if 'decoder_cross' in config.ft_module: - cross_pruning_method = 'finetune' - if 'decoder_ffn' in config.ft_module: - ffn_pruning_method = 'finetune' - self.attention = mpu.GPT2ParallelSelfAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -728,13 +658,6 @@ class DecodeLayer(nn.Module): output_dropout_prob=config.hidden_dropout_prob, init_method=init_method, output_layer_init_method=output_layer_init_method, - pruning_method=self_pruning_method if config.pruning_module in [ - 'all', 'decoder', 'decoder_self', 'decoder_self+ffn' - ] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.cross_attention = mpu.PalmParallelCrossAttention( @@ -745,12 +668,6 @@ class DecodeLayer(nn.Module): init_method=init_method, attn_separate=False, output_layer_init_method=output_layer_init_method, - pruning_method=cross_pruning_method, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - pruning_module=config.pruning_module, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.input_layernorm = BertLayerNorm( @@ -765,12 +682,6 @@ class DecodeLayer(nn.Module): config.intermediate_size, gather_output=False, init_method=init_method, - pruning_method=ffn_pruning_method if config.pruning_module - in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) else config.hidden_act @@ -779,12 +690,6 @@ class DecodeLayer(nn.Module): config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - pruning_method=ffn_pruning_method if config.pruning_module - in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, - pruning_mask_init=config.pruning_mask_init, - pruning_mask_scale=config.pruning_mask_scale, - LR_weight_rank=config.LR_weight_rank, - LR_mask_rank=config.LR_mask_rank, ) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) @@ -804,8 +709,7 @@ class DecodeLayer(nn.Module): enc_hidden_states, enc_attn_mask, dec_attn_mask, - is_infer=False, - pruning_threshold=None): + is_infer=False): residual = hidden_states previous_type = hidden_states.type() hidden_states = self.input_layernorm( @@ -813,10 +717,7 @@ class DecodeLayer(nn.Module): if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) hidden_states = self.attention( - hidden_states, - dec_attn_mask, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + hidden_states, dec_attn_mask, is_infer=is_infer) hidden_states = residual + hidden_states @@ -825,23 +726,18 @@ class DecodeLayer(nn.Module): self.type_converter(hidden_states)) if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) - hidden_states = self.cross_attention( - hidden_states, - enc_hidden_states, - enc_attn_mask, - pruning_threshold=pruning_threshold) + hidden_states = self.cross_attention(hidden_states, enc_hidden_states, + enc_attn_mask) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_cross_attention_layernorm( self.type_converter(hidden_states)) if self.fp32_layernorm: hidden_states = hidden_states.type(previous_type) - hidden_states = self.intermediate( - hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.intermediate(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.output( - hidden_states, pruning_threshold=pruning_threshold) + hidden_states = self.output(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = residual + hidden_states @@ -866,8 +762,7 @@ class BertDecoder(nn.Module): dec_attn_mask, checkpoint_activations=False, output_all_encoded_layers=False, - is_infer=False, - pruning_threshold=None): + is_infer=False): def custom(start, end): @@ -880,8 +775,7 @@ class BertDecoder(nn.Module): inputs[1], inputs[2], dec_attn_mask * 1, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) return x_ return custom_forward @@ -904,8 +798,7 @@ class BertDecoder(nn.Module): enc_hidden_states, enc_attn_mask, dec_attn_mask, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) previous_type = hidden_states.type() if self.fp32_layernorm: @@ -932,8 +825,7 @@ class DecodeModel(PreTrainedBertModel): enc_attn_mask=None, dec_attn_mask=None, checkpoint_activations=False, - is_infer=False, - pruning_threshold=None): + is_infer=False): extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to( dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility @@ -946,8 +838,7 @@ class DecodeModel(PreTrainedBertModel): extended_attention_mask, dec_attn_mask, checkpoint_activations=False, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) return sequence_output[-1] @@ -972,16 +863,14 @@ class PalmForPreTraining(PreTrainedBertModel): checkpoint_activations=False, is_infer=False, sequence_output=None, - parallel_output=True, - pruning_threshold=None): + parallel_output=True): if sequence_output is None: sequence_output, pooled_output = self.bert( input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, - checkpoint_activations=checkpoint_activations, - pruning_threshold=pruning_threshold) + checkpoint_activations=checkpoint_activations) prediction_scores, seq_relationship_score = self.cls( sequence_output, pooled_output) else: @@ -998,8 +887,7 @@ class PalmForPreTraining(PreTrainedBertModel): attention_mask, decode_attention_mask, checkpoint_activations=checkpoint_activations, - is_infer=is_infer, - pruning_threshold=pruning_threshold) + is_infer=is_infer) transformer_output_parallel = mpu.copy_to_model_parallel_region( decode_output) @@ -1017,6 +905,29 @@ class PalmForPreTraining(PreTrainedBertModel): class PlugModel(torch.nn.Module): + """ + The bare Plug Model transformer outputting raw hidden-states without any specific head on top. + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights. + Example: + + ```python + >>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given + >>> # here only initializes a slice of the model on a single GPU. + >>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. + >>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel + + >>> # Initializing a Plug configuration + >>> configuration = PlugNLGConfig() + + >>> # Initializing a model from the configuration + >>> model = PlugModel(configuration) + """ def __init__(self, config): super(PlugModel, self).__init__() @@ -1034,6 +945,58 @@ class PlugModel(torch.nn.Module): is_infer=False, sequence_output=None, parallel_output=True): + """ + Parameters: + input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`): + `input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. + Indices can be obtained using transformers [`BertTokenizer`]. See + [`TextGenerationPreprocessor.__call__`] for details. + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to + None): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model, + i.e. you can set `target_tokens = input_tokens` Indices are selected in + `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only + computed for labels in `[0, ..., config.vocab_size]` + + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.max_position_embeddings - 1]`. + + decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults + to None): + Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in + `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + checkpoint_activations (`boolean`, *optional*, defaults to `False`): + Whether gradient checkpointing is activated for this model or not. + is_infer (`boolean`, *optional*, defaults to `False`): + Whether or not to perform single inference. + sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, + defaults to None): + Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the + model. A single forward() call can produce one single token. To generate the current token, the + sequence_output generated by the `forward()` of the previous token is required. + parallel_output (`boolean`, *optional*, defaults to `True`): + To parallel return output, or gather it before return. + + + """ return self.model( input_tokens, token_type_ids,