From 780330897a47bf24437090e48cf4350dae7af8ed Mon Sep 17 00:00:00 2001 From: "peter.lx" Date: Thu, 1 Sep 2022 22:17:14 +0800 Subject: [PATCH] [to #42322933] add Deberta v2 modeling and fill_mask task, with master merged add Deberta v2 modeling and fill_mask task, with master merged Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9966511 --- modelscope/metainfo.py | 1 + modelscope/models/nlp/__init__.py | 16 +- modelscope/models/nlp/deberta_v2/__init__.py | 73 + .../deberta_v2/configuration_deberta_v2.py | 130 ++ .../nlp/deberta_v2/modeling_deberta_v2.py | 1789 +++++++++++++++++ .../nlp/deberta_v2/tokenization_deberta_v2.py | 546 +++++ .../tokenization_deberta_v2_fast.py | 241 +++ modelscope/models/nlp/masked_language.py | 39 + .../pipelines/nlp/fill_mask_pipeline.py | 16 +- modelscope/preprocessors/nlp.py | 3 + tests/pipelines/test_deberta_tasks.py | 62 + 11 files changed, 2907 insertions(+), 9 deletions(-) create mode 100644 modelscope/models/nlp/deberta_v2/__init__.py create mode 100644 modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py create mode 100644 modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py create mode 100644 modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py create mode 100644 modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py create mode 100644 tests/pipelines/test_deberta_tasks.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 7c5afe80..971dd3f1 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -37,6 +37,7 @@ class Models(object): bert = 'bert' palm = 'palm-v2' structbert = 'structbert' + deberta_v2 = 'deberta_v2' veco = 'veco' translation = 'csanmt-translation' space_dst = 'space-dst' diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index e17a1d31..fd61e40b 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -9,12 +9,15 @@ if TYPE_CHECKING: from .bert_for_sequence_classification import BertForSequenceClassification from .bert_for_document_segmentation import BertForDocumentSegmentation from .csanmt_for_translation import CsanmtForTranslation - from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, - BertForMaskedLM) + from .masked_language import ( + StructBertForMaskedLM, + VecoForMaskedLM, + BertForMaskedLM, + DebertaV2ForMaskedLM, + ) from .nncrf_for_named_entity_recognition import ( TransformerCRFForNamedEntityRecognition, LSTMCRFForNamedEntityRecognition) - from .palm_v2 import PalmForTextGeneration from .token_classification import SbertForTokenClassification from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification from .space import SpaceForDialogIntent @@ -22,7 +25,6 @@ if TYPE_CHECKING: from .space import SpaceForDialogStateTracking from .star_text_to_sql import StarForTextToSql from .task_models import (InformationExtractionModel, - SequenceClassificationModel, SingleBackboneTaskModelBase) from .bart_for_text_error_correction import BartForTextErrorCorrection from .gpt3 import GPT3ForTextGeneration @@ -36,8 +38,10 @@ else: 'csanmt_for_translation': ['CsanmtForTranslation'], 'bert_for_sequence_classification': ['BertForSequenceClassification'], 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], - 'masked_language': - ['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'], + 'masked_language': [ + 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', + 'DebertaV2ForMaskedLM' + ], 'nncrf_for_named_entity_recognition': [ 'TransformerCRFForNamedEntityRecognition', 'LSTMCRFForNamedEntityRecognition' diff --git a/modelscope/models/nlp/deberta_v2/__init__.py b/modelscope/models/nlp/deberta_v2/__init__.py new file mode 100644 index 00000000..664fc6c6 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/__init__.py @@ -0,0 +1,73 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +_import_structure = { + 'configuration_deberta_v2': [ + 'DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config', + 'DebertaV2OnnxConfig' + ], + 'tokenization_deberta_v2': ['DebertaV2Tokenizer'], +} + +if TYPE_CHECKING: + from .configuration_deberta_v2 import DebertaV2Config + from .tokenization_deberta_v2 import DebertaV2Tokenizer + from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast + + from .modeling_deberta_v2 import ( + DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, + DebertaV2ForMaskedLM, + DebertaV2ForMultipleChoice, + DebertaV2ForQuestionAnswering, + DebertaV2ForSequenceClassification, + DebertaV2ForTokenClassification, + DebertaV2Model, + DebertaV2PreTrainedModel, + ) + +else: + _import_structure = { + 'configuration_deberta_v2': + ['DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config'], + 'tokenization_deberta_v2': ['DebertaV2Tokenizer'] + } + _import_structure['tokenization_deberta_v2_fast'] = [ + 'DebertaV2TokenizerFast' + ] + _import_structure['modeling_deberta_v2'] = [ + 'DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST', + 'DebertaV2ForMaskedLM', + 'DebertaV2ForMultipleChoice', + 'DebertaV2ForQuestionAnswering', + 'DebertaV2ForSequenceClassification', + 'DebertaV2ForTokenClassification', + 'DebertaV2Model', + 'DebertaV2PreTrainedModel', + ] + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__) diff --git a/modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py b/modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py new file mode 100644 index 00000000..65e8f0b7 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/configuration_deberta_v2.py @@ -0,0 +1,130 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020, Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeBERTa-v2 model configuration, mainly copied from :class:`~transformers.DeBERTaV2Config""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union + +from transformers import PretrainedConfig + +from modelscope.utils import logger as logging + +logger = logging.get_logger(__name__) + + +class DebertaV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a + DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DeBERTa + [microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 128100): + Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DebertaV2Model`]. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"` + are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 0): + The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-7): + The epsilon used by the layer normalization layers. + relative_attention (`bool`, *optional*, defaults to `True`): + Whether use relative position encoding. + max_relative_positions (`int`, *optional*, defaults to -1): + The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value + as `max_position_embeddings`. + pad_token_id (`int`, *optional*, defaults to 0): + The value used to pad input_ids. + position_biased_input (`bool`, *optional*, defaults to `False`): + Whether add absolute position embedding to content embedding. + pos_att_type (`List[str]`, *optional*): + The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`, + `["p2c", "c2p"]`, `["p2c", "c2p"]`. + layer_norm_eps (`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + """ + model_type = 'deberta_v2' + + def __init__(self, + vocab_size=128100, + hidden_size=1536, + num_hidden_layers=24, + num_attention_heads=24, + intermediate_size=6144, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=0, + initializer_range=0.02, + layer_norm_eps=1e-7, + relative_attention=False, + max_relative_positions=-1, + pad_token_id=0, + position_biased_input=True, + pos_att_type=None, + pooler_dropout=0, + pooler_hidden_act='gelu', + **kwargs): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.pad_token_id = pad_token_id + self.position_biased_input = position_biased_input + + # Backwards compatibility + if type(pos_att_type) == str: + pos_att_type = [x.strip() for x in pos_att_type.lower().split('|')] + + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get('pooler_hidden_size', hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act diff --git a/modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py b/modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py new file mode 100644 index 00000000..1c6b9071 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/modeling_deberta_v2.py @@ -0,0 +1,1789 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model.""" + +from collections.abc import Sequence +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from transformers.activations import ACT2FN +from transformers.file_utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward) +from transformers.modeling_outputs import (BaseModelOutput, MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import softmax_backward_data + +from modelscope.utils import logger as logging +from .configuration_deberta_v2 import DebertaV2Config + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'DebertaV2Config' +_TOKENIZER_FOR_DOC = 'DebertaV2Tokenizer' +_CHECKPOINT_FOR_DOC = 'nlp_debertav2_fill-mask_chinese-lite' + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, + config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (`torch.tensor`): The input tensor that will apply softmax. + mask (`torch.IntTensor`): + The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example: + + ```python + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4, 20, 100]) + + >>> # Create a mask + >>> mask = (x > 0).int() + + >>> # Specify the dimension to apply softmax + >>> dim = -1 + + >>> y = XSoftmax.apply(x, mask, dim) + ```""" + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.to(torch.bool)) + + output = input.masked_fill(rmask, + torch.tensor(torch.finfo(input.dtype).min)) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output, ) = self.saved_tensors + inputGrad = softmax_backward_data(self, grad_output, output, self.dim, + output) + return inputGrad, None, None + + @staticmethod + def symbolic(g, self, mask, dim): + import torch.onnx.symbolic_helper as sym_help + from torch.onnx.symbolic_opset9 import masked_fill, softmax + + mask_cast_value = g.op( + 'Cast', mask, to_i=sym_help.cast_pytorch_to_onnx['Long']) + r_mask = g.op( + 'Cast', + g.op('Sub', + g.op('Constant', value_t=torch.tensor(1, dtype=torch.int64)), + mask_cast_value), + to_i=sym_help.cast_pytorch_to_onnx['Byte'], + ) + output = masked_fill( + g, self, r_mask, + g.op( + 'Constant', + value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = softmax(g, output, dim) + return masked_fill( + g, output, r_mask, + g.op('Constant', value_t=torch.tensor(0, dtype=torch.uint8))) + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to( + torch.bool) + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask, ) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, + local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + from torch.onnx import symbolic_opset12 + + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return symbolic_opset12.dropout(g, input, dropout_p, train) + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if output_attentions: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + query_states=None, + relative_pos=None, + rel_embeddings=None, + output_attentions=False, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if output_attentions: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if output_attentions: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, 'conv_kernel_size', 3) + groups = getattr(config, 'conv_groups', 1) + self.conv_act = getattr(config, 'conv_act', 'tanh') + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size, + padding=(kernel_size - 1) // 2, + groups=groups) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute( + 0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList( + [DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, 'relative_attention', False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, + 'max_relative_positions', -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, 'position_buckets', -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, + config.hidden_size) + + self.norm_rel_ebd = [ + x.strip() + for x in getattr(config, 'norm_rel_ebd', 'none').lower().split('|') + ] + + if 'layer_norm' in self.norm_rel_ebd: + self.LayerNorm = LayerNorm( + config.hidden_size, + config.layer_norm_eps, + elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, 'conv_kernel_size', + 0) > 0 else None + self.gradient_checkpointing = False + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ('layer_norm' in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze( + -2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size( + -2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, + relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states, ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + output_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + next_kv, + attention_mask, + query_states, + relative_pos, + rel_embeddings, + ) + else: + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) + + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + output_states = self.conv(hidden_states, output_states, + input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len( + self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m, ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states, ) + + if not return_dict: + return tuple( + v for v in [output_states, all_hidden_states, all_attentions] + if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, + hidden_states=all_hidden_states, + attentions=all_attentions) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = torch.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = torch.where( + (relative_pos < mid) & (relative_pos > -mid), + torch.tensor(mid - 1).type_as(relative_pos), + torch.abs(relative_pos), + ) + log_pos = ( + torch.ceil( + torch.log(abs_pos / mid) + / torch.log(torch.tensor( + (max_position - 1) / mid)) * (mid - 1)) + mid) + bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), + log_pos * sign) + return bucket_pos + + +def build_relative_position(query_size, + key_size, + bucket_size=-1, + max_position=-1): + """ + Build relative position according to the query and key + + We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key + \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q - + P_k\\) + + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + + Return: + `torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = torch.arange(0, query_size) + k_ids = torch.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, + max_position) + rel_pos_ids = rel_pos_ids.to(torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([ + query_layer.size(0), + query_layer.size(1), + query_layer.size(2), + relative_pos.size(-1) + ]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([ + query_layer.size(0), + query_layer.size(1), + key_layer.size(-2), + key_layer.size(-2) + ]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + *BertConfig*, for more details, please refer [`DebertaV2Config`] + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' + f'heads ({config.num_attention_heads})') + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, 'attention_head_size', + _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, 'share_att_key', False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, 'relative_attention', False) + + if self.relative_attention: + self.position_buckets = getattr(config, 'position_buckets', -1) + self.max_relative_positions = getattr(config, + 'max_relative_positions', -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if 'c2p' in self.pos_att_type: + self.pos_key_proj = nn.Linear( + config.hidden_size, self.all_head_size, bias=True) + if 'p2c' in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, + self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, attention_heads): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), + x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + output_attentions=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ + Call the module + + Args: + hidden_states (`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + *Attention(Q,K,V)* + + attention_mask (`torch.ByteTensor`): + An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum + sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* + th token. + + output_attentions (`bool`, optional): + Whether return the attention matrix. + + query_states (`torch.FloatTensor`, optional): + The *Q* state in *Attention(Q,K,V)*. + + relative_pos (`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with + values ranging in [*-max_relative_positions*, *max_relative_positions*]. + + rel_embeddings (`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [\\(2 \\times + \\text{max_relative_positions}\\), *hidden_size*]. + + + """ + if query_states is None: + query_states = hidden_states + query_layer = self.transpose_for_scores( + self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores( + self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores( + self.value_proj(hidden_states), self.num_attention_heads) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if 'c2p' in self.pos_att_type: + scale_factor += 1 + if 'p2c' in self.pos_att_type: + scale_factor += 1 + scale = torch.sqrt( + torch.tensor(query_layer.size(-1), dtype=torch.float) + * scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose( + -1, -2)) / torch.tensor( + scale, dtype=query_layer.dtype) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias(query_layer, key_layer, + relative_pos, + rel_embeddings, + scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + attention_scores = attention_scores + attention_scores = attention_scores.view(-1, self.num_attention_heads, + attention_scores.size(-2), + attention_scores.size(-1)) + + # bsz x height x length x dimension + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), + attention_probs.size(-1)), value_layer) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, + context_layer.size(-2), + context_layer.size(-1)).permute(0, 2, 1, + 3).contiguous()) + new_context_layer_shape = context_layer.size()[:-2] + (-1, ) + context_layer = context_layer.view(new_context_layer_shape) + if output_attentions: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, + rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError( + f'Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}' + ) + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + rel_embeddings = rel_embeddings[0:att_span * 2, :].unsqueeze(0) + if self.share_att_key: + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores( + self.key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) + else: + if 'c2p' in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + if 'p2c' in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if 'c2p' in self.pos_att_type: + scale = torch.sqrt( + torch.tensor(pos_key_layer.size(-1), dtype=torch.float) + * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([ + query_layer.size(0), + query_layer.size(1), + relative_pos.size(-1) + ]), + ) + score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) + + # position->content + if 'p2c' in self.pos_att_type: + scale = torch.sqrt( + torch.tensor(pos_query_layer.size(-1), dtype=torch.float) + * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([ + query_layer.size(0), + key_layer.size(-2), + key_layer.size(-2) + ]), + ).transpose(-1, -2) + score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, 'pad_token_id', 0) + self.embedding_size = getattr(config, 'embedding_size', + config.hidden_size) + self.word_embeddings = nn.Embedding( + config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, 'position_biased_input', + True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear( + self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + 'position_ids', + torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + mask=None, + inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = 'deberta' + _keys_to_ignore_on_load_missing = ['position_ids'] + _keys_to_ignore_on_load_unexpected = ['position_embeddings'] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DebertaV2Encoder): + module.gradient_checkpointing = value + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled + Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build + on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + + Parameters: + config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.', + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError( + 'The prune function is not implemented in DeBERTa model.') + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + output_attentions=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output, ) + encoder_outputs[ + (1 if output_hidden_states else 2):] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states + if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """DeBERTa Model with a `language modeling` head on top.""", + DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r'pooler'] + _keys_to_ignore_on_load_missing = [ + r'position_ids', r'predictions.decoder.bias' + ] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1)) + + if not return_dict: + output = (prediction_scores, ) + outputs[1:] + return ((masked_lm_loss, ) + + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, 'num_labels', 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, 'cls_dropout', None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather( + logits, 0, + label_index.expand( + label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct( + labeled_logits.view(-1, self.num_labels).float(), + labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + elif self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits, ) + outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits, ) + outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r'pooler'] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss, ) + + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, 'num_labels', 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, 1) + drop_out = getattr(config, 'cls_dropout', None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward( + DEBERTA_INPUTS_DOCSTRING.format('batch_size, sequence_length')) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[ + 1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view( + -1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view( + -1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view( + -1, + token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view( + -1, + attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), + inputs_embeds.size(-1)) + if inputs_embeds is not None else None) + + outputs = self.deberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits, ) + outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py b/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py new file mode 100644 index 00000000..adb60288 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2.py @@ -0,0 +1,546 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DeBERTa. mainly copied from :module:`~transformers.tokenization_deberta`""" + +import os +import unicodedata +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as sp +from transformers.tokenization_utils import PreTrainedTokenizer + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + +PRETRAINED_INIT_CONFIGURATION = {} + +VOCAB_FILES_NAMES = {'vocab_file': 'spm.model'} + + +class DebertaV2Tokenizer(PreTrainedTokenizer): + r""" + Constructs a DeBERTa-v2 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) + and [jieba](https://github.com/fxsjy/jieba). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, + vocab_file, + do_lower_case=False, + split_by_punct=False, + split_chinese=True, + bos_token='[CLS]', + eos_token='[SEP]', + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + super().__init__( + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + split_chinese=split_chinese, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + ' model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.split_chinese = split_chinese + self.vocab_file = vocab_file + self._tokenizer = SPMTokenizer( + vocab_file, + split_by_punct=split_by_punct, + sp_model_kwargs=self.sp_model_kwargs) + self.jieba = None + if self.split_chinese: + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install jieba to split chinese and use DebertaV2Tokenizer. ' + 'See https://pypi.org/project/jieba/ for installation.') + self.jieba = jieba + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self._tokenizer.vocab + + def get_vocab(self): + vocab = self.vocab.copy() + vocab.update(self.get_added_vocab()) + return vocab + + def _tokenize(self, text: str) -> List[str]: + """Take as input a string and return a list of strings (tokens) for words/sub-words""" + if self.do_lower_case: + text = text.lower() + if self.split_chinese: + seg_list = [x for x in self.jieba.cut(text)] + text = ' '.join(seg_list) + return self._tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self._tokenizer.spm.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self._tokenizer.spm.IdToPiece( + index) if index < self.vocab_size else self.unk_token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + return self._tokenizer.decode(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def prepare_for_tokenization(self, + text, + is_split_into_words=False, + **kwargs): + add_prefix_space = kwargs.pop('add_prefix_space', False) + if is_split_into_words or add_prefix_space: + text = ' ' + text + return (text, kwargs) + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + return self._tokenizer.save_pretrained( + save_directory, filename_prefix=filename_prefix) + + +class SPMTokenizer: + r""" + Constructs a tokenizer based on [SentencePiece](https://github.com/google/sentencepiece). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + def __init__(self, + vocab_file, + split_by_punct=False, + sp_model_kwargs: Optional[Dict[str, Any]] = None): + self.split_by_punct = split_by_punct + self.vocab_file = vocab_file + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + if not os.path.exists(vocab_file): + raise FileNotFoundError(f'{vocab_file} does not exist!') + spm.load(vocab_file) + bpe_vocab_size = spm.GetPieceSize() + # Token map + # 0+1 + # 1+1 + # 2+1 + self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} + self.ids_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] + # self.vocab['[PAD]'] = 0 + # self.vocab['[CLS]'] = 1 + # self.vocab['[SEP]'] = 2 + # self.vocab['[UNK]'] = 3 + + self.spm = spm + + def __getstate__(self): + state = self.__dict__.copy() + state['spm'] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, 'sp_model_kwargs'): + self.sp_model_kwargs = {} + + self.spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) + self.spm.Load(self.vocab_file) + + def tokenize(self, text): + return self._encode_as_pieces(text) + + def convert_ids_to_tokens(self, ids): + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def decode(self, tokens, start=-1, end=-1, raw_text=None): + if raw_text is None: + return self.spm.decode_pieces([t for t in tokens]) + else: + words = self.split_to_words(raw_text) + word_tokens = [self.tokenize(w) for w in words] + token2words = [0] * len(tokens) + tid = 0 + for i, w in enumerate(word_tokens): + for k, t in enumerate(w): + token2words[tid] = i + tid += 1 + word_start = token2words[start] + word_end = token2words[end] if end < len(tokens) else len(words) + text = ''.join(words[word_start:word_end]) + return text + + def add_special_token(self, token): + if token not in self.special_tokens: + self.special_tokens.append(token) + if token not in self.vocab: + self.vocab[token] = len(self.vocab) - 1 + self.ids_to_tokens.append(token) + return self.id(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + if (len(token) == 1 and (_is_whitespace(list(token)[0]))): + return False + if _is_control(list(token)[0]): + return False + if _is_punctuation(list(token)[0]): + return False + if token in self.add_special_token: + return False + + word_start = b'\xe2\x96\x81'.decode('utf-8') + return not token.startswith(word_start) + + def pad(self): + return '[PAD]' + + def bos(self): + return '[CLS]' + + def eos(self): + return '[SEP]' + + def unk(self): + return '[UNK]' + + def mask(self): + return '[MASK]' + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] if sym in self.vocab else 1 + + def _encode_as_pieces(self, text): + text = convert_to_unicode(text) + if self.split_by_punct: + words = self._run_split_on_punc(text) + pieces = [self.spm.encode(w, out_type=str) for w in words] + return [p for w in pieces for p in w] + else: + return self.spm.encode(text, out_type=str) + + def split_to_words(self, text): + pieces = self._encode_as_pieces(text) + word_start = b'\xe2\x96\x81'.decode('utf-8') + words = [] + offset = 0 + prev_end = 0 + for i, p in enumerate(pieces): + if p.startswith(word_start): + if offset > prev_end: + words.append(text[prev_end:offset]) + prev_end = offset + w = p.replace(word_start, '') + else: + w = p + try: + s = text.index(w, offset) + pn = '' + k = i + 1 + while k < len(pieces): + pn = pieces[k].replace(word_start, '') + if len(pn) > 0: + break + k += 1 + + if len(pn) > 0 and pn in text[offset:s]: + offset = offset + 1 + else: + offset = s + len(w) + except Exception: + offset = offset + 1 + + if prev_end < offset: + words.append(text[prev_end:offset]) + + return words + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize('NFD', text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == 'Mn': + continue + output.append(char) + return ''.join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return [''.join(x) for x in output] + + def save_pretrained(self, path: str, filename_prefix: str = None): + filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] + if filename_prefix is not None: + filename = filename_prefix + '-' + filename + full_path = os.path.join(path, filename) + with open(full_path, 'wb') as fs: + fs.write(self.spm.serialized_model_proto()) + return (full_path, ) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically control characters but we treat them + # as whitespace since they are generally considered as such. + if char == ' ' or char == '\t' or char == '\n' or char == '\r': + return True + cat = unicodedata.category(char) + if cat == 'Zs': + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == '\t' or char == '\n' or char == '\r': + return False + cat = unicodedata.category(char) + if cat.startswith('C'): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or ( + cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith('P'): + return True + return False + + +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode('utf-8', 'ignore') + else: + raise ValueError(f'Unsupported string type: {type(text)}') diff --git a/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py b/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py new file mode 100644 index 00000000..a1fcecf4 --- /dev/null +++ b/modelscope/models/nlp/deberta_v2/tokenization_deberta_v2_fast.py @@ -0,0 +1,241 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright 2020 Microsoft and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization class for model DeBERTa.""" + +import os +from shutil import copyfile +from typing import Optional, Tuple + +from transformers.file_utils import is_sentencepiece_available +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modelscope.utils import logger as logging + +if is_sentencepiece_available(): + from .tokenization_deberta_v2 import DebertaV2Tokenizer +else: + DebertaV2Tokenizer = None + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + 'vocab_file': 'spm.model', + 'tokenizer_file': 'tokenizer.json' +} + +PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + +PRETRAINED_INIT_CONFIGURATION = {} + + +class DebertaV2TokenizerFast(PreTrainedTokenizerFast): + r""" + Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) + and [rjieba-py](https://github.com/messense/rjieba-py). + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that + contains the vocabulary necessary to instantiate a tokenizer. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + bos_token (`string`, *optional*, defaults to `"[CLS]"`): + The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + eos_token (`string`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. When building a sequence using special tokens, this is not the token that is + used for the end of sequence. The token used is the `sep_token`. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = DebertaV2Tokenizer + + def __init__(self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=False, + split_by_punct=False, + split_chinese=True, + bos_token='[CLS]', + eos_token='[SEP]', + unk_token='[UNK]', + sep_token='[SEP]', + pad_token='[PAD]', + cls_token='[CLS]', + mask_token='[MASK]', + **kwargs) -> None: + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + split_by_punct=split_by_punct, + split_chinese=split_chinese, + **kwargs, + ) + + self.do_lower_case = do_lower_case + self.split_by_punct = split_by_punct + self.split_chinese = split_chinese + self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A DeBERTa sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, + token_ids_0, + token_ids_1=None, + already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ( + [0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, + token_ids_0, + token_ids_1=None): + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' + 'tokenizer.') + + if not os.path.isdir(save_directory): + logger.error( + f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + + VOCAB_FILES_NAMES['vocab_file']) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file, ) diff --git a/modelscope/models/nlp/masked_language.py b/modelscope/models/nlp/masked_language.py index 17324be9..4f466c23 100644 --- a/modelscope/models/nlp/masked_language.py +++ b/modelscope/models/nlp/masked_language.py @@ -6,6 +6,8 @@ from transformers import BertForMaskedLM as BertForMaskedLMTransformer from modelscope.metainfo import Models from modelscope.models.base import TorchModel from modelscope.models.builder import MODELS +from modelscope.models.nlp.deberta_v2 import \ + DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer from modelscope.models.nlp.structbert import SbertForMaskedLM from modelscope.models.nlp.veco import \ VecoForMaskedLM as VecoForMaskedLMTransformer @@ -125,3 +127,40 @@ class VecoForMaskedLM(TorchModel, VecoForMaskedLMTransformer): VecoForMaskedLM).from_pretrained( pretrained_model_name_or_path=model_dir, model_dir=model_dir) + + +@MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) +class DebertaV2ForMaskedLM(TorchModel, DebertaV2ForMaskedLMTransformer): + """Deberta v2 for MLM model. + + Inherited from deberta_v2.DebertaV2ForMaskedLM and TorchModel, so this class can be registered into Model sets. + """ + + def __init__(self, config, model_dir): + super(TorchModel, self).__init__(model_dir) + DebertaV2ForMaskedLMTransformer.__init__(self, config) + + def forward(self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + labels=None): + output = DebertaV2ForMaskedLMTransformer.forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + labels=labels) + output[OutputKeys.INPUT_IDS] = input_ids + return output + + @classmethod + def _instantiate(cls, **kwargs): + model_dir = kwargs.get('model_dir') + return super(DebertaV2ForMaskedLMTransformer, + DebertaV2ForMaskedLM).from_pretrained( + pretrained_model_name_or_path=model_dir, + model_dir=model_dir) diff --git a/modelscope/pipelines/nlp/fill_mask_pipeline.py b/modelscope/pipelines/nlp/fill_mask_pipeline.py index 60a9631b..caba4122 100644 --- a/modelscope/pipelines/nlp/fill_mask_pipeline.py +++ b/modelscope/pipelines/nlp/fill_mask_pipeline.py @@ -13,7 +13,10 @@ from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks __all__ = ['FillMaskPipeline'] -_type_map = {'veco': 'roberta', 'sbert': 'bert'} +_type_map = { + 'veco': 'roberta', + 'sbert': 'bert', +} @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) @@ -65,7 +68,7 @@ class FillMaskPipeline(Pipeline): self.config = Config.from_file( os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) self.tokenizer = preprocessor.tokenizer - self.mask_id = {'roberta': 250001, 'bert': 103} + self.mask_id = {'roberta': 250001, 'bert': 103, 'deberta_v2': 4} self.rep_map = { 'bert': { @@ -85,7 +88,14 @@ class FillMaskPipeline(Pipeline): '': '', '': '', '': ' ' - } + }, + 'deberta_v2': { + '[PAD]': '', + r' +': ' ', + '[SEP]': '', + '[CLS]': '', + '[UNK]': '' + }, } def forward(self, inputs: Dict[str, Any], diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 4882c477..825611d6 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -170,6 +170,9 @@ class NLPTokenizerPreprocessorBase(Preprocessor): elif model_type == Models.veco: from modelscope.models.nlp.veco import VecoTokenizer return VecoTokenizer.from_pretrained(model_dir) + elif model_type == Models.deberta_v2: + from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer + return DebertaV2Tokenizer.from_pretrained(model_dir) else: return AutoTokenizer.from_pretrained(model_dir, use_fast=False) diff --git a/tests/pipelines/test_deberta_tasks.py b/tests/pipelines/test_deberta_tasks.py new file mode 100644 index 00000000..4f3206cd --- /dev/null +++ b/tests/pipelines/test_deberta_tasks.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import torch + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import DebertaV2ForMaskedLM +from modelscope.models.nlp.deberta_v2 import (DebertaV2Tokenizer, + DebertaV2TokenizerFast) +from modelscope.pipelines import pipeline +from modelscope.pipelines.nlp import FillMaskPipeline +from modelscope.preprocessors import FillMaskPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class DeBERTaV2TaskTest(unittest.TestCase): + model_id_deberta = 'damo/nlp_debertav2_fill-mask_chinese-lite' + + ori_text = '你师父差得动你,你师父可差不动我。' + test_input = '你师父差得动你,你师父可[MASK]不动我。' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id_deberta) + preprocessor = FillMaskPreprocessor( + model_dir, first_sequence='sentence', second_sequence=None) + model = DebertaV2ForMaskedLM.from_pretrained(model_dir) + pipeline1 = FillMaskPipeline(model, preprocessor) + pipeline2 = pipeline( + Tasks.fill_mask, model=model, preprocessor=preprocessor) + ori_text = self.ori_text + test_input = self.test_input + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' + f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + # sbert + print(self.model_id_deberta) + model = Model.from_pretrained(self.model_id_deberta) + preprocessor = FillMaskPreprocessor( + model.model_dir, first_sequence='sentence', second_sequence=None) + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=model, preprocessor=preprocessor) + print( + f'\nori_text: {self.ori_text}\ninput: {self.test_input}\npipeline: ' + f'{pipeline_ins(self.test_input)}\n') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.fill_mask, model=self.model_id_deberta) + ori_text = self.ori_text + test_input = self.test_input + print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' + f'{pipeline_ins(test_input)}\n') + + +if __name__ == '__main__': + unittest.main()