add Deberta v2 modeling and fill_mask task, with master merged
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9966511
master
| @@ -37,6 +37,7 @@ class Models(object): | |||
| bert = 'bert' | |||
| palm = 'palm-v2' | |||
| structbert = 'structbert' | |||
| deberta_v2 = 'deberta_v2' | |||
| veco = 'veco' | |||
| translation = 'csanmt-translation' | |||
| space_dst = 'space-dst' | |||
| @@ -9,12 +9,15 @@ if TYPE_CHECKING: | |||
| from .bert_for_sequence_classification import BertForSequenceClassification | |||
| from .bert_for_document_segmentation import BertForDocumentSegmentation | |||
| from .csanmt_for_translation import CsanmtForTranslation | |||
| from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||
| BertForMaskedLM) | |||
| from .masked_language import ( | |||
| StructBertForMaskedLM, | |||
| VecoForMaskedLM, | |||
| BertForMaskedLM, | |||
| DebertaV2ForMaskedLM, | |||
| ) | |||
| from .nncrf_for_named_entity_recognition import ( | |||
| TransformerCRFForNamedEntityRecognition, | |||
| LSTMCRFForNamedEntityRecognition) | |||
| from .palm_v2 import PalmForTextGeneration | |||
| from .token_classification import SbertForTokenClassification | |||
| from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | |||
| from .space import SpaceForDialogIntent | |||
| @@ -22,7 +25,6 @@ if TYPE_CHECKING: | |||
| from .space import SpaceForDialogStateTracking | |||
| from .star_text_to_sql import StarForTextToSql | |||
| from .task_models import (InformationExtractionModel, | |||
| SequenceClassificationModel, | |||
| SingleBackboneTaskModelBase) | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| @@ -36,8 +38,10 @@ else: | |||
| 'csanmt_for_translation': ['CsanmtForTranslation'], | |||
| 'bert_for_sequence_classification': ['BertForSequenceClassification'], | |||
| 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | |||
| 'masked_language': | |||
| ['StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM'], | |||
| 'masked_language': [ | |||
| 'StructBertForMaskedLM', 'VecoForMaskedLM', 'BertForMaskedLM', | |||
| 'DebertaV2ForMaskedLM' | |||
| ], | |||
| 'nncrf_for_named_entity_recognition': [ | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForNamedEntityRecognition' | |||
| @@ -0,0 +1,73 @@ | |||
| # flake8: noqa | |||
| # There's no way to ignore "F401 '...' imported but unused" warnings in this | |||
| # module, but to preserve other warnings. So, don't check this module at all. | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| _import_structure = { | |||
| 'configuration_deberta_v2': [ | |||
| 'DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config', | |||
| 'DebertaV2OnnxConfig' | |||
| ], | |||
| 'tokenization_deberta_v2': ['DebertaV2Tokenizer'], | |||
| } | |||
| if TYPE_CHECKING: | |||
| from .configuration_deberta_v2 import DebertaV2Config | |||
| from .tokenization_deberta_v2 import DebertaV2Tokenizer | |||
| from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast | |||
| from .modeling_deberta_v2 import ( | |||
| DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, | |||
| DebertaV2ForMaskedLM, | |||
| DebertaV2ForMultipleChoice, | |||
| DebertaV2ForQuestionAnswering, | |||
| DebertaV2ForSequenceClassification, | |||
| DebertaV2ForTokenClassification, | |||
| DebertaV2Model, | |||
| DebertaV2PreTrainedModel, | |||
| ) | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_deberta_v2': | |||
| ['DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config'], | |||
| 'tokenization_deberta_v2': ['DebertaV2Tokenizer'] | |||
| } | |||
| _import_structure['tokenization_deberta_v2_fast'] = [ | |||
| 'DebertaV2TokenizerFast' | |||
| ] | |||
| _import_structure['modeling_deberta_v2'] = [ | |||
| 'DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST', | |||
| 'DebertaV2ForMaskedLM', | |||
| 'DebertaV2ForMultipleChoice', | |||
| 'DebertaV2ForQuestionAnswering', | |||
| 'DebertaV2ForSequenceClassification', | |||
| 'DebertaV2ForTokenClassification', | |||
| 'DebertaV2Model', | |||
| 'DebertaV2PreTrainedModel', | |||
| ] | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__) | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright 2020, Microsoft and the HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ DeBERTa-v2 model configuration, mainly copied from :class:`~transformers.DeBERTaV2Config""" | |||
| from collections import OrderedDict | |||
| from typing import TYPE_CHECKING, Any, Mapping, Optional, Union | |||
| from transformers import PretrainedConfig | |||
| from modelscope.utils import logger as logging | |||
| logger = logging.get_logger(__name__) | |||
| class DebertaV2Config(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration of a [`DebertaV2Model`]. It is used to instantiate a | |||
| DeBERTa-v2 model according to the specified arguments, defining the model architecture. Instantiating a | |||
| configuration with the defaults will yield a similar configuration to that of the DeBERTa | |||
| [microsoft/deberta-v2-xlarge](https://huggingface.co/microsoft/deberta-v2-xlarge) architecture. | |||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |||
| documentation from [`PretrainedConfig`] for more information. | |||
| Arguments: | |||
| vocab_size (`int`, *optional*, defaults to 128100): | |||
| Vocabulary size of the DeBERTa-v2 model. Defines the number of different tokens that can be represented by | |||
| the `inputs_ids` passed when calling [`DebertaV2Model`]. | |||
| hidden_size (`int`, *optional*, defaults to 1536): | |||
| Dimensionality of the encoder layers and the pooler layer. | |||
| num_hidden_layers (`int`, *optional*, defaults to 24): | |||
| Number of hidden layers in the Transformer encoder. | |||
| num_attention_heads (`int`, *optional*, defaults to 24): | |||
| Number of attention heads for each attention layer in the Transformer encoder. | |||
| intermediate_size (`int`, *optional*, defaults to 6144): | |||
| Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. | |||
| hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): | |||
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |||
| `"relu"`, `"silu"`, `"gelu"`, `"tanh"`, `"gelu_fast"`, `"mish"`, `"linear"`, `"sigmoid"` and `"gelu_new"` | |||
| are supported. | |||
| hidden_dropout_prob (`float`, *optional*, defaults to 0.1): | |||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
| attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): | |||
| The dropout ratio for the attention probabilities. | |||
| max_position_embeddings (`int`, *optional*, defaults to 512): | |||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
| just in case (e.g., 512 or 1024 or 2048). | |||
| type_vocab_size (`int`, *optional*, defaults to 0): | |||
| The vocabulary size of the `token_type_ids` passed when calling [`DebertaModel`] or [`TFDebertaModel`]. | |||
| initializer_range (`float`, *optional*, defaults to 0.02): | |||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
| layer_norm_eps (`float`, *optional*, defaults to 1e-7): | |||
| The epsilon used by the layer normalization layers. | |||
| relative_attention (`bool`, *optional*, defaults to `True`): | |||
| Whether use relative position encoding. | |||
| max_relative_positions (`int`, *optional*, defaults to -1): | |||
| The range of relative positions `[-max_position_embeddings, max_position_embeddings]`. Use the same value | |||
| as `max_position_embeddings`. | |||
| pad_token_id (`int`, *optional*, defaults to 0): | |||
| The value used to pad input_ids. | |||
| position_biased_input (`bool`, *optional*, defaults to `False`): | |||
| Whether add absolute position embedding to content embedding. | |||
| pos_att_type (`List[str]`, *optional*): | |||
| The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`, | |||
| `["p2c", "c2p"]`, `["p2c", "c2p"]`. | |||
| layer_norm_eps (`float`, optional, defaults to 1e-12): | |||
| The epsilon used by the layer normalization layers. | |||
| """ | |||
| model_type = 'deberta_v2' | |||
| def __init__(self, | |||
| vocab_size=128100, | |||
| hidden_size=1536, | |||
| num_hidden_layers=24, | |||
| num_attention_heads=24, | |||
| intermediate_size=6144, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=0, | |||
| initializer_range=0.02, | |||
| layer_norm_eps=1e-7, | |||
| relative_attention=False, | |||
| max_relative_positions=-1, | |||
| pad_token_id=0, | |||
| position_biased_input=True, | |||
| pos_att_type=None, | |||
| pooler_dropout=0, | |||
| pooler_hidden_act='gelu', | |||
| **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_act = hidden_act | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.relative_attention = relative_attention | |||
| self.max_relative_positions = max_relative_positions | |||
| self.pad_token_id = pad_token_id | |||
| self.position_biased_input = position_biased_input | |||
| # Backwards compatibility | |||
| if type(pos_att_type) == str: | |||
| pos_att_type = [x.strip() for x in pos_att_type.lower().split('|')] | |||
| self.pos_att_type = pos_att_type | |||
| self.vocab_size = vocab_size | |||
| self.layer_norm_eps = layer_norm_eps | |||
| self.pooler_hidden_size = kwargs.get('pooler_hidden_size', hidden_size) | |||
| self.pooler_dropout = pooler_dropout | |||
| self.pooler_hidden_act = pooler_hidden_act | |||
| @@ -0,0 +1,546 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright 2020 Microsoft and the HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for DeBERTa. mainly copied from :module:`~transformers.tokenization_deberta`""" | |||
| import os | |||
| import unicodedata | |||
| from typing import Any, Dict, List, Optional, Tuple | |||
| import sentencepiece as sp | |||
| from transformers.tokenization_utils import PreTrainedTokenizer | |||
| PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} | |||
| PRETRAINED_INIT_CONFIGURATION = {} | |||
| VOCAB_FILES_NAMES = {'vocab_file': 'spm.model'} | |||
| class DebertaV2Tokenizer(PreTrainedTokenizer): | |||
| r""" | |||
| Constructs a DeBERTa-v2 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) | |||
| and [jieba](https://github.com/fxsjy/jieba). | |||
| Args: | |||
| vocab_file (`str`): | |||
| [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that | |||
| contains the vocabulary necessary to instantiate a tokenizer. | |||
| do_lower_case (`bool`, *optional*, defaults to `False`): | |||
| Whether or not to lowercase the input when tokenizing. | |||
| bos_token (`string`, *optional*, defaults to `"[CLS]"`): | |||
| The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. | |||
| When building a sequence using special tokens, this is not the token that is used for the beginning of | |||
| sequence. The token used is the `cls_token`. | |||
| eos_token (`string`, *optional*, defaults to `"[SEP]"`): | |||
| The end of sequence token. When building a sequence using special tokens, this is not the token that is | |||
| used for the end of sequence. The token used is the `sep_token`. | |||
| unk_token (`str`, *optional*, defaults to `"[UNK]"`): | |||
| The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this | |||
| token instead. | |||
| sep_token (`str`, *optional*, defaults to `"[SEP]"`): | |||
| The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for | |||
| sequence classification or for a text and a question for question answering. It is also used as the last | |||
| token of a sequence built with special tokens. | |||
| pad_token (`str`, *optional*, defaults to `"[PAD]"`): | |||
| The token used for padding, for example when batching sequences of different lengths. | |||
| cls_token (`str`, *optional*, defaults to `"[CLS]"`): | |||
| The classifier token which is used when doing sequence classification (classification of the whole sequence | |||
| instead of per-token classification). It is the first token of the sequence when built with special tokens. | |||
| mask_token (`str`, *optional*, defaults to `"[MASK]"`): | |||
| The token used for masking values. This is the token used when training this model with masked language | |||
| modeling. This is the token which the model will try to predict. | |||
| sp_model_kwargs (`dict`, *optional*): | |||
| Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for | |||
| SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, | |||
| to set: | |||
| - `enable_sampling`: Enable subword regularization. | |||
| - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. | |||
| - `nbest_size = {0,1}`: No sampling is performed. | |||
| - `nbest_size > 1`: samples from the nbest_size results. | |||
| - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) | |||
| using forward-filtering-and-backward-sampling algorithm. | |||
| - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for | |||
| BPE-dropout. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| def __init__(self, | |||
| vocab_file, | |||
| do_lower_case=False, | |||
| split_by_punct=False, | |||
| split_chinese=True, | |||
| bos_token='[CLS]', | |||
| eos_token='[SEP]', | |||
| unk_token='[UNK]', | |||
| sep_token='[SEP]', | |||
| pad_token='[PAD]', | |||
| cls_token='[CLS]', | |||
| mask_token='[MASK]', | |||
| sp_model_kwargs: Optional[Dict[str, Any]] = None, | |||
| **kwargs) -> None: | |||
| self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs | |||
| super().__init__( | |||
| do_lower_case=do_lower_case, | |||
| bos_token=bos_token, | |||
| eos_token=eos_token, | |||
| unk_token=unk_token, | |||
| sep_token=sep_token, | |||
| pad_token=pad_token, | |||
| cls_token=cls_token, | |||
| mask_token=mask_token, | |||
| split_by_punct=split_by_punct, | |||
| split_chinese=split_chinese, | |||
| sp_model_kwargs=self.sp_model_kwargs, | |||
| **kwargs, | |||
| ) | |||
| if not os.path.isfile(vocab_file): | |||
| raise ValueError( | |||
| f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" | |||
| ' model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||
| ) | |||
| self.do_lower_case = do_lower_case | |||
| self.split_by_punct = split_by_punct | |||
| self.split_chinese = split_chinese | |||
| self.vocab_file = vocab_file | |||
| self._tokenizer = SPMTokenizer( | |||
| vocab_file, | |||
| split_by_punct=split_by_punct, | |||
| sp_model_kwargs=self.sp_model_kwargs) | |||
| self.jieba = None | |||
| if self.split_chinese: | |||
| try: | |||
| import jieba | |||
| except ImportError: | |||
| raise ImportError( | |||
| 'You need to install jieba to split chinese and use DebertaV2Tokenizer. ' | |||
| 'See https://pypi.org/project/jieba/ for installation.') | |||
| self.jieba = jieba | |||
| @property | |||
| def vocab_size(self): | |||
| return len(self.vocab) | |||
| @property | |||
| def vocab(self): | |||
| return self._tokenizer.vocab | |||
| def get_vocab(self): | |||
| vocab = self.vocab.copy() | |||
| vocab.update(self.get_added_vocab()) | |||
| return vocab | |||
| def _tokenize(self, text: str) -> List[str]: | |||
| """Take as input a string and return a list of strings (tokens) for words/sub-words""" | |||
| if self.do_lower_case: | |||
| text = text.lower() | |||
| if self.split_chinese: | |||
| seg_list = [x for x in self.jieba.cut(text)] | |||
| text = ' '.join(seg_list) | |||
| return self._tokenizer.tokenize(text) | |||
| def _convert_token_to_id(self, token): | |||
| """Converts a token (str) in an id using the vocab.""" | |||
| return self._tokenizer.spm.PieceToId(token) | |||
| def _convert_id_to_token(self, index): | |||
| """Converts an index (integer) in a token (str) using the vocab.""" | |||
| return self._tokenizer.spm.IdToPiece( | |||
| index) if index < self.vocab_size else self.unk_token | |||
| def convert_tokens_to_string(self, tokens): | |||
| """Converts a sequence of tokens (string) in a single string.""" | |||
| return self._tokenizer.decode(tokens) | |||
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
| adding special tokens. A DeBERTa sequence has the following format: | |||
| - single sequence: [CLS] X [SEP] | |||
| - pair of sequences: [CLS] A [SEP] B [SEP] | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs to which the special tokens will be added. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
| """ | |||
| if token_ids_1 is None: | |||
| return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| sep = [self.sep_token_id] | |||
| return cls + token_ids_0 + sep + token_ids_1 + sep | |||
| def get_special_tokens_mask(self, | |||
| token_ids_0, | |||
| token_ids_1=None, | |||
| already_has_special_tokens=False): | |||
| """ | |||
| Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding | |||
| special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| already_has_special_tokens (`bool`, *optional*, defaults to `False`): | |||
| Whether or not the token list is already formatted with special tokens for the model. | |||
| Returns: | |||
| `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. | |||
| """ | |||
| if already_has_special_tokens: | |||
| return super().get_special_tokens_mask( | |||
| token_ids_0=token_ids_0, | |||
| token_ids_1=token_ids_1, | |||
| already_has_special_tokens=True) | |||
| if token_ids_1 is not None: | |||
| return [1] + ([0] * len(token_ids_0)) + [1] + ( | |||
| [0] * len(token_ids_1)) + [1] | |||
| return [1] + ([0] * len(token_ids_0)) + [1] | |||
| def create_token_type_ids_from_sequences(self, | |||
| token_ids_0, | |||
| token_ids_1=None): | |||
| """ | |||
| Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa | |||
| sequence pair mask has the following format: | |||
| ``` | |||
| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| | first sequence | second sequence | | |||
| ``` | |||
| If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
| """ | |||
| sep = [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| if token_ids_1 is None: | |||
| return len(cls + token_ids_0 + sep) * [0] | |||
| return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
| + sep) * [1] | |||
| def prepare_for_tokenization(self, | |||
| text, | |||
| is_split_into_words=False, | |||
| **kwargs): | |||
| add_prefix_space = kwargs.pop('add_prefix_space', False) | |||
| if is_split_into_words or add_prefix_space: | |||
| text = ' ' + text | |||
| return (text, kwargs) | |||
| def save_vocabulary(self, | |||
| save_directory: str, | |||
| filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
| return self._tokenizer.save_pretrained( | |||
| save_directory, filename_prefix=filename_prefix) | |||
| class SPMTokenizer: | |||
| r""" | |||
| Constructs a tokenizer based on [SentencePiece](https://github.com/google/sentencepiece). | |||
| Args: | |||
| vocab_file (`str`): | |||
| [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that | |||
| contains the vocabulary necessary to instantiate a tokenizer. | |||
| sp_model_kwargs (`dict`, *optional*): | |||
| Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for | |||
| SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, | |||
| to set: | |||
| - `enable_sampling`: Enable subword regularization. | |||
| - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. | |||
| - `nbest_size = {0,1}`: No sampling is performed. | |||
| - `nbest_size > 1`: samples from the nbest_size results. | |||
| - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) | |||
| using forward-filtering-and-backward-sampling algorithm. | |||
| - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for | |||
| BPE-dropout. | |||
| """ | |||
| def __init__(self, | |||
| vocab_file, | |||
| split_by_punct=False, | |||
| sp_model_kwargs: Optional[Dict[str, Any]] = None): | |||
| self.split_by_punct = split_by_punct | |||
| self.vocab_file = vocab_file | |||
| self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs | |||
| spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) | |||
| if not os.path.exists(vocab_file): | |||
| raise FileNotFoundError(f'{vocab_file} does not exist!') | |||
| spm.load(vocab_file) | |||
| bpe_vocab_size = spm.GetPieceSize() | |||
| # Token map | |||
| # <unk> 0+1 | |||
| # <s> 1+1 | |||
| # </s> 2+1 | |||
| self.vocab = {spm.IdToPiece(i): i for i in range(bpe_vocab_size)} | |||
| self.ids_to_tokens = [spm.IdToPiece(i) for i in range(bpe_vocab_size)] | |||
| # self.vocab['[PAD]'] = 0 | |||
| # self.vocab['[CLS]'] = 1 | |||
| # self.vocab['[SEP]'] = 2 | |||
| # self.vocab['[UNK]'] = 3 | |||
| self.spm = spm | |||
| def __getstate__(self): | |||
| state = self.__dict__.copy() | |||
| state['spm'] = None | |||
| return state | |||
| def __setstate__(self, d): | |||
| self.__dict__ = d | |||
| # for backward compatibility | |||
| if not hasattr(self, 'sp_model_kwargs'): | |||
| self.sp_model_kwargs = {} | |||
| self.spm = sp.SentencePieceProcessor(**self.sp_model_kwargs) | |||
| self.spm.Load(self.vocab_file) | |||
| def tokenize(self, text): | |||
| return self._encode_as_pieces(text) | |||
| def convert_ids_to_tokens(self, ids): | |||
| tokens = [] | |||
| for i in ids: | |||
| tokens.append(self.ids_to_tokens[i]) | |||
| return tokens | |||
| def decode(self, tokens, start=-1, end=-1, raw_text=None): | |||
| if raw_text is None: | |||
| return self.spm.decode_pieces([t for t in tokens]) | |||
| else: | |||
| words = self.split_to_words(raw_text) | |||
| word_tokens = [self.tokenize(w) for w in words] | |||
| token2words = [0] * len(tokens) | |||
| tid = 0 | |||
| for i, w in enumerate(word_tokens): | |||
| for k, t in enumerate(w): | |||
| token2words[tid] = i | |||
| tid += 1 | |||
| word_start = token2words[start] | |||
| word_end = token2words[end] if end < len(tokens) else len(words) | |||
| text = ''.join(words[word_start:word_end]) | |||
| return text | |||
| def add_special_token(self, token): | |||
| if token not in self.special_tokens: | |||
| self.special_tokens.append(token) | |||
| if token not in self.vocab: | |||
| self.vocab[token] = len(self.vocab) - 1 | |||
| self.ids_to_tokens.append(token) | |||
| return self.id(token) | |||
| def part_of_whole_word(self, token, is_bos=False): | |||
| if is_bos: | |||
| return True | |||
| if (len(token) == 1 and (_is_whitespace(list(token)[0]))): | |||
| return False | |||
| if _is_control(list(token)[0]): | |||
| return False | |||
| if _is_punctuation(list(token)[0]): | |||
| return False | |||
| if token in self.add_special_token: | |||
| return False | |||
| word_start = b'\xe2\x96\x81'.decode('utf-8') | |||
| return not token.startswith(word_start) | |||
| def pad(self): | |||
| return '[PAD]' | |||
| def bos(self): | |||
| return '[CLS]' | |||
| def eos(self): | |||
| return '[SEP]' | |||
| def unk(self): | |||
| return '[UNK]' | |||
| def mask(self): | |||
| return '[MASK]' | |||
| def sym(self, id): | |||
| return self.ids_to_tokens[id] | |||
| def id(self, sym): | |||
| return self.vocab[sym] if sym in self.vocab else 1 | |||
| def _encode_as_pieces(self, text): | |||
| text = convert_to_unicode(text) | |||
| if self.split_by_punct: | |||
| words = self._run_split_on_punc(text) | |||
| pieces = [self.spm.encode(w, out_type=str) for w in words] | |||
| return [p for w in pieces for p in w] | |||
| else: | |||
| return self.spm.encode(text, out_type=str) | |||
| def split_to_words(self, text): | |||
| pieces = self._encode_as_pieces(text) | |||
| word_start = b'\xe2\x96\x81'.decode('utf-8') | |||
| words = [] | |||
| offset = 0 | |||
| prev_end = 0 | |||
| for i, p in enumerate(pieces): | |||
| if p.startswith(word_start): | |||
| if offset > prev_end: | |||
| words.append(text[prev_end:offset]) | |||
| prev_end = offset | |||
| w = p.replace(word_start, '') | |||
| else: | |||
| w = p | |||
| try: | |||
| s = text.index(w, offset) | |||
| pn = '' | |||
| k = i + 1 | |||
| while k < len(pieces): | |||
| pn = pieces[k].replace(word_start, '') | |||
| if len(pn) > 0: | |||
| break | |||
| k += 1 | |||
| if len(pn) > 0 and pn in text[offset:s]: | |||
| offset = offset + 1 | |||
| else: | |||
| offset = s + len(w) | |||
| except Exception: | |||
| offset = offset + 1 | |||
| if prev_end < offset: | |||
| words.append(text[prev_end:offset]) | |||
| return words | |||
| def _run_strip_accents(self, text): | |||
| """Strips accents from a piece of text.""" | |||
| text = unicodedata.normalize('NFD', text) | |||
| output = [] | |||
| for char in text: | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Mn': | |||
| continue | |||
| output.append(char) | |||
| return ''.join(output) | |||
| def _run_split_on_punc(self, text): | |||
| """Splits punctuation on a piece of text.""" | |||
| chars = list(text) | |||
| i = 0 | |||
| start_new_word = True | |||
| output = [] | |||
| while i < len(chars): | |||
| char = chars[i] | |||
| if _is_punctuation(char): | |||
| output.append([char]) | |||
| start_new_word = True | |||
| else: | |||
| if start_new_word: | |||
| output.append([]) | |||
| start_new_word = False | |||
| output[-1].append(char) | |||
| i += 1 | |||
| return [''.join(x) for x in output] | |||
| def save_pretrained(self, path: str, filename_prefix: str = None): | |||
| filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]] | |||
| if filename_prefix is not None: | |||
| filename = filename_prefix + '-' + filename | |||
| full_path = os.path.join(path, filename) | |||
| with open(full_path, 'wb') as fs: | |||
| fs.write(self.spm.serialized_model_proto()) | |||
| return (full_path, ) | |||
| def _is_whitespace(char): | |||
| """Checks whether `chars` is a whitespace character.""" | |||
| # \t, \n, and \r are technically control characters but we treat them | |||
| # as whitespace since they are generally considered as such. | |||
| if char == ' ' or char == '\t' or char == '\n' or char == '\r': | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Zs': | |||
| return True | |||
| return False | |||
| def _is_control(char): | |||
| """Checks whether `chars` is a control character.""" | |||
| # These are technically control characters but we count them as whitespace | |||
| # characters. | |||
| if char == '\t' or char == '\n' or char == '\r': | |||
| return False | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('C'): | |||
| return True | |||
| return False | |||
| def _is_punctuation(char): | |||
| """Checks whether `chars` is a punctuation character.""" | |||
| cp = ord(char) | |||
| # We treat all non-letter/number ASCII as punctuation. | |||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||
| # Punctuation class but we treat them as punctuation anyways, for | |||
| # consistency. | |||
| if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or ( | |||
| cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('P'): | |||
| return True | |||
| return False | |||
| def convert_to_unicode(text): | |||
| """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" | |||
| if isinstance(text, str): | |||
| return text | |||
| elif isinstance(text, bytes): | |||
| return text.decode('utf-8', 'ignore') | |||
| else: | |||
| raise ValueError(f'Unsupported string type: {type(text)}') | |||
| @@ -0,0 +1,241 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # Copyright 2020 Microsoft and the HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Fast Tokenization class for model DeBERTa.""" | |||
| import os | |||
| from shutil import copyfile | |||
| from typing import Optional, Tuple | |||
| from transformers.file_utils import is_sentencepiece_available | |||
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |||
| from modelscope.utils import logger as logging | |||
| if is_sentencepiece_available(): | |||
| from .tokenization_deberta_v2 import DebertaV2Tokenizer | |||
| else: | |||
| DebertaV2Tokenizer = None | |||
| logger = logging.get_logger(__name__) | |||
| VOCAB_FILES_NAMES = { | |||
| 'vocab_file': 'spm.model', | |||
| 'tokenizer_file': 'tokenizer.json' | |||
| } | |||
| PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} | |||
| PRETRAINED_INIT_CONFIGURATION = {} | |||
| class DebertaV2TokenizerFast(PreTrainedTokenizerFast): | |||
| r""" | |||
| Constructs a DeBERTa-v2 fast tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece) | |||
| and [rjieba-py](https://github.com/messense/rjieba-py). | |||
| Args: | |||
| vocab_file (`str`): | |||
| [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that | |||
| contains the vocabulary necessary to instantiate a tokenizer. | |||
| do_lower_case (`bool`, *optional*, defaults to `False`): | |||
| Whether or not to lowercase the input when tokenizing. | |||
| bos_token (`string`, *optional*, defaults to `"[CLS]"`): | |||
| The beginning of sequence token that was used during pre-training. Can be used a sequence classifier token. | |||
| When building a sequence using special tokens, this is not the token that is used for the beginning of | |||
| sequence. The token used is the `cls_token`. | |||
| eos_token (`string`, *optional*, defaults to `"[SEP]"`): | |||
| The end of sequence token. When building a sequence using special tokens, this is not the token that is | |||
| used for the end of sequence. The token used is the `sep_token`. | |||
| unk_token (`str`, *optional*, defaults to `"[UNK]"`): | |||
| The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this | |||
| token instead. | |||
| sep_token (`str`, *optional*, defaults to `"[SEP]"`): | |||
| The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for | |||
| sequence classification or for a text and a question for question answering. It is also used as the last | |||
| token of a sequence built with special tokens. | |||
| pad_token (`str`, *optional*, defaults to `"[PAD]"`): | |||
| The token used for padding, for example when batching sequences of different lengths. | |||
| cls_token (`str`, *optional*, defaults to `"[CLS]"`): | |||
| The classifier token which is used when doing sequence classification (classification of the whole sequence | |||
| instead of per-token classification). It is the first token of the sequence when built with special tokens. | |||
| mask_token (`str`, *optional*, defaults to `"[MASK]"`): | |||
| The token used for masking values. This is the token used when training this model with masked language | |||
| modeling. This is the token which the model will try to predict. | |||
| sp_model_kwargs (`dict`, *optional*): | |||
| Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for | |||
| SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, | |||
| to set: | |||
| - `enable_sampling`: Enable subword regularization. | |||
| - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. | |||
| - `nbest_size = {0,1}`: No sampling is performed. | |||
| - `nbest_size > 1`: samples from the nbest_size results. | |||
| - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) | |||
| using forward-filtering-and-backward-sampling algorithm. | |||
| - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for | |||
| BPE-dropout. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| slow_tokenizer_class = DebertaV2Tokenizer | |||
| def __init__(self, | |||
| vocab_file=None, | |||
| tokenizer_file=None, | |||
| do_lower_case=False, | |||
| split_by_punct=False, | |||
| split_chinese=True, | |||
| bos_token='[CLS]', | |||
| eos_token='[SEP]', | |||
| unk_token='[UNK]', | |||
| sep_token='[SEP]', | |||
| pad_token='[PAD]', | |||
| cls_token='[CLS]', | |||
| mask_token='[MASK]', | |||
| **kwargs) -> None: | |||
| super().__init__( | |||
| vocab_file, | |||
| tokenizer_file=tokenizer_file, | |||
| do_lower_case=do_lower_case, | |||
| bos_token=bos_token, | |||
| eos_token=eos_token, | |||
| unk_token=unk_token, | |||
| sep_token=sep_token, | |||
| pad_token=pad_token, | |||
| cls_token=cls_token, | |||
| mask_token=mask_token, | |||
| split_by_punct=split_by_punct, | |||
| split_chinese=split_chinese, | |||
| **kwargs, | |||
| ) | |||
| self.do_lower_case = do_lower_case | |||
| self.split_by_punct = split_by_punct | |||
| self.split_chinese = split_chinese | |||
| self.vocab_file = vocab_file | |||
| self.can_save_slow_tokenizer = False if not self.vocab_file else True | |||
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |||
| """ | |||
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and | |||
| adding special tokens. A DeBERTa sequence has the following format: | |||
| - single sequence: [CLS] X [SEP] | |||
| - pair of sequences: [CLS] A [SEP] B [SEP] | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs to which the special tokens will be added. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. | |||
| """ | |||
| if token_ids_1 is None: | |||
| return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| sep = [self.sep_token_id] | |||
| return cls + token_ids_0 + sep + token_ids_1 + sep | |||
| def get_special_tokens_mask(self, | |||
| token_ids_0, | |||
| token_ids_1=None, | |||
| already_has_special_tokens=False): | |||
| """ | |||
| Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding | |||
| special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| already_has_special_tokens (`bool`, *optional*, defaults to `False`): | |||
| Whether or not the token list is already formatted with special tokens for the model. | |||
| Returns: | |||
| `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. | |||
| """ | |||
| if already_has_special_tokens: | |||
| return super().get_special_tokens_mask( | |||
| token_ids_0=token_ids_0, | |||
| token_ids_1=token_ids_1, | |||
| already_has_special_tokens=True) | |||
| if token_ids_1 is not None: | |||
| return [1] + ([0] * len(token_ids_0)) + [1] + ( | |||
| [0] * len(token_ids_1)) + [1] | |||
| return [1] + ([0] * len(token_ids_0)) + [1] | |||
| def create_token_type_ids_from_sequences(self, | |||
| token_ids_0, | |||
| token_ids_1=None): | |||
| """ | |||
| Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa | |||
| sequence pair mask has the following format: | |||
| ``` | |||
| 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | |||
| | first sequence | second sequence | | |||
| ``` | |||
| If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). | |||
| Args: | |||
| token_ids_0 (`List[int]`): | |||
| List of IDs. | |||
| token_ids_1 (`List[int]`, *optional*): | |||
| Optional second list of IDs for sequence pairs. | |||
| Returns: | |||
| `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). | |||
| """ | |||
| sep = [self.sep_token_id] | |||
| cls = [self.cls_token_id] | |||
| if token_ids_1 is None: | |||
| return len(cls + token_ids_0 + sep) * [0] | |||
| return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 | |||
| + sep) * [1] | |||
| def save_vocabulary(self, | |||
| save_directory: str, | |||
| filename_prefix: Optional[str] = None) -> Tuple[str]: | |||
| if not self.can_save_slow_tokenizer: | |||
| raise ValueError( | |||
| 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' | |||
| 'tokenizer.') | |||
| if not os.path.isdir(save_directory): | |||
| logger.error( | |||
| f'Vocabulary path ({save_directory}) should be a directory') | |||
| return | |||
| out_vocab_file = os.path.join( | |||
| save_directory, (filename_prefix + '-' if filename_prefix else '') | |||
| + VOCAB_FILES_NAMES['vocab_file']) | |||
| if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): | |||
| copyfile(self.vocab_file, out_vocab_file) | |||
| return (out_vocab_file, ) | |||
| @@ -6,6 +6,8 @@ from transformers import BertForMaskedLM as BertForMaskedLMTransformer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.deberta_v2 import \ | |||
| DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer | |||
| from modelscope.models.nlp.structbert import SbertForMaskedLM | |||
| from modelscope.models.nlp.veco import \ | |||
| VecoForMaskedLM as VecoForMaskedLMTransformer | |||
| @@ -125,3 +127,40 @@ class VecoForMaskedLM(TorchModel, VecoForMaskedLMTransformer): | |||
| VecoForMaskedLM).from_pretrained( | |||
| pretrained_model_name_or_path=model_dir, | |||
| model_dir=model_dir) | |||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.deberta_v2) | |||
| class DebertaV2ForMaskedLM(TorchModel, DebertaV2ForMaskedLMTransformer): | |||
| """Deberta v2 for MLM model. | |||
| Inherited from deberta_v2.DebertaV2ForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
| """ | |||
| def __init__(self, config, model_dir): | |||
| super(TorchModel, self).__init__(model_dir) | |||
| DebertaV2ForMaskedLMTransformer.__init__(self, config) | |||
| def forward(self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| labels=None): | |||
| output = DebertaV2ForMaskedLMTransformer.forward( | |||
| self, | |||
| input_ids=input_ids, | |||
| attention_mask=attention_mask, | |||
| token_type_ids=token_type_ids, | |||
| position_ids=position_ids, | |||
| labels=labels) | |||
| output[OutputKeys.INPUT_IDS] = input_ids | |||
| return output | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| model_dir = kwargs.get('model_dir') | |||
| return super(DebertaV2ForMaskedLMTransformer, | |||
| DebertaV2ForMaskedLM).from_pretrained( | |||
| pretrained_model_name_or_path=model_dir, | |||
| model_dir=model_dir) | |||
| @@ -13,7 +13,10 @@ from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['FillMaskPipeline'] | |||
| _type_map = {'veco': 'roberta', 'sbert': 'bert'} | |||
| _type_map = { | |||
| 'veco': 'roberta', | |||
| 'sbert': 'bert', | |||
| } | |||
| @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) | |||
| @@ -65,7 +68,7 @@ class FillMaskPipeline(Pipeline): | |||
| self.config = Config.from_file( | |||
| os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.mask_id = {'roberta': 250001, 'bert': 103} | |||
| self.mask_id = {'roberta': 250001, 'bert': 103, 'deberta_v2': 4} | |||
| self.rep_map = { | |||
| 'bert': { | |||
| @@ -85,7 +88,14 @@ class FillMaskPipeline(Pipeline): | |||
| '<s>': '', | |||
| '</s>': '', | |||
| '<unk>': ' ' | |||
| } | |||
| }, | |||
| 'deberta_v2': { | |||
| '[PAD]': '', | |||
| r' +': ' ', | |||
| '[SEP]': '', | |||
| '[CLS]': '', | |||
| '[UNK]': '' | |||
| }, | |||
| } | |||
| def forward(self, inputs: Dict[str, Any], | |||
| @@ -170,6 +170,9 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| elif model_type == Models.veco: | |||
| from modelscope.models.nlp.veco import VecoTokenizer | |||
| return VecoTokenizer.from_pretrained(model_dir) | |||
| elif model_type == Models.deberta_v2: | |||
| from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer | |||
| return DebertaV2Tokenizer.from_pretrained(model_dir) | |||
| else: | |||
| return AutoTokenizer.from_pretrained(model_dir, use_fast=False) | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import torch | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import DebertaV2ForMaskedLM | |||
| from modelscope.models.nlp.deberta_v2 import (DebertaV2Tokenizer, | |||
| DebertaV2TokenizerFast) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import FillMaskPipeline | |||
| from modelscope.preprocessors import FillMaskPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class DeBERTaV2TaskTest(unittest.TestCase): | |||
| model_id_deberta = 'damo/nlp_debertav2_fill-mask_chinese-lite' | |||
| ori_text = '你师父差得动你,你师父可差不动我。' | |||
| test_input = '你师父差得动你,你师父可[MASK]不动我。' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| model_dir = snapshot_download(self.model_id_deberta) | |||
| preprocessor = FillMaskPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = DebertaV2ForMaskedLM.from_pretrained(model_dir) | |||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| ori_text = self.ori_text | |||
| test_input = self.test_input | |||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' | |||
| f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| # sbert | |||
| print(self.model_id_deberta) | |||
| model = Model.from_pretrained(self.model_id_deberta) | |||
| preprocessor = FillMaskPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| print( | |||
| f'\nori_text: {self.ori_text}\ninput: {self.test_input}\npipeline: ' | |||
| f'{pipeline_ins(self.test_input)}\n') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, model=self.model_id_deberta) | |||
| ori_text = self.ori_text | |||
| test_input = self.test_input | |||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||
| f'{pipeline_ins(test_input)}\n') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||