Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10019083master
| @@ -62,6 +62,7 @@ class Models(object): | |||
| gpt3 = 'gpt3' | |||
| plug = 'plug' | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| ponet = 'ponet' | |||
| # audio models | |||
| sambert_hifigan = 'sambert-hifigan' | |||
| @@ -179,6 +180,7 @@ class Pipelines(object): | |||
| sentiment_classification = 'sentiment-classification' | |||
| text_classification = 'text-classification' | |||
| fill_mask = 'fill-mask' | |||
| fill_mask_ponet = 'fill-mask-ponet' | |||
| csanmt_translation = 'csanmt-translation' | |||
| nli = 'nli' | |||
| dialog_intent_prediction = 'dialog-intent-prediction' | |||
| @@ -281,6 +283,7 @@ class Preprocessors(object): | |||
| sequence_labeling_tokenizer = 'sequence-labeling-tokenizer' | |||
| word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor' | |||
| fill_mask = 'fill-mask' | |||
| fill_mask_ponet = 'fill-mask-ponet' | |||
| faq_question_answering_preprocessor = 'faq-question-answering-preprocessor' | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| re_tokenizer = 're-tokenizer' | |||
| @@ -13,6 +13,7 @@ if TYPE_CHECKING: | |||
| from .gpt3 import GPT3ForTextGeneration | |||
| from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM, | |||
| BertForMaskedLM, DebertaV2ForMaskedLM) | |||
| from .ponet_for_masked_language import PoNetForMaskedLM | |||
| from .nncrf_for_named_entity_recognition import ( | |||
| TransformerCRFForNamedEntityRecognition, | |||
| LSTMCRFForNamedEntityRecognition) | |||
| @@ -46,6 +47,7 @@ else: | |||
| 'TransformerCRFForNamedEntityRecognition', | |||
| 'LSTMCRFForNamedEntityRecognition' | |||
| ], | |||
| 'ponet_for_masked_language': ['PoNetForMaskedLM'], | |||
| 'palm_v2': ['PalmForTextGeneration'], | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
| 'star_text_to_sql': ['StarForTextToSql'], | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO Team Authors. | |||
| # All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .configuration_ponet import PoNetConfig | |||
| from .modeling_ponet import (PoNetForMaskedLM, PoNetModel, | |||
| PoNetPreTrainedModel) | |||
| from .tokenization_ponet import PoNetTokenizer | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_ponet': ['PoNetConfig'], | |||
| 'modeling_ponet': | |||
| ['PoNetForMaskedLM', 'PoNetModel', 'PoNetPreTrainedModel'], | |||
| 'tokenization_ponet': ['PoNetTokenizer'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO Team Authors. | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ PoNet model configuration, mainly copied from :class:`~transformers.BertConfig` """ | |||
| from transformers import PretrainedConfig | |||
| from modelscope.utils import logger as logging | |||
| logger = logging.get_logger(__name__) | |||
| class PoNetConfig(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration | |||
| of a :class:`~modelscope.models.nlp.ponet.PoNetModel`. | |||
| It is used to instantiate a PoNet model according to the specified arguments. | |||
| Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | |||
| outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | |||
| Args: | |||
| vocab_size (:obj:`int`, `optional`, defaults to 30522): | |||
| Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the | |||
| :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or | |||
| :class:`~transformers.TFBertModel`. | |||
| hidden_size (:obj:`int`, `optional`, defaults to 768): | |||
| Dimensionality of the encoder layers and the pooler layer. | |||
| num_hidden_layers (:obj:`int`, `optional`, defaults to 12): | |||
| Number of hidden layers in the Transformer encoder. | |||
| num_attention_heads (:obj:`int`, `optional`, defaults to 12): | |||
| Number of attention heads for each attention layer in the Transformer encoder. | |||
| intermediate_size (:obj:`int`, `optional`, defaults to 3072): | |||
| Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. | |||
| hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): | |||
| The non-linear activation function (function or string) in the encoder and pooler. If string, | |||
| :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. | |||
| hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||
| The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |||
| attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): | |||
| The dropout ratio for the attention probabilities. | |||
| max_position_embeddings (:obj:`int`, `optional`, defaults to 512): | |||
| The maximum sequence length that this model might ever be used with. Typically set this to something large | |||
| just in case (e.g., 512 or 1024 or 2048). | |||
| type_vocab_size (:obj:`int`, `optional`, defaults to 2): | |||
| The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or | |||
| :class:`~transformers.TFBertModel`. | |||
| initializer_range (:obj:`float`, `optional`, defaults to 0.02): | |||
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |||
| layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): | |||
| The epsilon used by the layer normalization layers. | |||
| position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): | |||
| Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, | |||
| :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on | |||
| :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.) | |||
| <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to | |||
| `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) | |||
| <https://arxiv.org/abs/2009.13658>`__. | |||
| use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||
| Whether or not the model should return the last key/values attentions (not used by all models). Only | |||
| relevant if ``config.is_decoder=True``. | |||
| classifier_dropout (:obj:`float`, `optional`): | |||
| The dropout ratio for the classification head. | |||
| clsgsepg (:obj:`bool`, `optional`, defaults to :obj:`True`): | |||
| Whether or not use a trick to make sure the segment and local information will not leak. | |||
| """ | |||
| model_type = 'ponet' | |||
| def __init__(self, | |||
| vocab_size=30522, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| layer_norm_eps=1e-12, | |||
| pad_token_id=0, | |||
| position_embedding_type='absolute', | |||
| use_cache=True, | |||
| classifier_dropout=None, | |||
| clsgsepg=True, | |||
| **kwargs): | |||
| super().__init__(pad_token_id=pad_token_id, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.layer_norm_eps = layer_norm_eps | |||
| self.position_embedding_type = position_embedding_type | |||
| self.use_cache = use_cache | |||
| self.classifier_dropout = classifier_dropout | |||
| self.clsgsepg = clsgsepg | |||
| @@ -0,0 +1,155 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO Team Authors. | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for PoNet """ | |||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |||
| from transformers.file_utils import PaddingStrategy | |||
| from transformers.models.bert.tokenization_bert import BertTokenizer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| VOCAB_FILES_NAMES = {'vocab_file': ModelFile.VOCAB_FILE} | |||
| PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}} | |||
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | |||
| 'nlp_ponet_fill-mask_chinese-base': 512, | |||
| 'nlp_ponet_fill-mask_english-base': 512, | |||
| } | |||
| PRETRAINED_INIT_CONFIGURATION = { | |||
| 'nlp_ponet_fill-mask_chinese-base': { | |||
| 'do_lower_case': True | |||
| }, | |||
| 'nlp_ponet_fill-mask_english-base': { | |||
| 'do_lower_case': True | |||
| }, | |||
| } | |||
| class PoNetTokenizer(BertTokenizer): | |||
| r""" | |||
| Construct an PoNet tokenizer. Based on BertTokenizer. | |||
| This tokenizer inherits from :class:`~transformers.BertTokenizer` which contains most of the main methods. | |||
| Users should refer to this superclass for more information regarding those methods. | |||
| Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning | |||
| parameters. | |||
| """ | |||
| vocab_files_names = VOCAB_FILES_NAMES | |||
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP | |||
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES | |||
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION | |||
| def _pad( | |||
| self, | |||
| encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], | |||
| max_length: Optional[int] = None, | |||
| padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, | |||
| pad_to_multiple_of: Optional[int] = None, | |||
| return_attention_mask: Optional[bool] = None, | |||
| ) -> dict: | |||
| """ | |||
| Pad encoded inputs (on left/right and up to predefined length or max length in the batch) | |||
| Args: | |||
| encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or | |||
| batch of tokenized inputs (`List[List[int]]`). | |||
| max_length: maximum length of the returned list and optionally padding length (see below). | |||
| Will truncate by taking into account the special tokens. | |||
| padding_strategy: PaddingStrategy to use for padding. | |||
| - PaddingStrategy.LONGEST Pad to the longest sequence in the batch | |||
| - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) | |||
| - PaddingStrategy.DO_NOT_PAD: Do not pad | |||
| The tokenizer padding sides are defined in self.padding_side: | |||
| - 'left': pads on the left of the sequences | |||
| - 'right': pads on the right of the sequences | |||
| pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. | |||
| This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability | |||
| >= 7.5 (Volta). | |||
| return_attention_mask: (optional) Set to False to avoid returning | |||
| attention mask (default: set to model specifics) | |||
| """ | |||
| # Load from model defaults | |||
| if return_attention_mask is None: | |||
| return_attention_mask = 'attention_mask' in self.model_input_names | |||
| required_input = encoded_inputs[self.model_input_names[0]] | |||
| if padding_strategy == PaddingStrategy.LONGEST: | |||
| max_length = len(required_input) | |||
| if max_length is not None and pad_to_multiple_of is not None and ( | |||
| max_length % pad_to_multiple_of != 0): | |||
| max_length = ( | |||
| (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |||
| needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len( | |||
| required_input) != max_length | |||
| if needs_to_be_padded: | |||
| difference = max_length - len(required_input) | |||
| if self.padding_side == 'right': | |||
| if return_attention_mask: | |||
| encoded_inputs['attention_mask'] = [1] * len( | |||
| required_input) + [0] * difference | |||
| if 'token_type_ids' in encoded_inputs: | |||
| encoded_inputs['token_type_ids'] = ( | |||
| encoded_inputs['token_type_ids'] | |||
| + [self.pad_token_type_id] * difference) | |||
| if 'special_tokens_mask' in encoded_inputs: | |||
| encoded_inputs['special_tokens_mask'] = encoded_inputs[ | |||
| 'special_tokens_mask'] + [1] * difference | |||
| if 'segment_ids' in encoded_inputs: | |||
| encoded_inputs[ | |||
| 'segment_ids'] = encoded_inputs['segment_ids'] + [ | |||
| encoded_inputs['segment_ids'][-1] + 1 | |||
| ] * difference # noqa * | |||
| encoded_inputs[self.model_input_names[ | |||
| 0]] = required_input + [self.pad_token_id] * difference | |||
| elif self.padding_side == 'left': | |||
| if return_attention_mask: | |||
| encoded_inputs['attention_mask'] = [0] * difference + [ | |||
| 1 | |||
| ] * len(required_input) | |||
| if 'token_type_ids' in encoded_inputs: | |||
| encoded_inputs['token_type_ids'] = [ | |||
| self.pad_token_type_id | |||
| ] * difference + encoded_inputs['token_type_ids'] | |||
| if 'segment_ids' in encoded_inputs: | |||
| encoded_inputs['segment_ids'] = [encoded_inputs['segment_ids'][-1] + 1] * difference + \ | |||
| encoded_inputs['segment_ids'] # noqa * | |||
| if 'special_tokens_mask' in encoded_inputs: | |||
| encoded_inputs['special_tokens_mask'] = [ | |||
| 1 | |||
| ] * difference + encoded_inputs['special_tokens_mask'] | |||
| encoded_inputs[self.model_input_names[ | |||
| 0]] = [self.pad_token_id] * difference + required_input | |||
| else: | |||
| raise ValueError('Invalid padding strategy:' | |||
| + str(self.padding_side)) | |||
| elif return_attention_mask and 'attention_mask' not in encoded_inputs: | |||
| encoded_inputs['attention_mask'] = [1] * len(required_input) | |||
| return encoded_inputs | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.ponet import \ | |||
| PoNetForMaskedLM as PoNetForMaskedLMTransformer | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['PoNetForMaskedLM'] | |||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet) | |||
| class PoNetForMaskedLM(TorchModel, PoNetForMaskedLMTransformer): | |||
| """PoNet for MLM model.'. | |||
| Inherited from ponet.PoNetForMaskedLM and TorchModel, so this class can be registered into Model sets. | |||
| """ | |||
| def __init__(self, config, model_dir): | |||
| super(TorchModel, self).__init__(model_dir) | |||
| PoNetForMaskedLMTransformer.__init__(self, config) | |||
| def forward(self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| segment_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| labels=None): | |||
| output = PoNetForMaskedLMTransformer.forward( | |||
| self, | |||
| input_ids=input_ids, | |||
| attention_mask=attention_mask, | |||
| token_type_ids=token_type_ids, | |||
| segment_ids=segment_ids, | |||
| position_ids=position_ids, | |||
| head_mask=head_mask, | |||
| labels=labels) | |||
| output[OutputKeys.INPUT_IDS] = input_ids | |||
| return output | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| model_dir = kwargs.get('model_dir') | |||
| return super(PoNetForMaskedLMTransformer, | |||
| PoNetForMaskedLM).from_pretrained( | |||
| pretrained_model_name_or_path=model_dir, | |||
| model_dir=model_dir) | |||
| @@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .fill_mask_ponet_pipeline import FillMaskPoNetPreprocessor | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | |||
| @@ -36,6 +37,7 @@ else: | |||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | |||
| 'fill_mask_ponet_pipeline': ['FillMaskPoNetPipeline'], | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| @@ -0,0 +1,136 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import FillMaskPoNetPreprocessor, Preprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['FillMaskPonetPipeline'] | |||
| _type_map = {'ponet': 'bert'} | |||
| @PIPELINES.register_module( | |||
| Tasks.fill_mask, module_name=Pipelines.fill_mask_ponet) | |||
| class FillMaskPonetPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| first_sequence='sentence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported fill-mask task, | |||
| or a fill-mask model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| first_sequence: The key to read the sentence in. | |||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' | |||
| param will have no effect. | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipeline_ins = pipeline( | |||
| 'fill-mask', model='damo/nlp_ponet_fill-mask_english-base') | |||
| >>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].' | |||
| >>> print(pipeline_ins(input)) | |||
| NOTE2: Please pay attention to the model's special tokens. | |||
| If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'. | |||
| If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'. | |||
| To view other examples plese check the tests/pipelines/test_fill_mask.py. | |||
| """ | |||
| fill_mask_model = model if isinstance( | |||
| model, Model) else Model.from_pretrained(model) | |||
| self.config = Config.from_file( | |||
| os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION)) | |||
| if preprocessor is None: | |||
| preprocessor = FillMaskPoNetPreprocessor( | |||
| fill_mask_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=None, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| fill_mask_model.eval() | |||
| super().__init__( | |||
| model=fill_mask_model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.mask_id = {'roberta': 250001, 'bert': 103} | |||
| self.rep_map = { | |||
| 'bert': { | |||
| '[unused0]': '', | |||
| '[PAD]': '', | |||
| '[unused1]': '', | |||
| r' +': ' ', | |||
| '[SEP]': '', | |||
| '[unused2]': '', | |||
| '[CLS]': '', | |||
| '[UNK]': '' | |||
| }, | |||
| 'roberta': { | |||
| r' +': ' ', | |||
| '<mask>': '<q>', | |||
| '<pad>': '', | |||
| '<s>': '', | |||
| '</s>': '', | |||
| '<unk>': ' ' | |||
| } | |||
| } | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| import numpy as np | |||
| logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy() | |||
| input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy() | |||
| pred_ids = np.argmax(logits, axis=-1) | |||
| model_type = self.model.config.model_type | |||
| process_type = model_type if model_type in self.mask_id else _type_map[ | |||
| model_type] | |||
| rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids, | |||
| input_ids) | |||
| def rep_tokens(string, rep_map): | |||
| for k, v in rep_map.items(): | |||
| string = string.replace(k, v) | |||
| return string.strip() | |||
| pred_strings = [] | |||
| for ids in rst_ids: # batch | |||
| if 'language' in self.config.model and self.config.model.language == 'zh': | |||
| pred_string = self.tokenizer.convert_ids_to_tokens(ids) | |||
| pred_string = ''.join(pred_string) | |||
| else: | |||
| pred_string = self.tokenizer.decode(ids) | |||
| pred_string = rep_tokens(pred_string, self.rep_map[process_type]) | |||
| pred_strings.append(pred_string) | |||
| return {OutputKeys.TEXT: pred_strings} | |||
| @@ -22,8 +22,8 @@ if TYPE_CHECKING: | |||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor) | |||
| from .slp import DocumentSegmentationPreprocessor | |||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | |||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| @@ -52,9 +52,9 @@ else: | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'RelationExtractionPreprocessor' | |||
| 'RelationExtractionPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| ], | |||
| 'slp': ['DocumentSegmentationPreprocessor'], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| 'DialogStateTrackingPreprocessor', 'InputFeatures' | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| import re | |||
| import uuid | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| @@ -11,13 +12,17 @@ from transformers import AutoTokenizer, BertTokenizerFast | |||
| from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.models.nlp.structbert import SbertTokenizerFast | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.config import ConfigFields | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.nlp.nlp_utils import import_external_nltk_data | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| logger = get_logger() | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| @@ -25,7 +30,8 @@ __all__ = [ | |||
| 'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor', | |||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor' | |||
| 'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| ] | |||
| @@ -903,3 +909,297 @@ class FaqQuestionAnsweringPreprocessor(Preprocessor): | |||
| max_length = self.MAX_LEN | |||
| return self.tokenizer.batch_encode_plus( | |||
| sentence_list, padding=True, max_length=max_length) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.document_segmentation) | |||
| class DocumentSegmentationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, config, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, | |||
| use_fast=True, | |||
| ) | |||
| self.question_column_name = 'labels' | |||
| self.context_column_name = 'sentences' | |||
| self.example_id_column_name = 'example_id' | |||
| self.label_to_id = {'B-EOP': 0, 'O': 1} | |||
| self.target_specical_ids = set() | |||
| self.target_specical_ids.add(self.tokenizer.eos_token_id) | |||
| self.max_seq_length = config.max_position_embeddings | |||
| self.label_list = ['B-EOP', 'O'] | |||
| def __call__(self, examples) -> Dict[str, Any]: | |||
| questions = examples[self.question_column_name] | |||
| contexts = examples[self.context_column_name] | |||
| example_ids = examples[self.example_id_column_name] | |||
| num_examples = len(questions) | |||
| sentences = [] | |||
| for sentence_list in contexts: | |||
| sentence_list = [_ + '[EOS]' for _ in sentence_list] | |||
| sentences.append(sentence_list) | |||
| try: | |||
| tokenized_examples = self.tokenizer( | |||
| sentences, | |||
| is_split_into_words=True, | |||
| add_special_tokens=False, | |||
| return_token_type_ids=True, | |||
| return_attention_mask=True, | |||
| ) | |||
| except Exception as e: | |||
| logger.error(e) | |||
| return {} | |||
| segment_ids = [] | |||
| token_seq_labels = [] | |||
| for example_index in range(num_examples): | |||
| example_input_ids = tokenized_examples['input_ids'][example_index] | |||
| example_labels = questions[example_index] | |||
| example_labels = [ | |||
| self.label_to_id[_] if _ in self.label_to_id else -100 | |||
| for _ in example_labels | |||
| ] | |||
| example_token_labels = [] | |||
| segment_id = [] | |||
| cur_seg_id = 1 | |||
| for token_index in range(len(example_input_ids)): | |||
| if example_input_ids[token_index] in self.target_specical_ids: | |||
| example_token_labels.append(example_labels[cur_seg_id - 1]) | |||
| segment_id.append(cur_seg_id) | |||
| cur_seg_id += 1 | |||
| else: | |||
| example_token_labels.append(-100) | |||
| segment_id.append(cur_seg_id) | |||
| segment_ids.append(segment_id) | |||
| token_seq_labels.append(example_token_labels) | |||
| tokenized_examples['segment_ids'] = segment_ids | |||
| tokenized_examples['token_seq_labels'] = token_seq_labels | |||
| new_segment_ids = [] | |||
| new_token_seq_labels = [] | |||
| new_input_ids = [] | |||
| new_token_type_ids = [] | |||
| new_attention_mask = [] | |||
| new_example_ids = [] | |||
| new_sentences = [] | |||
| for example_index in range(num_examples): | |||
| example_input_ids = tokenized_examples['input_ids'][example_index] | |||
| example_token_type_ids = tokenized_examples['token_type_ids'][ | |||
| example_index] | |||
| example_attention_mask = tokenized_examples['attention_mask'][ | |||
| example_index] | |||
| example_segment_ids = tokenized_examples['segment_ids'][ | |||
| example_index] | |||
| example_token_seq_labels = tokenized_examples['token_seq_labels'][ | |||
| example_index] | |||
| example_sentences = contexts[example_index] | |||
| example_id = example_ids[example_index] | |||
| example_total_num_sentences = len(questions[example_index]) | |||
| example_total_num_tokens = len( | |||
| tokenized_examples['input_ids'][example_index]) | |||
| accumulate_length = [ | |||
| i for i, x in enumerate(tokenized_examples['input_ids'] | |||
| [example_index]) | |||
| if x == self.tokenizer.eos_token_id | |||
| ] | |||
| samples_boundary = [] | |||
| left_index = 0 | |||
| sent_left_index = 0 | |||
| sent_i = 0 | |||
| # for sent_i, length in enumerate(accumulate_length): | |||
| while sent_i < len(accumulate_length): | |||
| length = accumulate_length[sent_i] | |||
| right_index = length + 1 | |||
| sent_right_index = sent_i + 1 | |||
| if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens: | |||
| samples_boundary.append([left_index, right_index]) | |||
| sample_input_ids = [ | |||
| self.tokenizer.cls_token_id | |||
| ] + example_input_ids[left_index:right_index] | |||
| sample_input_ids = sample_input_ids[:self.max_seq_length] | |||
| sample_token_type_ids = [ | |||
| 0 | |||
| ] + example_token_type_ids[left_index:right_index] | |||
| sample_token_type_ids = sample_token_type_ids[:self. | |||
| max_seq_length] | |||
| sample_attention_mask = [ | |||
| 1 | |||
| ] + example_attention_mask[left_index:right_index] | |||
| sample_attention_mask = sample_attention_mask[:self. | |||
| max_seq_length] | |||
| sample_segment_ids = [ | |||
| 0 | |||
| ] + example_segment_ids[left_index:right_index] | |||
| sample_segment_ids = sample_segment_ids[:self. | |||
| max_seq_length] | |||
| sample_token_seq_labels = [ | |||
| -100 | |||
| ] + example_token_seq_labels[left_index:right_index] | |||
| sample_token_seq_labels = sample_token_seq_labels[:self. | |||
| max_seq_length] | |||
| if sent_right_index - 1 == sent_left_index: | |||
| left_index = right_index | |||
| sample_input_ids[-1] = self.tokenizer.eos_token_id | |||
| sample_token_seq_labels[-1] = -100 | |||
| else: | |||
| left_index = accumulate_length[sent_i - 1] + 1 | |||
| if sample_token_seq_labels[-1] != -100: | |||
| sample_token_seq_labels[-1] = -100 | |||
| if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens: | |||
| sample_sentences = example_sentences[ | |||
| sent_left_index:sent_right_index] | |||
| sent_left_index = sent_right_index | |||
| sent_i += 1 | |||
| else: | |||
| sample_sentences = example_sentences[ | |||
| sent_left_index:sent_right_index - 1] | |||
| sent_left_index = sent_right_index - 1 | |||
| if (len([_ for _ in sample_token_seq_labels if _ != -100 | |||
| ])) != len(sample_sentences) - 1 and (len([ | |||
| _ | |||
| for _ in sample_token_seq_labels if _ != -100 | |||
| ])) != len(sample_sentences): | |||
| tmp = [] | |||
| for w_i, w, l in zip( | |||
| sample_input_ids, | |||
| self.tokenizer.decode(sample_input_ids).split( | |||
| ' '), sample_token_seq_labels): | |||
| tmp.append((w_i, w, l)) | |||
| while len(sample_input_ids) < self.max_seq_length: | |||
| sample_input_ids.append(self.tokenizer.pad_token_id) | |||
| sample_token_type_ids.append(0) | |||
| sample_attention_mask.append(0) | |||
| sample_segment_ids.append(example_total_num_sentences | |||
| + 1) | |||
| sample_token_seq_labels.append(-100) | |||
| new_input_ids.append(sample_input_ids) | |||
| new_token_type_ids.append(sample_token_type_ids) | |||
| new_attention_mask.append(sample_attention_mask) | |||
| new_segment_ids.append(sample_segment_ids) | |||
| new_token_seq_labels.append(sample_token_seq_labels) | |||
| new_example_ids.append(example_id) | |||
| new_sentences.append(sample_sentences) | |||
| else: | |||
| sent_i += 1 | |||
| continue | |||
| output_samples = {} | |||
| output_samples['input_ids'] = new_input_ids | |||
| output_samples['token_type_ids'] = new_token_type_ids | |||
| output_samples['attention_mask'] = new_attention_mask | |||
| output_samples['segment_ids'] = new_segment_ids | |||
| output_samples['example_id'] = new_example_ids | |||
| output_samples['labels'] = new_token_seq_labels | |||
| output_samples['sentences'] = new_sentences | |||
| return output_samples | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.fill_mask_ponet) | |||
| class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in MLM task. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 512) | |||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||
| True) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if self.language == 'en': | |||
| from nltk.tokenize import sent_tokenize | |||
| import_external_nltk_data( | |||
| osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt') | |||
| elif self.language in ['zh', 'cn']: | |||
| def sent_tokenize(para): | |||
| para = re.sub(r'([。!!?\?])([^”’])', r'\1\n\2', para) # noqa * | |||
| para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * | |||
| para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * | |||
| para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', | |||
| para) # noqa * | |||
| para = para.rstrip() | |||
| return [_ for _ in para.split('\n') if _] | |||
| else: | |||
| raise NotImplementedError | |||
| self.sent_tokenize = sent_tokenize | |||
| self.max_length = kwargs['max_length'] | |||
| def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (tuple): [sentence1, sentence2] | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| text_a, text_b, labels = self.parse_text_and_label(data) | |||
| output = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||
| **self.tokenize_kwargs) | |||
| max_seq_length = self.max_length | |||
| if text_b is None: | |||
| segment_ids = [] | |||
| seg_lens = list( | |||
| map( | |||
| len, | |||
| self.tokenizer( | |||
| self.sent_tokenize(text_a), | |||
| add_special_tokens=False, | |||
| truncation=True)['input_ids'])) | |||
| segment_id = [0] + sum( | |||
| [[i] * sl for i, sl in enumerate(seg_lens, start=1)], []) | |||
| segment_id = segment_id[:max_seq_length - 1] | |||
| segment_ids.append(segment_id + [segment_id[-1] + 1] | |||
| * (max_seq_length - len(segment_id))) | |||
| output['segment_ids'] = segment_ids | |||
| output = { | |||
| k: np.array(v) if isinstance(v, list) else v | |||
| for k, v in output.items() | |||
| } | |||
| self.labels_to_id(labels, output) | |||
| return output | |||
| @@ -1,223 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict | |||
| from transformers import BertTokenizerFast | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['DocumentSegmentationPreprocessor'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.document_segmentation) | |||
| class DocumentSegmentationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, config, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, | |||
| use_fast=True, | |||
| ) | |||
| self.question_column_name = 'labels' | |||
| self.context_column_name = 'sentences' | |||
| self.example_id_column_name = 'example_id' | |||
| self.label_to_id = {'B-EOP': 0, 'O': 1} | |||
| self.target_specical_ids = set() | |||
| self.target_specical_ids.add(self.tokenizer.eos_token_id) | |||
| self.max_seq_length = config.max_position_embeddings | |||
| self.label_list = ['B-EOP', 'O'] | |||
| def __call__(self, examples) -> Dict[str, Any]: | |||
| questions = examples[self.question_column_name] | |||
| contexts = examples[self.context_column_name] | |||
| example_ids = examples[self.example_id_column_name] | |||
| num_examples = len(questions) | |||
| sentences = [] | |||
| for sentence_list in contexts: | |||
| sentence_list = [_ + '[EOS]' for _ in sentence_list] | |||
| sentences.append(sentence_list) | |||
| try: | |||
| tokenized_examples = self.tokenizer( | |||
| sentences, | |||
| is_split_into_words=True, | |||
| add_special_tokens=False, | |||
| return_token_type_ids=True, | |||
| return_attention_mask=True, | |||
| ) | |||
| except Exception as e: | |||
| print(str(e)) | |||
| return {} | |||
| segment_ids = [] | |||
| token_seq_labels = [] | |||
| for example_index in range(num_examples): | |||
| example_input_ids = tokenized_examples['input_ids'][example_index] | |||
| example_labels = questions[example_index] | |||
| example_labels = [ | |||
| self.label_to_id[_] if _ in self.label_to_id else -100 | |||
| for _ in example_labels | |||
| ] | |||
| example_token_labels = [] | |||
| segment_id = [] | |||
| cur_seg_id = 1 | |||
| for token_index in range(len(example_input_ids)): | |||
| if example_input_ids[token_index] in self.target_specical_ids: | |||
| example_token_labels.append(example_labels[cur_seg_id - 1]) | |||
| segment_id.append(cur_seg_id) | |||
| cur_seg_id += 1 | |||
| else: | |||
| example_token_labels.append(-100) | |||
| segment_id.append(cur_seg_id) | |||
| segment_ids.append(segment_id) | |||
| token_seq_labels.append(example_token_labels) | |||
| tokenized_examples['segment_ids'] = segment_ids | |||
| tokenized_examples['token_seq_labels'] = token_seq_labels | |||
| new_segment_ids = [] | |||
| new_token_seq_labels = [] | |||
| new_input_ids = [] | |||
| new_token_type_ids = [] | |||
| new_attention_mask = [] | |||
| new_example_ids = [] | |||
| new_sentences = [] | |||
| for example_index in range(num_examples): | |||
| example_input_ids = tokenized_examples['input_ids'][example_index] | |||
| example_token_type_ids = tokenized_examples['token_type_ids'][ | |||
| example_index] | |||
| example_attention_mask = tokenized_examples['attention_mask'][ | |||
| example_index] | |||
| example_segment_ids = tokenized_examples['segment_ids'][ | |||
| example_index] | |||
| example_token_seq_labels = tokenized_examples['token_seq_labels'][ | |||
| example_index] | |||
| example_sentences = contexts[example_index] | |||
| example_id = example_ids[example_index] | |||
| example_total_num_sentences = len(questions[example_index]) | |||
| example_total_num_tokens = len( | |||
| tokenized_examples['input_ids'][example_index]) | |||
| accumulate_length = [ | |||
| i for i, x in enumerate(tokenized_examples['input_ids'] | |||
| [example_index]) | |||
| if x == self.tokenizer.eos_token_id | |||
| ] | |||
| samples_boundary = [] | |||
| left_index = 0 | |||
| sent_left_index = 0 | |||
| sent_i = 0 | |||
| # for sent_i, length in enumerate(accumulate_length): | |||
| while sent_i < len(accumulate_length): | |||
| length = accumulate_length[sent_i] | |||
| right_index = length + 1 | |||
| sent_right_index = sent_i + 1 | |||
| if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens: | |||
| samples_boundary.append([left_index, right_index]) | |||
| sample_input_ids = [ | |||
| self.tokenizer.cls_token_id | |||
| ] + example_input_ids[left_index:right_index] | |||
| sample_input_ids = sample_input_ids[:self.max_seq_length] | |||
| sample_token_type_ids = [ | |||
| 0 | |||
| ] + example_token_type_ids[left_index:right_index] | |||
| sample_token_type_ids = sample_token_type_ids[:self. | |||
| max_seq_length] | |||
| sample_attention_mask = [ | |||
| 1 | |||
| ] + example_attention_mask[left_index:right_index] | |||
| sample_attention_mask = sample_attention_mask[:self. | |||
| max_seq_length] | |||
| sample_segment_ids = [ | |||
| 0 | |||
| ] + example_segment_ids[left_index:right_index] | |||
| sample_segment_ids = sample_segment_ids[:self. | |||
| max_seq_length] | |||
| sample_token_seq_labels = [ | |||
| -100 | |||
| ] + example_token_seq_labels[left_index:right_index] | |||
| sample_token_seq_labels = sample_token_seq_labels[:self. | |||
| max_seq_length] | |||
| if sent_right_index - 1 == sent_left_index: | |||
| left_index = right_index | |||
| sample_input_ids[-1] = self.tokenizer.eos_token_id | |||
| sample_token_seq_labels[-1] = -100 | |||
| else: | |||
| left_index = accumulate_length[sent_i - 1] + 1 | |||
| if sample_token_seq_labels[-1] != -100: | |||
| sample_token_seq_labels[-1] = -100 | |||
| if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens: | |||
| sample_sentences = example_sentences[ | |||
| sent_left_index:sent_right_index] | |||
| sent_left_index = sent_right_index | |||
| sent_i += 1 | |||
| else: | |||
| sample_sentences = example_sentences[ | |||
| sent_left_index:sent_right_index - 1] | |||
| sent_left_index = sent_right_index - 1 | |||
| if (len([_ for _ in sample_token_seq_labels if _ != -100 | |||
| ])) != len(sample_sentences) - 1 and (len([ | |||
| _ | |||
| for _ in sample_token_seq_labels if _ != -100 | |||
| ])) != len(sample_sentences): | |||
| tmp = [] | |||
| for w_i, w, l in zip( | |||
| sample_input_ids, | |||
| self.tokenizer.decode(sample_input_ids).split( | |||
| ' '), sample_token_seq_labels): | |||
| tmp.append((w_i, w, l)) | |||
| while len(sample_input_ids) < self.max_seq_length: | |||
| sample_input_ids.append(self.tokenizer.pad_token_id) | |||
| sample_token_type_ids.append(0) | |||
| sample_attention_mask.append(0) | |||
| sample_segment_ids.append(example_total_num_sentences | |||
| + 1) | |||
| sample_token_seq_labels.append(-100) | |||
| new_input_ids.append(sample_input_ids) | |||
| new_token_type_ids.append(sample_token_type_ids) | |||
| new_attention_mask.append(sample_attention_mask) | |||
| new_segment_ids.append(sample_segment_ids) | |||
| new_token_seq_labels.append(sample_token_seq_labels) | |||
| new_example_ids.append(example_id) | |||
| new_sentences.append(sample_sentences) | |||
| else: | |||
| sent_i += 1 | |||
| continue | |||
| output_samples = {} | |||
| output_samples['input_ids'] = new_input_ids | |||
| output_samples['token_type_ids'] = new_token_type_ids | |||
| output_samples['attention_mask'] = new_attention_mask | |||
| output_samples['segment_ids'] = new_segment_ids | |||
| output_samples['example_id'] = new_example_ids | |||
| output_samples['labels'] = new_token_seq_labels | |||
| output_samples['sentences'] = new_sentences | |||
| return output_samples | |||
| @@ -1,3 +1,4 @@ | |||
| import os.path as osp | |||
| from typing import List | |||
| from modelscope.outputs import OutputKeys | |||
| @@ -41,3 +42,22 @@ def tracking_and_print_dialog_states( | |||
| print(json.dumps(result)) | |||
| history_states.extend([result[OutputKeys.OUTPUT], {}]) | |||
| def import_external_nltk_data(nltk_data_dir, package_name): | |||
| """import external nltk_data, and extract nltk zip package. | |||
| Args: | |||
| nltk_data_dir (str): external nltk_data dir path, eg. /home/xx/nltk_data | |||
| package_name (str): nltk package name, eg. tokenizers/punkt | |||
| """ | |||
| import nltk | |||
| nltk.data.path.append(nltk_data_dir) | |||
| filepath = osp.join(nltk_data_dir, package_name + '.zip') | |||
| zippath = osp.join(nltk_data_dir, package_name) | |||
| packagepath = osp.dirname(zippath) | |||
| if not osp.exists(zippath): | |||
| import zipfile | |||
| with zipfile.ZipFile(filepath) as zf: | |||
| zf.extractall(osp.join(packagepath)) | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class FillMaskPonetTest(unittest.TestCase): | |||
| model_id_ponet = { | |||
| 'zh': 'damo/nlp_ponet_fill-mask_chinese-base', | |||
| 'en': 'damo/nlp_ponet_fill-mask_english-base' | |||
| } | |||
| ori_texts = { | |||
| 'zh': | |||
| '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。' | |||
| '你师父差得动你,你师父可差不动我。', | |||
| 'en': | |||
| 'Everything in what you call reality is really just a reflection of your ' | |||
| 'consciousness. Your whole universe is just a mirror reflection of your story.' | |||
| } | |||
| test_inputs = { | |||
| 'zh': | |||
| '段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你' | |||
| '师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。', | |||
| 'en': | |||
| 'Everything in [MASK] you call reality is really [MASK] a reflection of your ' | |||
| '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' | |||
| } | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_ponet_model(self): | |||
| for language in ['zh', 'en']: | |||
| ori_text = self.ori_texts[language] | |||
| test_input = self.test_inputs[language] | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, model=self.model_id_ponet[language]) | |||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||
| f'{pipeline_ins(test_input)}\n') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||