|
|
|
@@ -152,15 +152,7 @@ class BertSelfOutput(nn.Module): |
|
|
|
bias=True, |
|
|
|
input_is_parallel=True, |
|
|
|
stride=1, |
|
|
|
init_method=init_method, |
|
|
|
pruning_method=config.pruning_method if config.pruning_module in [ |
|
|
|
'all', 'encoder', 'encoder_self', 'encoder_selfvo', |
|
|
|
'encoder_selfo' |
|
|
|
] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank) |
|
|
|
init_method=init_method) |
|
|
|
self.fp32_layernorm = config.fp32_layernorm |
|
|
|
if not config.pre_ln: |
|
|
|
self.LayerNorm = BertLayerNorm( |
|
|
|
@@ -173,12 +165,8 @@ class BertSelfOutput(nn.Module): |
|
|
|
self, |
|
|
|
hidden_states, |
|
|
|
input_tensor, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
hidden_states = self.dense( |
|
|
|
hidden_states, |
|
|
|
pruning_threshold=pruning_threshold, |
|
|
|
) |
|
|
|
hidden_states = self.dense(hidden_states) |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
ln_input = hidden_states + input_tensor |
|
|
|
if self.LayerNorm is not None: |
|
|
|
@@ -210,20 +198,13 @@ class BertAttention(nn.Module): |
|
|
|
output_parallel=True, |
|
|
|
init_method=normal_init_method( |
|
|
|
mean=0.0, std=config.initializer_range), |
|
|
|
separate=config.attn_separate, |
|
|
|
pruning_method=config.pruning_method, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
pruning_module=config.pruning_module, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank) |
|
|
|
separate=config.attn_separate) |
|
|
|
self.output = BertSelfOutput(config) |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
input_tensor, |
|
|
|
attention_mask, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
if self.LayerNorm is not None: |
|
|
|
ln_input = input_tensor |
|
|
|
@@ -236,20 +217,16 @@ class BertAttention(nn.Module): |
|
|
|
self_output = self.self( |
|
|
|
ln_output, |
|
|
|
attention_mask, |
|
|
|
pruning_threshold=pruning_threshold, |
|
|
|
) |
|
|
|
else: |
|
|
|
self_output = self.self( |
|
|
|
input_tensor, |
|
|
|
attention_mask, |
|
|
|
pruning_threshold=pruning_threshold, |
|
|
|
) |
|
|
|
output_pruning_threshold = pruning_threshold |
|
|
|
|
|
|
|
attention_output = self.output( |
|
|
|
self_output, |
|
|
|
input_tensor, |
|
|
|
pruning_threshold=output_pruning_threshold, |
|
|
|
) |
|
|
|
return attention_output |
|
|
|
|
|
|
|
@@ -265,25 +242,15 @@ class BertIntermediate(nn.Module): |
|
|
|
gather_output=False, |
|
|
|
stride=1, |
|
|
|
init_method=normal_init_method( |
|
|
|
mean=0.0, std=config.initializer_range), |
|
|
|
pruning_method=config.pruning_method if config.pruning_module |
|
|
|
in ['all', 'encoder', 'encoder_ffn'] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank) |
|
|
|
mean=0.0, std=config.initializer_range)) |
|
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ |
|
|
|
if isinstance(config.hidden_act, str) else config.hidden_act |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
hidden_states, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
hidden_states = self.dense( |
|
|
|
hidden_states, |
|
|
|
pruning_threshold=pruning_threshold, |
|
|
|
) |
|
|
|
hidden_states = self.dense(hidden_states) |
|
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
@@ -306,13 +273,7 @@ class BertOutput(nn.Module): |
|
|
|
bias=True, |
|
|
|
input_is_parallel=True, |
|
|
|
stride=1, |
|
|
|
init_method=init_method, |
|
|
|
pruning_method=config.pruning_method if config.pruning_module |
|
|
|
in ['all', 'encoder', 'encoder_ffn'] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank) |
|
|
|
init_method=init_method) |
|
|
|
self.fp32_layernorm = config.fp32_layernorm |
|
|
|
if not config.pre_ln: |
|
|
|
self.LayerNorm = BertLayerNorm( |
|
|
|
@@ -325,12 +286,8 @@ class BertOutput(nn.Module): |
|
|
|
self, |
|
|
|
hidden_states, |
|
|
|
input_tensor, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
hidden_states = self.dense( |
|
|
|
hidden_states, |
|
|
|
pruning_threshold=pruning_threshold, |
|
|
|
) |
|
|
|
hidden_states = self.dense(hidden_states) |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
ln_input = hidden_states + input_tensor |
|
|
|
if self.LayerNorm is not None: |
|
|
|
@@ -359,14 +316,8 @@ class BertLayer(nn.Module): |
|
|
|
else: |
|
|
|
self.LayerNorm = None |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
hidden_states, |
|
|
|
attention_mask, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
attention_output = self.attention( |
|
|
|
hidden_states, attention_mask, pruning_threshold=pruning_threshold) |
|
|
|
def forward(self, hidden_states, attention_mask): |
|
|
|
attention_output = self.attention(hidden_states, attention_mask) |
|
|
|
if self.LayerNorm is not None: |
|
|
|
ln_input = attention_output |
|
|
|
previous_type = attention_output.type() |
|
|
|
@@ -375,15 +326,10 @@ class BertLayer(nn.Module): |
|
|
|
ln_output = self.LayerNorm(ln_input) |
|
|
|
if self.fp32_layernorm: |
|
|
|
ln_output = ln_output.type(previous_type) |
|
|
|
intermediate_output = self.intermediate( |
|
|
|
ln_output, pruning_threshold=pruning_threshold) |
|
|
|
intermediate_output = self.intermediate(ln_output) |
|
|
|
else: |
|
|
|
intermediate_output = self.intermediate( |
|
|
|
attention_output, pruning_threshold=pruning_threshold) |
|
|
|
layer_output = self.output( |
|
|
|
intermediate_output, |
|
|
|
attention_output, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
intermediate_output = self.intermediate(attention_output) |
|
|
|
layer_output = self.output(intermediate_output, attention_output) |
|
|
|
return layer_output |
|
|
|
|
|
|
|
|
|
|
|
@@ -407,7 +353,6 @@ class BertEncoder(nn.Module): |
|
|
|
output_all_encoded_layers=True, |
|
|
|
checkpoint_activations=False, |
|
|
|
detach_index=-1, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
all_encoder_layers = [] |
|
|
|
|
|
|
|
@@ -417,8 +362,7 @@ class BertEncoder(nn.Module): |
|
|
|
layers = self.layer[start:end] |
|
|
|
x_ = inputs[0] |
|
|
|
for layer in layers: |
|
|
|
x_ = layer( |
|
|
|
x_, inputs[1], pruning_threshold=pruning_threshold) |
|
|
|
x_ = layer(x_, inputs[1]) |
|
|
|
return x_ |
|
|
|
|
|
|
|
return custom_forward |
|
|
|
@@ -654,7 +598,6 @@ class BertModel(PreTrainedBertModel): |
|
|
|
output_all_encoded_layers=True, |
|
|
|
checkpoint_activations=False, |
|
|
|
detach_index=-1, |
|
|
|
pruning_threshold=None, |
|
|
|
): |
|
|
|
if attention_mask is None: |
|
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
@@ -683,8 +626,7 @@ class BertModel(PreTrainedBertModel): |
|
|
|
extended_attention_mask, |
|
|
|
output_all_encoded_layers=output_all_encoded_layers, |
|
|
|
checkpoint_activations=checkpoint_activations, |
|
|
|
detach_index=detach_index, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
detach_index=detach_index) |
|
|
|
sequence_output = encoded_layers[-1] |
|
|
|
for p in self.pooler.parameters(): |
|
|
|
if p is None: |
|
|
|
@@ -709,18 +651,6 @@ class DecodeLayer(nn.Module): |
|
|
|
std=config.initializer_range, |
|
|
|
num_layers=config.num_hidden_layers) |
|
|
|
|
|
|
|
self_pruning_method = config.pruning_method |
|
|
|
cross_pruning_method = config.pruning_method |
|
|
|
ffn_pruning_method = config.pruning_method |
|
|
|
|
|
|
|
if config.ft_module is not None: |
|
|
|
if 'decoder_self' in config.ft_module: |
|
|
|
self_pruning_method = 'finetune' |
|
|
|
if 'decoder_cross' in config.ft_module: |
|
|
|
cross_pruning_method = 'finetune' |
|
|
|
if 'decoder_ffn' in config.ft_module: |
|
|
|
ffn_pruning_method = 'finetune' |
|
|
|
|
|
|
|
self.attention = mpu.GPT2ParallelSelfAttention( |
|
|
|
hidden_size=config.hidden_size, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
@@ -728,13 +658,6 @@ class DecodeLayer(nn.Module): |
|
|
|
output_dropout_prob=config.hidden_dropout_prob, |
|
|
|
init_method=init_method, |
|
|
|
output_layer_init_method=output_layer_init_method, |
|
|
|
pruning_method=self_pruning_method if config.pruning_module in [ |
|
|
|
'all', 'decoder', 'decoder_self', 'decoder_self+ffn' |
|
|
|
] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank, |
|
|
|
) |
|
|
|
|
|
|
|
self.cross_attention = mpu.PalmParallelCrossAttention( |
|
|
|
@@ -745,12 +668,6 @@ class DecodeLayer(nn.Module): |
|
|
|
init_method=init_method, |
|
|
|
attn_separate=False, |
|
|
|
output_layer_init_method=output_layer_init_method, |
|
|
|
pruning_method=cross_pruning_method, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
pruning_module=config.pruning_module, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank, |
|
|
|
) |
|
|
|
|
|
|
|
self.input_layernorm = BertLayerNorm( |
|
|
|
@@ -765,12 +682,6 @@ class DecodeLayer(nn.Module): |
|
|
|
config.intermediate_size, |
|
|
|
gather_output=False, |
|
|
|
init_method=init_method, |
|
|
|
pruning_method=ffn_pruning_method if config.pruning_module |
|
|
|
in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank, |
|
|
|
) |
|
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ |
|
|
|
if isinstance(config.hidden_act, str) else config.hidden_act |
|
|
|
@@ -779,12 +690,6 @@ class DecodeLayer(nn.Module): |
|
|
|
config.hidden_size, |
|
|
|
input_is_parallel=True, |
|
|
|
init_method=output_layer_init_method, |
|
|
|
pruning_method=ffn_pruning_method if config.pruning_module |
|
|
|
in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None, |
|
|
|
pruning_mask_init=config.pruning_mask_init, |
|
|
|
pruning_mask_scale=config.pruning_mask_scale, |
|
|
|
LR_weight_rank=config.LR_weight_rank, |
|
|
|
LR_mask_rank=config.LR_mask_rank, |
|
|
|
) |
|
|
|
|
|
|
|
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
|
|
@@ -804,8 +709,7 @@ class DecodeLayer(nn.Module): |
|
|
|
enc_hidden_states, |
|
|
|
enc_attn_mask, |
|
|
|
dec_attn_mask, |
|
|
|
is_infer=False, |
|
|
|
pruning_threshold=None): |
|
|
|
is_infer=False): |
|
|
|
residual = hidden_states |
|
|
|
previous_type = hidden_states.type() |
|
|
|
hidden_states = self.input_layernorm( |
|
|
|
@@ -813,10 +717,7 @@ class DecodeLayer(nn.Module): |
|
|
|
if self.fp32_layernorm: |
|
|
|
hidden_states = hidden_states.type(previous_type) |
|
|
|
hidden_states = self.attention( |
|
|
|
hidden_states, |
|
|
|
dec_attn_mask, |
|
|
|
is_infer=is_infer, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
hidden_states, dec_attn_mask, is_infer=is_infer) |
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
@@ -825,23 +726,18 @@ class DecodeLayer(nn.Module): |
|
|
|
self.type_converter(hidden_states)) |
|
|
|
if self.fp32_layernorm: |
|
|
|
hidden_states = hidden_states.type(previous_type) |
|
|
|
hidden_states = self.cross_attention( |
|
|
|
hidden_states, |
|
|
|
enc_hidden_states, |
|
|
|
enc_attn_mask, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
hidden_states = self.cross_attention(hidden_states, enc_hidden_states, |
|
|
|
enc_attn_mask) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.post_cross_attention_layernorm( |
|
|
|
self.type_converter(hidden_states)) |
|
|
|
if self.fp32_layernorm: |
|
|
|
hidden_states = hidden_states.type(previous_type) |
|
|
|
hidden_states = self.intermediate( |
|
|
|
hidden_states, pruning_threshold=pruning_threshold) |
|
|
|
hidden_states = self.intermediate(hidden_states) |
|
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
|
|
|
|
|
hidden_states = self.output( |
|
|
|
hidden_states, pruning_threshold=pruning_threshold) |
|
|
|
hidden_states = self.output(hidden_states) |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
@@ -866,8 +762,7 @@ class BertDecoder(nn.Module): |
|
|
|
dec_attn_mask, |
|
|
|
checkpoint_activations=False, |
|
|
|
output_all_encoded_layers=False, |
|
|
|
is_infer=False, |
|
|
|
pruning_threshold=None): |
|
|
|
is_infer=False): |
|
|
|
|
|
|
|
def custom(start, end): |
|
|
|
|
|
|
|
@@ -880,8 +775,7 @@ class BertDecoder(nn.Module): |
|
|
|
inputs[1], |
|
|
|
inputs[2], |
|
|
|
dec_attn_mask * 1, |
|
|
|
is_infer=is_infer, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
is_infer=is_infer) |
|
|
|
return x_ |
|
|
|
|
|
|
|
return custom_forward |
|
|
|
@@ -904,8 +798,7 @@ class BertDecoder(nn.Module): |
|
|
|
enc_hidden_states, |
|
|
|
enc_attn_mask, |
|
|
|
dec_attn_mask, |
|
|
|
is_infer=is_infer, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
is_infer=is_infer) |
|
|
|
|
|
|
|
previous_type = hidden_states.type() |
|
|
|
if self.fp32_layernorm: |
|
|
|
@@ -932,8 +825,7 @@ class DecodeModel(PreTrainedBertModel): |
|
|
|
enc_attn_mask=None, |
|
|
|
dec_attn_mask=None, |
|
|
|
checkpoint_activations=False, |
|
|
|
is_infer=False, |
|
|
|
pruning_threshold=None): |
|
|
|
is_infer=False): |
|
|
|
extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
extended_attention_mask = extended_attention_mask.to( |
|
|
|
dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility |
|
|
|
@@ -946,8 +838,7 @@ class DecodeModel(PreTrainedBertModel): |
|
|
|
extended_attention_mask, |
|
|
|
dec_attn_mask, |
|
|
|
checkpoint_activations=False, |
|
|
|
is_infer=is_infer, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
is_infer=is_infer) |
|
|
|
return sequence_output[-1] |
|
|
|
|
|
|
|
|
|
|
|
@@ -972,16 +863,14 @@ class PalmForPreTraining(PreTrainedBertModel): |
|
|
|
checkpoint_activations=False, |
|
|
|
is_infer=False, |
|
|
|
sequence_output=None, |
|
|
|
parallel_output=True, |
|
|
|
pruning_threshold=None): |
|
|
|
parallel_output=True): |
|
|
|
if sequence_output is None: |
|
|
|
sequence_output, pooled_output = self.bert( |
|
|
|
input_ids, |
|
|
|
token_type_ids, |
|
|
|
attention_mask, |
|
|
|
output_all_encoded_layers=False, |
|
|
|
checkpoint_activations=checkpoint_activations, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
checkpoint_activations=checkpoint_activations) |
|
|
|
prediction_scores, seq_relationship_score = self.cls( |
|
|
|
sequence_output, pooled_output) |
|
|
|
else: |
|
|
|
@@ -998,8 +887,7 @@ class PalmForPreTraining(PreTrainedBertModel): |
|
|
|
attention_mask, |
|
|
|
decode_attention_mask, |
|
|
|
checkpoint_activations=checkpoint_activations, |
|
|
|
is_infer=is_infer, |
|
|
|
pruning_threshold=pruning_threshold) |
|
|
|
is_infer=is_infer) |
|
|
|
|
|
|
|
transformer_output_parallel = mpu.copy_to_model_parallel_region( |
|
|
|
decode_output) |
|
|
|
@@ -1017,6 +905,29 @@ class PalmForPreTraining(PreTrainedBertModel): |
|
|
|
|
|
|
|
|
|
|
|
class PlugModel(torch.nn.Module): |
|
|
|
""" |
|
|
|
The bare Plug Model transformer outputting raw hidden-states without any specific head on top. |
|
|
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
|
|
and behavior. |
|
|
|
Parameters: |
|
|
|
config ([`PlugNLGConfig`]): Model configuration class with all the parameters of the model. |
|
|
|
Initializing with a config file does not load the weights associated with the model, only the |
|
|
|
configuration. Check out the [`~DistributedPlug.initialize_model`] method to load the model weights. |
|
|
|
Example: |
|
|
|
|
|
|
|
```python |
|
|
|
>>> # The PLUG model has 27B parameters and usually need to run on multiple GPUs. The example given |
|
|
|
>>> # here only initializes a slice of the model on a single GPU. |
|
|
|
>>> # Check out the [`~DistributedPipeline.__init__`] method to initialize entire PLUG model. |
|
|
|
>>> from modelscope.models.nlp.plug import PlugNLGConfig, PlugModel |
|
|
|
|
|
|
|
>>> # Initializing a Plug configuration |
|
|
|
>>> configuration = PlugNLGConfig() |
|
|
|
|
|
|
|
>>> # Initializing a model from the configuration |
|
|
|
>>> model = PlugModel(configuration) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
|
super(PlugModel, self).__init__() |
|
|
|
@@ -1034,6 +945,58 @@ class PlugModel(torch.nn.Module): |
|
|
|
is_infer=False, |
|
|
|
sequence_output=None, |
|
|
|
parallel_output=True): |
|
|
|
""" |
|
|
|
Parameters: |
|
|
|
input_tokens (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`): |
|
|
|
`input_tokens_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using transformers [`BertTokenizer`]. See |
|
|
|
[`TextGenerationPreprocessor.__call__`] for details. |
|
|
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_tokens_length)`, *optional*, defaults to |
|
|
|
None): |
|
|
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
|
|
|
1]`: |
|
|
|
|
|
|
|
- 0 corresponds to a *sentence A* token, |
|
|
|
- 1 corresponds to a *sentence B* token. |
|
|
|
|
|
|
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): |
|
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
|
|
|
|
- 1 for tokens that are **not masked**, |
|
|
|
- 0 for tokens that are **masked**. |
|
|
|
|
|
|
|
target_tokens (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): |
|
|
|
Target token ids(labels) for language modeling. Note that the labels **are shifted** inside the model, |
|
|
|
i.e. you can set `target_tokens = input_tokens` Indices are selected in |
|
|
|
`[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only |
|
|
|
computed for labels in `[0, ..., config.vocab_size]` |
|
|
|
|
|
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults to None): |
|
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range |
|
|
|
`[0, config.max_position_embeddings - 1]`. |
|
|
|
|
|
|
|
decode_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*, defaults |
|
|
|
to None): |
|
|
|
Mask to avoid performing attention on padding token indices of target tokens. Mask values selected in |
|
|
|
`[0, 1]`: |
|
|
|
|
|
|
|
- 1 for tokens that are **not masked**, |
|
|
|
- 0 for tokens that are **masked**. |
|
|
|
|
|
|
|
checkpoint_activations (`boolean`, *optional*, defaults to `False`): |
|
|
|
Whether gradient checkpointing is activated for this model or not. |
|
|
|
is_infer (`boolean`, *optional*, defaults to `False`): |
|
|
|
Whether or not to perform single inference. |
|
|
|
sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, |
|
|
|
defaults to None): |
|
|
|
Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the |
|
|
|
model. A single forward() call can produce one single token. To generate the current token, the |
|
|
|
sequence_output generated by the `forward()` of the previous token is required. |
|
|
|
parallel_output (`boolean`, *optional*, defaults to `True`): |
|
|
|
To parallel return output, or gather it before return. |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
return self.model( |
|
|
|
input_tokens, |
|
|
|
token_type_ids, |
|
|
|
|