1.新增支持原始bert模型(非easynlp的 backbone prefix版本)
2.支持bert的在sequence classification/fill mask /token classification上的backbone head形式
3.统一了sequence classification几个任务的pipeline到一个类
4.fill mask 支持backbone head形式
5.token classification的几个子任务(ner,word seg, part of speech)的preprocessor 统一到了一起TokenClassificationPreprocessor
6. sequence classification的几个子任务(single classification, pair classification)的preprocessor 统一到了一起SequenceClassificationPreprocessor
7. 改动register中 cls的group_key 赋值位置,之前的group_key在多个decorators的情况下,会被覆盖,obj_cls的group_key信息不正确
8. 基于backbone head形式将 原本group_key和 module同名的情况尝试做调整,如下在modelscope/pipelines/nlp/sequence_classification_pipeline.py 中
原本
@PIPELINES.register_module(
Tasks.sentiment_classification, module_name=Pipelines.sentiment_classification)
改成
@PIPELINES.register_module(
Tasks.text_classification, module_name=Pipelines.sentiment_classification)
相应的configuration.json也有改动,这样的改动更符合任务和pipline(子任务)的关系。
8. 其他相应改动为支持上述功能
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10041463
master
| @@ -91,17 +91,22 @@ class TaskModels(object): | |||
| text_classification = 'text-classification' | |||
| token_classification = 'token-classification' | |||
| information_extraction = 'information-extraction' | |||
| fill_mask = 'fill-mask' | |||
| feature_extraction = 'feature-extraction' | |||
| class Heads(object): | |||
| # nlp heads | |||
| # text cls | |||
| text_classification = 'text-classification' | |||
| # mlm | |||
| # fill mask | |||
| fill_mask = 'fill-mask' | |||
| bert_mlm = 'bert-mlm' | |||
| # roberta mlm | |||
| roberta_mlm = 'roberta-mlm' | |||
| # token cls | |||
| token_classification = 'token-classification' | |||
| # extraction | |||
| information_extraction = 'information-extraction' | |||
| @@ -203,6 +208,7 @@ class Pipelines(object): | |||
| passage_ranking = 'passage-ranking' | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| # audio tasks | |||
| sambert_hifigan_tts = 'sambert-hifigan-tts' | |||
| @@ -306,6 +312,7 @@ class Preprocessors(object): | |||
| table_question_answering_preprocessor = 'table-question-answering-preprocessor' | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -37,13 +37,16 @@ def build_backbone(cfg: ConfigDict, | |||
| cfg, BACKBONES, group_key=field, default_args=default_args) | |||
| def build_head(cfg: ConfigDict, default_args: dict = None): | |||
| def build_head(cfg: ConfigDict, | |||
| group_key: str = None, | |||
| default_args: dict = None): | |||
| """ build head given config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for head object. | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| if group_key is None: | |||
| group_key = cfg[TYPE_NAME] | |||
| return build_from_cfg( | |||
| cfg, HEADS, group_key=cfg[TYPE_NAME], default_args=default_args) | |||
| cfg, HEADS, group_key=group_key, default_args=default_args) | |||
| @@ -6,7 +6,6 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .backbones import SbertModel | |||
| from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
| from .bert_for_sequence_classification import BertForSequenceClassification | |||
| from .bert_for_document_segmentation import BertForDocumentSegmentation | |||
| from .csanmt_for_translation import CsanmtForTranslation | |||
| from .heads import SequenceClassificationHead | |||
| @@ -20,12 +19,15 @@ if TYPE_CHECKING: | |||
| from .palm_v2 import PalmForTextGeneration | |||
| from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
| from .star_text_to_sql import StarForTextToSql | |||
| from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification | |||
| from .sequence_classification import (VecoForSequenceClassification, | |||
| SbertForSequenceClassification, | |||
| BertForSequenceClassification) | |||
| from .space import SpaceForDialogIntent | |||
| from .space import SpaceForDialogModeling | |||
| from .space import SpaceForDialogStateTracking | |||
| from .table_question_answering import TableQuestionAnswering | |||
| from .task_models import (InformationExtractionModel, | |||
| from .task_models import (FeatureExtractionModel, | |||
| InformationExtractionModel, | |||
| SequenceClassificationModel, | |||
| SingleBackboneTaskModelBase, | |||
| TokenClassificationModel) | |||
| @@ -37,7 +39,6 @@ else: | |||
| _import_structure = { | |||
| 'backbones': ['SbertModel'], | |||
| 'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
| 'bert_for_sequence_classification': ['BertForSequenceClassification'], | |||
| 'bert_for_document_segmentation': ['BertForDocumentSegmentation'], | |||
| 'csanmt_for_translation': ['CsanmtForTranslation'], | |||
| 'heads': ['SequenceClassificationHead'], | |||
| @@ -54,15 +55,20 @@ else: | |||
| 'palm_v2': ['PalmForTextGeneration'], | |||
| 'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
| 'star_text_to_sql': ['StarForTextToSql'], | |||
| 'sequence_classification': | |||
| ['VecoForSequenceClassification', 'SbertForSequenceClassification'], | |||
| 'sequence_classification': [ | |||
| 'VecoForSequenceClassification', 'SbertForSequenceClassification', | |||
| 'BertForSequenceClassification' | |||
| ], | |||
| 'space': [ | |||
| 'SpaceForDialogIntent', 'SpaceForDialogModeling', | |||
| 'SpaceForDialogStateTracking' | |||
| ], | |||
| 'task_models': [ | |||
| 'InformationExtractionModel', 'SequenceClassificationModel', | |||
| 'SingleBackboneTaskModelBase', 'TokenClassificationModel' | |||
| 'FeatureExtractionModel', | |||
| 'InformationExtractionModel', | |||
| 'SequenceClassificationModel', | |||
| 'SingleBackboneTaskModelBase', | |||
| 'TokenClassificationModel', | |||
| ], | |||
| 'token_classification': ['SbertForTokenClassification'], | |||
| 'table_question_answering': ['TableQuestionAnswering'], | |||
| @@ -0,0 +1,7 @@ | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.builder import BACKBONES | |||
| from modelscope.models.nlp.bert import BertModel | |||
| from modelscope.utils.constant import Fields | |||
| BACKBONES.register_module( | |||
| group_key=Fields.nlp, module_name=Models.bert, module_cls=BertModel) | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .modeling_bert import ( | |||
| BERT_PRETRAINED_MODEL_ARCHIVE_LIST, | |||
| BertForMaskedLM, | |||
| BertForMultipleChoice, | |||
| BertForNextSentencePrediction, | |||
| BertForPreTraining, | |||
| BertForQuestionAnswering, | |||
| BertForSequenceClassification, | |||
| BertForTokenClassification, | |||
| BertLayer, | |||
| BertLMHeadModel, | |||
| BertModel, | |||
| BertPreTrainedModel, | |||
| load_tf_weights_in_bert, | |||
| ) | |||
| from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig | |||
| from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer | |||
| from .tokenization_bert_fast import BertTokenizerFast | |||
| else: | |||
| _import_structure = { | |||
| 'configuration_bert': | |||
| ['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'], | |||
| 'tokenization_bert': | |||
| ['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'], | |||
| } | |||
| _import_structure['tokenization_bert_fast'] = ['BertTokenizerFast'] | |||
| _import_structure['modeling_bert'] = [ | |||
| 'BERT_PRETRAINED_MODEL_ARCHIVE_LIST', | |||
| 'BertForMaskedLM', | |||
| 'BertForMultipleChoice', | |||
| 'BertForNextSentencePrediction', | |||
| 'BertForPreTraining', | |||
| 'BertForQuestionAnswering', | |||
| 'BertForSequenceClassification', | |||
| 'BertForTokenClassification', | |||
| 'BertLayer', | |||
| 'BertLMHeadModel', | |||
| 'BertModel', | |||
| 'BertPreTrainedModel', | |||
| 'load_tf_weights_in_bert', | |||
| ] | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,162 @@ | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ BERT model configuration """ | |||
| from collections import OrderedDict | |||
| from typing import Mapping | |||
| from transformers.configuration_utils import PretrainedConfig | |||
| from transformers.onnx import OnnxConfig | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| class BertConfig(PretrainedConfig): | |||
| r""" | |||
| This is the configuration class to store the configuration of a | |||
| [`BertModel`] or a [`TFBertModel`]. It is used to instantiate a BERT model | |||
| according to the specified arguments, defining the model architecture. | |||
| Instantiating a configuration with the defaults will yield a similar | |||
| configuration to that of the BERT | |||
| [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. | |||
| Configuration objects inherit from [`PretrainedConfig`] and can be used to | |||
| control the model outputs. Read the documentation from [`PretrainedConfig`] | |||
| for more information. | |||
| Args: | |||
| vocab_size (`int`, *optional*, defaults to 30522): | |||
| Vocabulary size of the BERT model. Defines the number of different | |||
| tokens that can be represented by the `inputs_ids` passed when | |||
| calling [`BertModel`] or [`TFBertModel`]. | |||
| hidden_size (`int`, *optional*, defaults to 768): | |||
| Dimensionality of the encoder layers and the pooler layer. | |||
| num_hidden_layers (`int`, *optional*, defaults to 12): | |||
| Number of hidden layers in the Transformer encoder. | |||
| num_attention_heads (`int`, *optional*, defaults to 12): | |||
| Number of attention heads for each attention layer in the | |||
| Transformer encoder. | |||
| intermediate_size (`int`, *optional*, defaults to 3072): | |||
| Dimensionality of the "intermediate" (often named feed-forward) | |||
| layer in the Transformer encoder. | |||
| hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): | |||
| The non-linear activation function (function or string) in the | |||
| encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and | |||
| `"gelu_new"` are supported. | |||
| hidden_dropout_prob (`float`, *optional*, defaults to 0.1): | |||
| The dropout probability for all fully connected layers in the | |||
| embeddings, encoder, and pooler. | |||
| attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): | |||
| The dropout ratio for the attention probabilities. | |||
| max_position_embeddings (`int`, *optional*, defaults to 512): | |||
| The maximum sequence length that this model might ever be used with. | |||
| Typically set this to something large just in case (e.g., 512 or | |||
| 1024 or 2048). | |||
| type_vocab_size (`int`, *optional*, defaults to 2): | |||
| The vocabulary size of the `token_type_ids` passed when calling | |||
| [`BertModel`] or [`TFBertModel`]. | |||
| initializer_range (`float`, *optional*, defaults to 0.02): | |||
| The standard deviation of the truncated_normal_initializer for | |||
| initializing all weight matrices. | |||
| layer_norm_eps (`float`, *optional*, defaults to 1e-12): | |||
| The epsilon used by the layer normalization layers. | |||
| position_embedding_type (`str`, *optional*, defaults to `"absolute"`): | |||
| Type of position embedding. Choose one of `"absolute"`, | |||
| `"relative_key"`, `"relative_key_query"`. For positional embeddings | |||
| use `"absolute"`. For more information on `"relative_key"`, please | |||
| refer to [Self-Attention with Relative Position Representations | |||
| (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more | |||
| information on `"relative_key_query"`, please refer to *Method 4* in | |||
| [Improve Transformer Models with Better Relative Position Embeddings | |||
| (Huang et al.)](https://arxiv.org/abs/2009.13658). | |||
| use_cache (`bool`, *optional*, defaults to `True`): | |||
| Whether or not the model should return the last key/values | |||
| attentions (not used by all models). Only relevant if | |||
| `config.is_decoder=True`. | |||
| classifier_dropout (`float`, *optional*): | |||
| The dropout ratio for the classification head. | |||
| Examples: | |||
| ```python >>> from transformers import BertModel, BertConfig | |||
| >>> # Initializing a BERT bert-base-uncased style configuration | |||
| >>> configuration = BertConfig() | |||
| >>> # Initializing a model from the bert-base-uncased style configuration | |||
| >>> model = BertModel(configuration) | |||
| >>> # Accessing the model configuration | |||
| >>> configuration = model.config | |||
| ```""" | |||
| model_type = 'bert' | |||
| def __init__(self, | |||
| vocab_size=30522, | |||
| hidden_size=768, | |||
| num_hidden_layers=12, | |||
| num_attention_heads=12, | |||
| intermediate_size=3072, | |||
| hidden_act='gelu', | |||
| hidden_dropout_prob=0.1, | |||
| attention_probs_dropout_prob=0.1, | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| layer_norm_eps=1e-12, | |||
| pad_token_id=0, | |||
| position_embedding_type='absolute', | |||
| use_cache=True, | |||
| classifier_dropout=None, | |||
| **kwargs): | |||
| super().__init__(pad_token_id=pad_token_id, **kwargs) | |||
| self.vocab_size = vocab_size | |||
| self.hidden_size = hidden_size | |||
| self.num_hidden_layers = num_hidden_layers | |||
| self.num_attention_heads = num_attention_heads | |||
| self.hidden_act = hidden_act | |||
| self.intermediate_size = intermediate_size | |||
| self.hidden_dropout_prob = hidden_dropout_prob | |||
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
| self.max_position_embeddings = max_position_embeddings | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.layer_norm_eps = layer_norm_eps | |||
| self.position_embedding_type = position_embedding_type | |||
| self.use_cache = use_cache | |||
| self.classifier_dropout = classifier_dropout | |||
| class BertOnnxConfig(OnnxConfig): | |||
| @property | |||
| def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| return OrderedDict([ | |||
| ('input_ids', { | |||
| 0: 'batch', | |||
| 1: 'sequence' | |||
| }), | |||
| ('attention_mask', { | |||
| 0: 'batch', | |||
| 1: 'sequence' | |||
| }), | |||
| ('token_type_ids', { | |||
| 0: 'batch', | |||
| 1: 'sequence' | |||
| }), | |||
| ]) | |||
| @@ -1,70 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['BertForSequenceClassification'] | |||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||
| class BertForSequenceClassification(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
| # Predictor.__init__(self, *args, **kwargs) | |||
| """initialize the sequence classification model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| import torch | |||
| from easynlp.appzoo import SequenceClassification | |||
| from easynlp.core.predictor import get_model_predictor | |||
| self.model = get_model_predictor( | |||
| model_dir=self.model_dir, | |||
| model_cls=SequenceClassification, | |||
| input_keys=[('input_ids', torch.LongTensor), | |||
| ('attention_mask', torch.LongTensor), | |||
| ('token_type_ids', torch.LongTensor)], | |||
| output_keys=['predictions', 'probabilities', 'logits']) | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| return self.model.predict(input) | |||
| def postprocess(self, inputs: Dict[str, np.ndarray], | |||
| **kwargs) -> Dict[str, np.ndarray]: | |||
| # N x num_classes | |||
| probs = inputs['probabilities'] | |||
| result = { | |||
| 'probs': probs, | |||
| } | |||
| return result | |||
| @@ -21,21 +21,12 @@ from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| _import_structure = { | |||
| 'configuration_deberta_v2': [ | |||
| 'DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config', | |||
| 'DebertaV2OnnxConfig' | |||
| ], | |||
| 'tokenization_deberta_v2': ['DebertaV2Tokenizer'], | |||
| } | |||
| if TYPE_CHECKING: | |||
| from .configuration_deberta_v2 import DebertaV2Config | |||
| from .tokenization_deberta_v2 import DebertaV2Tokenizer | |||
| from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast | |||
| from .modeling_deberta_v2 import ( | |||
| DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST, | |||
| DebertaV2ForMaskedLM, | |||
| DebertaV2ForMultipleChoice, | |||
| DebertaV2ForQuestionAnswering, | |||
| @@ -55,7 +46,6 @@ else: | |||
| 'DebertaV2TokenizerFast' | |||
| ] | |||
| _import_structure['modeling_deberta_v2'] = [ | |||
| 'DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST', | |||
| 'DebertaV2ForMaskedLM', | |||
| 'DebertaV2ForMultipleChoice', | |||
| 'DebertaV2ForQuestionAnswering', | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import Dict | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from torch.nn import CrossEntropyLoss | |||
| from transformers.activations import ACT2FN | |||
| from modelscope.metainfo import Heads | |||
| from modelscope.models.base import TorchHead | |||
| from modelscope.models.builder import HEADS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| @HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm) | |||
| class BertFillMaskHead(TorchHead): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.cls = BertOnlyMLMHead(self.config) | |||
| def forward(self, sequence_output): | |||
| prediction_scores = self.cls(sequence_output) | |||
| return {OutputKeys.LOGITS: prediction_scores} | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| labels) -> Dict[str, torch.Tensor]: | |||
| loss_fct = CrossEntropyLoss() # -100 index = padding token | |||
| masked_lm_loss = loss_fct( | |||
| outputs.view(-1, self.config.vocab_size), labels.view(-1)) | |||
| return {OutputKeys.LOSS: masked_lm_loss} | |||
| class BertPredictionHeadTransform(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |||
| if isinstance(config.hidden_act, str): | |||
| self.transform_act_fn = ACT2FN[config.hidden_act] | |||
| else: | |||
| self.transform_act_fn = config.hidden_act | |||
| self.LayerNorm = nn.LayerNorm( | |||
| config.hidden_size, eps=config.layer_norm_eps) | |||
| def forward(self, hidden_states): | |||
| hidden_states = self.dense(hidden_states) | |||
| hidden_states = self.transform_act_fn(hidden_states) | |||
| hidden_states = self.LayerNorm(hidden_states) | |||
| return hidden_states | |||
| class BertLMPredictionHead(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.transform = BertPredictionHeadTransform(config) | |||
| # The output weights are the same as the input embeddings, but there is | |||
| # an output-only bias for each token. | |||
| self.decoder = nn.Linear( | |||
| config.hidden_size, config.vocab_size, bias=False) | |||
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |||
| # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
| self.decoder.bias = self.bias | |||
| def forward(self, hidden_states): | |||
| hidden_states = self.transform(hidden_states) | |||
| hidden_states = self.decoder(hidden_states) | |||
| return hidden_states | |||
| class BertOnlyMLMHead(nn.Module): | |||
| def __init__(self, config): | |||
| super().__init__() | |||
| self.predictions = BertLMPredictionHead(config) | |||
| def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: | |||
| prediction_scores = self.predictions(sequence_output) | |||
| return prediction_scores | |||
| @@ -11,7 +11,7 @@ from modelscope.models.builder import HEADS | |||
| from modelscope.utils.constant import Tasks | |||
| @HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm) | |||
| # @HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm) | |||
| class BertMLMHead(BertOnlyMLMHead, TorchHead): | |||
| def compute_loss(self, outputs: Dict[str, torch.Tensor], | |||
| @@ -1,10 +1,9 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from transformers import BertForMaskedLM as BertForMaskedLMTransformer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.bert import \ | |||
| BertForMaskedLM as BertForMaskedLMTransformer | |||
| from modelscope.models.nlp.deberta_v2 import \ | |||
| DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer | |||
| from modelscope.models.nlp.structbert import SbertForMaskedLM | |||
| @@ -41,12 +41,9 @@ class SequenceLabelingForNamedEntityRecognition(TorchModel): | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| input_tensor = { | |||
| 'input_ids': | |||
| torch.tensor(input['input_ids']).unsqueeze(0), | |||
| 'attention_mask': | |||
| torch.tensor(input['attention_mask']).unsqueeze(0), | |||
| 'label_mask': | |||
| torch.tensor(input['label_mask'], dtype=torch.bool).unsqueeze(0) | |||
| 'input_ids': input['input_ids'], | |||
| 'attention_mask': input['attention_mask'], | |||
| 'label_mask': input['label_mask'], | |||
| } | |||
| output = { | |||
| 'text': input['text'], | |||
| @@ -7,6 +7,7 @@ from torch import nn | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.bert import BertPreTrainedModel | |||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.models.nlp.veco import \ | |||
| VecoForSequenceClassification as VecoForSequenceClassificationTransform | |||
| @@ -16,7 +17,10 @@ from modelscope.utils.hub import parse_label_mapping | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| __all__ = ['SbertForSequenceClassification', 'VecoForSequenceClassification'] | |||
| __all__ = [ | |||
| 'SbertForSequenceClassification', 'VecoForSequenceClassification', | |||
| 'BertForSequenceClassification' | |||
| ] | |||
| class SequenceClassificationBase(TorchModel): | |||
| @@ -132,7 +136,7 @@ class SbertForSequenceClassification(SequenceClassificationBase, | |||
| label2id = parse_label_mapping(model_dir) | |||
| if label2id is not None and len(label2id) > 0: | |||
| num_labels = len(label2id) | |||
| cls.id2label = {id: label for label, id in label2id.items()} | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(SbertPreTrainedModel, | |||
| SbertForSequenceClassification).from_pretrained( | |||
| @@ -206,3 +210,78 @@ class VecoForSequenceClassification(TorchModel, | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert) | |||
| @MODELS.register_module( | |||
| Tasks.sentiment_classification, module_name=Models.bert) | |||
| @MODELS.register_module(Tasks.nli, module_name=Models.bert) | |||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||
| class BertForSequenceClassification(SequenceClassificationBase, | |||
| BertPreTrainedModel): | |||
| """Bert sequence classification model. | |||
| Inherited from SequenceClassificationBase. | |||
| """ | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| def build_base_model(self): | |||
| from .bert import BertModel | |||
| return BertModel(self.config, add_pooling_layer=True) | |||
| def forward(self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| inputs_embeds=None, | |||
| labels=None, | |||
| output_attentions=None, | |||
| output_hidden_states=None, | |||
| return_dict=None, | |||
| **kwargs): | |||
| return super().forward( | |||
| input_ids=input_ids, | |||
| attention_mask=attention_mask, | |||
| token_type_ids=token_type_ids, | |||
| position_ids=position_ids, | |||
| head_mask=head_mask, | |||
| inputs_embeds=inputs_embeds, | |||
| labels=labels, | |||
| output_attentions=output_attentions, | |||
| output_hidden_states=output_hidden_states, | |||
| return_dict=return_dict) | |||
| @classmethod | |||
| def _instantiate(cls, **kwargs): | |||
| """Instantiate the model. | |||
| @param kwargs: Input args. | |||
| model_dir: The model dir used to load the checkpoint and the label information. | |||
| num_labels: An optional arg to tell the model how many classes to initialize. | |||
| Method will call utils.parse_label_mapping if num_labels not supplied. | |||
| If num_labels is not found, the model will use the default setting (2 classes). | |||
| @return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
| """ | |||
| model_dir = kwargs.get('model_dir') | |||
| num_labels = kwargs.get('num_labels') | |||
| if num_labels is None: | |||
| label2id = parse_label_mapping(model_dir) | |||
| if label2id is not None and len(label2id) > 0: | |||
| num_labels = len(label2id) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| return super(BertPreTrainedModel, | |||
| BertForSequenceClassification).from_pretrained( | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @@ -5,6 +5,8 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .information_extraction import InformationExtractionModel | |||
| from .feature_extraction import FeatureExtractionModel | |||
| from .fill_mask import FillMaskModel | |||
| from .sequence_classification import SequenceClassificationModel | |||
| from .task_model import SingleBackboneTaskModelBase | |||
| from .token_classification import TokenClassificationModel | |||
| @@ -12,6 +14,8 @@ if TYPE_CHECKING: | |||
| else: | |||
| _import_structure = { | |||
| 'information_extraction': ['InformationExtractionModel'], | |||
| 'feature_extraction': ['FeatureExtractionModel'], | |||
| 'fill_mask': ['FillMaskModel'], | |||
| 'sequence_classification': ['SequenceClassificationModel'], | |||
| 'task_model': ['SingleBackboneTaskModelBase'], | |||
| 'token_classification': ['TokenClassificationModel'], | |||
| @@ -0,0 +1,43 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.bert import BertConfig | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| __all__ = ['FeatureExtractionModel'] | |||
| @MODELS.register_module( | |||
| Tasks.feature_extraction, module_name=TaskModels.feature_extraction) | |||
| class FeatureExtractionModel(SingleBackboneTaskModelBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the fill mask model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| if 'base_model_prefix' in kwargs: | |||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||
| self.build_backbone(self.backbone_cfg) | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| # backbone do not need labels, only head need for loss compute | |||
| labels = input.pop(OutputKeys.LABELS, None) | |||
| outputs = super().forward(input) | |||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| return {OutputKeys.TEXT_EMBEDDING: sequence_output} | |||
| @@ -0,0 +1,47 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import TaskModels | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.bert import BertConfig | |||
| from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| __all__ = ['FillMaskModel'] | |||
| @MODELS.register_module(Tasks.fill_mask, module_name=TaskModels.fill_mask) | |||
| class FillMaskModel(SingleBackboneTaskModelBase): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the fill mask model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| if 'base_model_prefix' in kwargs: | |||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||
| self.build_backbone(self.backbone_cfg) | |||
| self.build_head(self.head_cfg) | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| # backbone do not need labels, only head need for loss compute | |||
| labels = input.pop(OutputKeys.LABELS, None) | |||
| outputs = super().forward(input) | |||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
| outputs = self.head.forward(sequence_output) | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| loss = self.compute_loss(outputs, labels) | |||
| outputs.update(loss) | |||
| outputs[OutputKeys.INPUT_IDS] = input[OutputKeys.INPUT_IDS] | |||
| return outputs | |||
| @@ -26,21 +26,12 @@ class InformationExtractionModel(SingleBackboneTaskModelBase): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| backbone_cfg = self.cfg.backbone | |||
| head_cfg = self.cfg.head | |||
| self.build_backbone(backbone_cfg) | |||
| self.build_head(head_cfg) | |||
| self.build_backbone(self.backbone_cfg) | |||
| self.build_head(self.head_cfg) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| outputs = super().forward(input) | |||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
| outputs = self.head.forward(sequence_output, input['text'], | |||
| input['offsets']) | |||
| return {OutputKeys.SPO_LIST: outputs} | |||
| def extract_backbone_outputs(self, outputs): | |||
| sequence_output = None | |||
| pooled_output = None | |||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
| return sequence_output, pooled_output | |||
| @@ -11,10 +11,14 @@ from modelscope.models.nlp.task_models.task_model import \ | |||
| SingleBackboneTaskModelBase | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| __all__ = ['SequenceClassificationModel'] | |||
| @MODELS.register_module( | |||
| Tasks.sentence_similarity, module_name=TaskModels.text_classification) | |||
| @MODELS.register_module(Tasks.nli, module_name=TaskModels.text_classification) | |||
| @MODELS.register_module( | |||
| Tasks.sentiment_classification, module_name=TaskModels.text_classification) | |||
| @MODELS.register_module( | |||
| @@ -31,49 +35,36 @@ class SequenceClassificationModel(SingleBackboneTaskModelBase): | |||
| if 'base_model_prefix' in kwargs: | |||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||
| backbone_cfg = self.cfg.backbone | |||
| head_cfg = self.cfg.head | |||
| # get the num_labels from label_mapping.json | |||
| self.id2label = {} | |||
| self.label_path = os.path.join(model_dir, 'label_mapping.json') | |||
| if os.path.exists(self.label_path): | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = { | |||
| idx: name | |||
| for name, idx in self.label_mapping.items() | |||
| } | |||
| head_cfg['num_labels'] = len(self.label_mapping) | |||
| # get the num_labels | |||
| num_labels = kwargs.get('num_labels') | |||
| if num_labels is None: | |||
| label2id = parse_label_mapping(model_dir) | |||
| if label2id is not None and len(label2id) > 0: | |||
| num_labels = len(label2id) | |||
| self.id2label = {id: label for label, id in label2id.items()} | |||
| self.head_cfg['num_labels'] = num_labels | |||
| self.build_backbone(backbone_cfg) | |||
| self.build_head(head_cfg) | |||
| self.build_backbone(self.backbone_cfg) | |||
| self.build_head(self.head_cfg) | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| # backbone do not need labels, only head need for loss compute | |||
| labels = input.pop(OutputKeys.LABELS, None) | |||
| outputs = super().forward(input) | |||
| sequence_output, pooled_output = self.extract_backbone_outputs(outputs) | |||
| outputs = self.head.forward(pooled_output) | |||
| if 'labels' in input: | |||
| loss = self.compute_loss(outputs, input['labels']) | |||
| if labels is not None: | |||
| input[OutputKeys.LABELS] = labels | |||
| loss = self.compute_loss(outputs, labels) | |||
| outputs.update(loss) | |||
| return outputs | |||
| def extract_logits(self, outputs): | |||
| return outputs[OutputKeys.LOGITS].cpu().detach() | |||
| def extract_backbone_outputs(self, outputs): | |||
| sequence_output = None | |||
| pooled_output = None | |||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
| if hasattr(self.backbone, 'extract_pooled_outputs'): | |||
| pooled_output = self.backbone.extract_pooled_outputs(outputs) | |||
| return sequence_output, pooled_output | |||
| def compute_loss(self, outputs, labels): | |||
| loss = self.head.compute_loss(outputs, labels) | |||
| return loss | |||
| def postprocess(self, input, **kwargs): | |||
| logits = self.extract_logits(input) | |||
| probs = logits.softmax(-1).numpy() | |||
| @@ -74,7 +74,7 @@ class BaseTaskModel(TorchModel, ABC): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.cfg = ConfigDict(kwargs) | |||
| self.config = ConfigDict(kwargs) | |||
| def __repr__(self): | |||
| # only log backbone and head name | |||
| @@ -397,6 +397,9 @@ class SingleBackboneTaskModelBase(BaseTaskModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.backbone_cfg = self.config.get('backbone', None) | |||
| assert self.backbone_cfg is not None | |||
| self.head_cfg = self.config.get('head', None) | |||
| def build_backbone(self, cfg): | |||
| if 'prefix' in cfg: | |||
| @@ -405,9 +408,13 @@ class SingleBackboneTaskModelBase(BaseTaskModel): | |||
| setattr(self, cfg['prefix'], backbone) | |||
| def build_head(self, cfg): | |||
| if cfg is None: | |||
| raise ValueError( | |||
| 'Head config is missing, check if this was a backbone-only model' | |||
| ) | |||
| if 'prefix' in cfg: | |||
| self._head_prefix = cfg['prefix'] | |||
| head = build_head(cfg) | |||
| head = build_head(cfg, group_key=self.group_key) | |||
| setattr(self, self._head_prefix, head) | |||
| return head | |||
| @@ -431,8 +438,18 @@ class SingleBackboneTaskModelBase(BaseTaskModel): | |||
| outputs = self.backbone.forward(**input) | |||
| return outputs | |||
| def compute_loss(self, outputs: Dict[str, Any], labels): | |||
| raise NotImplementedError() | |||
| def compute_loss(self, outputs, labels): | |||
| loss = self.head.compute_loss(outputs, labels) | |||
| return loss | |||
| def extract_backbone_outputs(self, outputs): | |||
| sequence_output = None | |||
| pooled_output = None | |||
| if hasattr(self.backbone, 'extract_sequence_outputs'): | |||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
| if hasattr(self.backbone, 'extract_pooled_outputs'): | |||
| pooled_output = self.backbone.extract_pooled_outputs(outputs) | |||
| return sequence_output, pooled_output | |||
| class EncoderDecoderTaskModelBase(BaseTaskModel): | |||
| @@ -453,7 +470,7 @@ class EncoderDecoderTaskModelBase(BaseTaskModel): | |||
| def build_encoder(self): | |||
| encoder = build_backbone( | |||
| self.cfg, | |||
| self.config, | |||
| type_name=self._encoder_key_in_cfg, | |||
| task_name=Tasks.backbone) | |||
| setattr(self, self._encoder_prefix, encoder) | |||
| @@ -461,7 +478,7 @@ class EncoderDecoderTaskModelBase(BaseTaskModel): | |||
| def build_decoder(self): | |||
| decoder = build_backbone( | |||
| self.cfg, | |||
| self.config, | |||
| type_name=self._decoder_key_in_cfg, | |||
| task_name=Tasks.backbone) | |||
| setattr(self, self._decoder_prefix, decoder) | |||
| @@ -31,9 +31,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
| if 'base_model_prefix' in kwargs: | |||
| self._base_model_prefix = kwargs['base_model_prefix'] | |||
| backbone_cfg = self.cfg.backbone | |||
| head_cfg = self.cfg.head | |||
| # get the num_labels | |||
| num_labels = kwargs.get('num_labels') | |||
| if num_labels is None: | |||
| @@ -41,12 +38,12 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
| if label2id is not None and len(label2id) > 0: | |||
| num_labels = len(label2id) | |||
| self.id2label = {id: label for label, id in label2id.items()} | |||
| head_cfg['num_labels'] = num_labels | |||
| self.head_cfg['num_labels'] = num_labels | |||
| self.build_backbone(backbone_cfg) | |||
| self.build_head(head_cfg) | |||
| self.build_backbone(self.backbone_cfg) | |||
| self.build_head(self.head_cfg) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| labels = None | |||
| if OutputKeys.LABEL in input: | |||
| labels = input.pop(OutputKeys.LABEL) | |||
| @@ -71,10 +68,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): | |||
| sequence_output = self.backbone.extract_sequence_outputs(outputs) | |||
| return sequence_output, pooled_output | |||
| def compute_loss(self, outputs, labels): | |||
| loss = self.head.compute_loss(outputs, labels) | |||
| return loss | |||
| def postprocess(self, input, **kwargs): | |||
| logits = self.extract_logits(input) | |||
| pred = torch.argmax(logits[0], dim=-1) | |||
| @@ -10,12 +10,13 @@ from torch import nn | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.nlp.bert import BertPreTrainedModel | |||
| from modelscope.models.nlp.structbert import SbertPreTrainedModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||
| torch_nested_numpify) | |||
| from .structbert import SbertPreTrainedModel | |||
| __all__ = ['SbertForTokenClassification'] | |||
| @@ -171,3 +172,49 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel): | |||
| pretrained_model_name_or_path=kwargs.get('model_dir'), | |||
| model_dir=kwargs.get('model_dir'), | |||
| **model_args) | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert) | |||
| @MODELS.register_module(Tasks.token_classification, module_name=Models.bert) | |||
| class BertForSequenceClassification(TokenClassification, BertPreTrainedModel): | |||
| """Bert token classification model. | |||
| Inherited from TokenClassificationBase. | |||
| """ | |||
| base_model_prefix: str = 'bert' | |||
| supports_gradient_checkpointing = True | |||
| _keys_to_ignore_on_load_missing = [r'position_ids'] | |||
| def __init__(self, config, model_dir): | |||
| if hasattr(config, 'base_model_prefix'): | |||
| BertForSequenceClassification.base_model_prefix = config.base_model_prefix | |||
| super().__init__(config, model_dir) | |||
| def build_base_model(self): | |||
| from .bert import BertModel | |||
| return BertModel(self.config, add_pooling_layer=True) | |||
| def forward(self, | |||
| input_ids=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| inputs_embeds=None, | |||
| labels=None, | |||
| output_attentions=None, | |||
| output_hidden_states=None, | |||
| return_dict=None, | |||
| **kwargs): | |||
| return super().forward( | |||
| input_ids=input_ids, | |||
| attention_mask=attention_mask, | |||
| token_type_ids=token_type_ids, | |||
| position_ids=position_ids, | |||
| head_mask=head_mask, | |||
| inputs_embeds=inputs_embeds, | |||
| labels=labels, | |||
| output_attentions=output_attentions, | |||
| output_hidden_states=output_hidden_states, | |||
| return_dict=return_dict, | |||
| **kwargs) | |||
| @@ -417,6 +417,22 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.fill_mask: [OutputKeys.TEXT], | |||
| # feature extraction result for single sample | |||
| # { | |||
| # "text_embedding": [[ | |||
| # [1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04], | |||
| # [6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01], | |||
| # [2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05] | |||
| # ], | |||
| # [ | |||
| # [2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05], | |||
| # [8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05], | |||
| # [3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05] | |||
| # ] | |||
| # ] | |||
| # } | |||
| Tasks.feature_extraction: [OutputKeys.TEXT_EMBEDDING], | |||
| # (Deprecated) dialog intent prediction result for single sample | |||
| # {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05, | |||
| # 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04, | |||
| @@ -52,8 +52,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_vit_object-detection_coco'), | |||
| Tasks.image_denoising: (Pipelines.image_denoise, | |||
| 'damo/cv_nafnet_image-denoise_sidd'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.text_classification: | |||
| (Pipelines.sentiment_classification, | |||
| 'damo/nlp_structbert_sentiment-classification_chinese-base'), | |||
| Tasks.text_generation: (Pipelines.text_generation, | |||
| 'damo/nlp_palm2.0_text-generation_chinese-base'), | |||
| Tasks.zero_shot_classification: | |||
| @@ -80,6 +81,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | |||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | |||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
| Tasks.feature_extraction: (Pipelines.feature_extraction, | |||
| 'damo/pert_feature-extraction_base-test'), | |||
| Tasks.action_recognition: (Pipelines.action_recognition, | |||
| 'damo/cv_TAdaConv_action-recognition'), | |||
| Tasks.action_detection: (Pipelines.action_detection, | |||
| @@ -11,12 +11,13 @@ if TYPE_CHECKING: | |||
| from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline | |||
| from .document_segmentation_pipeline import DocumentSegmentationPipeline | |||
| from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline | |||
| from .feature_extraction_pipeline import FeatureExtractionPipeline | |||
| from .fill_mask_pipeline import FillMaskPipeline | |||
| from .fill_mask_ponet_pipeline import FillMaskPonetPipeline | |||
| from .information_extraction_pipeline import InformationExtractionPipeline | |||
| from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline | |||
| from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline | |||
| from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline | |||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| from .sequence_classification_pipeline import SequenceClassificationPipeline | |||
| from .summarization_pipeline import SummarizationPipeline | |||
| from .text_classification_pipeline import TextClassificationPipeline | |||
| @@ -27,8 +28,7 @@ if TYPE_CHECKING: | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .passage_ranking_pipeline import PassageRankingPipeline | |||
| from .sentence_embedding_pipeline import SentenceEmbeddingPipeline | |||
| else: | |||
| _import_structure = { | |||
| 'conversational_text_to_sql_pipeline': | |||
| @@ -41,16 +41,15 @@ else: | |||
| 'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'], | |||
| 'document_segmentation_pipeline': ['DocumentSegmentationPipeline'], | |||
| 'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'], | |||
| 'feature_extraction_pipeline': ['FeatureExtractionPipeline'], | |||
| 'fill_mask_pipeline': ['FillMaskPipeline'], | |||
| 'fill_mask_ponet_pipeline': ['FillMaskPoNetPipeline'], | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'named_entity_recognition_pipeline': | |||
| ['NamedEntityRecognitionPipeline'], | |||
| 'information_extraction_pipeline': ['InformationExtractionPipeline'], | |||
| 'pair_sentence_classification_pipeline': | |||
| ['PairSentenceClassificationPipeline'], | |||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'], | |||
| 'sequence_classification_pipeline': ['SequenceClassificationPipeline'], | |||
| 'single_sentence_classification_pipeline': | |||
| ['SingleSentenceClassificationPipeline'], | |||
| 'summarization_pipeline': ['SummarizationPipeline'], | |||
| 'text_classification_pipeline': ['TextClassificationPipeline'], | |||
| 'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'], | |||
| @@ -61,8 +60,6 @@ else: | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| 'zero_shot_classification_pipeline': | |||
| ['ZeroShotClassificationPipeline'], | |||
| 'passage_ranking_pipeline': ['PassageRankingPipeline'], | |||
| 'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'] | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,82 @@ | |||
| import os | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import NLPPreprocessor, Preprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| __all__ = ['FeatureExtractionPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.feature_extraction, module_name=Pipelines.feature_extraction) | |||
| class FeatureExtractionPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Optional[Preprocessor] = None, | |||
| first_sequence='sentence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp feature extraction pipeline for prediction | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported feature extraction task, or a | |||
| no-head model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| first_sequence: The key to read the sentence in. | |||
| sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value. | |||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' | |||
| param will have no effect. | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipe_ins = pipeline('feature_extraction', model='damo/nlp_structbert_feature-extraction_english-large') | |||
| >>> input = 'Everything you love is treasure' | |||
| >>> print(pipe_ins(input)) | |||
| """ | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = NLPPreprocessor( | |||
| model.model_dir, | |||
| padding=kwargs.pop('padding', False), | |||
| sequence_length=kwargs.pop('sequence_length', 128)) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| self.config = Config.from_file( | |||
| os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
| self.tokenizer = preprocessor.tokenizer | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model(**inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| return { | |||
| OutputKeys.TEXT_EMBEDDING: | |||
| inputs[OutputKeys.TEXT_EMBEDDING].tolist() | |||
| } | |||
| @@ -10,7 +10,7 @@ from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import FillMaskPreprocessor, Preprocessor | |||
| from modelscope.preprocessors import NLPPreprocessor, Preprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @@ -57,7 +57,7 @@ class FillMaskPipeline(Pipeline): | |||
| model, Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| fill_mask_model.model_dir, | |||
| first_sequence=first_sequence, | |||
| second_sequence=None, | |||
| @@ -118,7 +118,10 @@ class FillMaskPipeline(Pipeline): | |||
| logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy() | |||
| input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy() | |||
| pred_ids = np.argmax(logits, axis=-1) | |||
| model_type = self.model.config.model_type | |||
| if hasattr(self.model.config, 'backbone'): | |||
| model_type = self.model.config.backbone.type | |||
| else: | |||
| model_type = self.model.config.model_type | |||
| process_type = model_type if model_type in self.mask_id else _type_map[ | |||
| model_type] | |||
| rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids, | |||
| @@ -36,7 +36,7 @@ class InformationExtractionPipeline(Pipeline): | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| return self.model(**inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| @@ -9,7 +9,8 @@ from modelscope.models import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import NERPreprocessor, Preprocessor | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| TokenClassificationPreprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['NamedEntityRecognitionPipeline'] | |||
| @@ -46,7 +47,7 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = NERPreprocessor( | |||
| preprocessor = TokenClassificationPreprocessor( | |||
| model.model_dir, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| model.eval() | |||
| @@ -1,59 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Union | |||
| from modelscope.models.base import Model | |||
| from ...metainfo import Pipelines | |||
| from ...preprocessors import (PairSentenceClassificationPreprocessor, | |||
| Preprocessor) | |||
| from ...utils.constant import Tasks | |||
| from ..builder import PIPELINES | |||
| from .sequence_classification_pipeline_base import \ | |||
| SequenceClassificationPipelineBase | |||
| __all__ = ['PairSentenceClassificationPipeline'] | |||
| @PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) | |||
| class PairSentenceClassificationPipeline(SequenceClassificationPipelineBase): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Preprocessor = None, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp pair sequence classification pipeline for prediction. | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported the sequence classification task, | |||
| or a model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| first_sequence: The key to read the first sentence in. | |||
| second_sequence: The key to read the second sentence in. | |||
| sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. | |||
| NOTE: Inputs of type 'tuple' or 'list' are also supported. In this scenario, the 'first_sequence' and | |||
| 'second_sequence' param will have no effect. | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipeline_ins = pipeline(task='nli', model='damo/nlp_structbert_nli_chinese-base') | |||
| >>> sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | |||
| >>> sentence2 = '四川商务职业学院商务管理在哪个校区?' | |||
| >>> print(pipeline_ins((sentence1, sentence2))) | |||
| >>> # Or use the dict input: | |||
| >>> print(pipeline_ins({'first_sequence': sentence1, 'second_sequence': sentence2})) | |||
| To view other examples plese check the tests/pipelines/test_nli.py. | |||
| """ | |||
| if preprocessor is None: | |||
| preprocessor = PairSentenceClassificationPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| @@ -1,48 +1,64 @@ | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import BertForSequenceClassification | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.preprocessors import (Preprocessor, | |||
| SequenceClassificationPreprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['SequenceClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text_classification, module_name=Pipelines.sentiment_analysis) | |||
| @PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli) | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) | |||
| @PIPELINES.register_module( | |||
| Tasks.text_classification, module_name=Pipelines.sentiment_classification) | |||
| class SequenceClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[BertForSequenceClassification, str], | |||
| preprocessor: SequenceClassificationPreprocessor = None, | |||
| model: Union[Model, str], | |||
| preprocessor: Preprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| """This is the base class for all the sequence classification sub-tasks. | |||
| Args: | |||
| model (BertForSequenceClassification): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| model (str or Model): A model instance or a model local dir or a model id in the model hub. | |||
| preprocessor (Preprocessor): a preprocessor instance, must not be None. | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, BertForSequenceClassification), \ | |||
| 'model must be a single str or BertForSequenceClassification' | |||
| sc_model = model if isinstance( | |||
| model, | |||
| BertForSequenceClassification) else Model.from_pretrained(model) | |||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||
| 'model must be a single str or Model' | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| first_sequence = kwargs.pop('first_sequence', 'first_sequence') | |||
| second_sequence = kwargs.pop('second_sequence', None) | |||
| if preprocessor is None: | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| first_sequence='sentence', | |||
| second_sequence=None, | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| first_sequence=first_sequence, | |||
| second_sequence=second_sequence, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| assert preprocessor is not None | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.id2label = kwargs.get('id2label') | |||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | |||
| self.id2label = self.preprocessor.id2label | |||
| assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
| 'as a parameter or make sure the preprocessor has the attribute.' | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model(**inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| @@ -50,20 +66,18 @@ class SequenceClassificationPipeline(Pipeline): | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): input data dict | |||
| topk (int): return topk classification result. | |||
| inputs (Dict[str, Any]): _description_ | |||
| topk (int): The topk probs to take | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| # NxC np.ndarray | |||
| probs = inputs['probs'][0] | |||
| probs = inputs[OutputKeys.PROBABILITIES][0] | |||
| num_classes = probs.shape[0] | |||
| topk = min(topk, num_classes) | |||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| cls_names = [self.id2label[cid] for cid in cls_ids] | |||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||
| @@ -1,62 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.models.base import Model | |||
| from modelscope.outputs import OutputKeys | |||
| from ...preprocessors import Preprocessor | |||
| from ..base import Pipeline | |||
| class SequenceClassificationPipelineBase(Pipeline): | |||
| def __init__(self, model: Union[Model, str], preprocessor: Preprocessor, | |||
| **kwargs): | |||
| """This is the base class for all the sequence classification sub-tasks. | |||
| Args: | |||
| model (str or Model): A model instance or a model local dir or a model id in the model hub. | |||
| preprocessor (Preprocessor): a preprocessor instance, must not be None. | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, Model), \ | |||
| 'model must be a single str or Model' | |||
| model = model if isinstance(model, | |||
| Model) else Model.from_pretrained(model) | |||
| assert preprocessor is not None | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.id2label = kwargs.get('id2label') | |||
| if self.id2label is None and hasattr(self.preprocessor, 'id2label'): | |||
| self.id2label = self.preprocessor.id2label | |||
| assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \ | |||
| 'as a parameter or make sure the preprocessor has the attribute.' | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return self.model(**inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| topk: int = 5) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| topk (int): The topk probs to take | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs[OutputKeys.PROBABILITIES][0] | |||
| num_classes = probs.shape[0] | |||
| topk = min(topk, num_classes) | |||
| top_indices = np.argpartition(probs, -topk)[-topk:] | |||
| cls_ids = top_indices[np.argsort(probs[top_indices])] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.id2label[cid] for cid in cls_ids] | |||
| return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names} | |||
| @@ -1,56 +0,0 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Union | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...preprocessors import (Preprocessor, | |||
| SingleSentenceClassificationPreprocessor) | |||
| from ...utils.constant import Tasks | |||
| from ..builder import PIPELINES | |||
| from .sequence_classification_pipeline_base import \ | |||
| SequenceClassificationPipelineBase | |||
| __all__ = ['SingleSentenceClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentiment_classification, | |||
| module_name=Pipelines.sentiment_classification) | |||
| class SingleSentenceClassificationPipeline(SequenceClassificationPipelineBase): | |||
| def __init__(self, | |||
| model: Union[Model, str], | |||
| preprocessor: Preprocessor = None, | |||
| first_sequence='first_sequence', | |||
| **kwargs): | |||
| """Use `model` and `preprocessor` to create a nlp single sequence classification pipeline for prediction. | |||
| Args: | |||
| model (str or Model): Supply either a local model dir which supported the sequence classification task, | |||
| or a model id from the model hub, or a torch model instance. | |||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
| the model if supplied. | |||
| first_sequence: The key to read the first sentence in. | |||
| sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value. | |||
| NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence' | |||
| param will have no effect. | |||
| Example: | |||
| >>> from modelscope.pipelines import pipeline | |||
| >>> pipeline_ins = pipeline(task='sentiment-classification', | |||
| >>> model='damo/nlp_structbert_sentiment-classification_chinese-base') | |||
| >>> sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' | |||
| >>> print(pipeline_ins(sentence1)) | |||
| >>> # Or use the dict input: | |||
| >>> print(pipeline_ins({'first_sequence': sentence1})) | |||
| To view other examples plese check the tests/pipelines/test_sentiment-classification.py. | |||
| """ | |||
| if preprocessor is None: | |||
| preprocessor = SingleSentenceClassificationPreprocessor( | |||
| model.model_dir if isinstance(model, Model) else model, | |||
| first_sequence=first_sequence, | |||
| sequence_length=kwargs.pop('sequence_length', 512)) | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| @@ -49,7 +49,7 @@ class TokenClassificationPipeline(Pipeline): | |||
| text = inputs.pop(OutputKeys.TEXT) | |||
| with torch.no_grad(): | |||
| return { | |||
| **self.model(inputs, **forward_params), OutputKeys.TEXT: text | |||
| **self.model(**inputs, **forward_params), OutputKeys.TEXT: text | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any], | |||
| @@ -16,17 +16,23 @@ if TYPE_CHECKING: | |||
| from .kws import WavToLists | |||
| from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) | |||
| from .nlp import ( | |||
| Tokenize, SequenceClassificationPreprocessor, | |||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||
| SingleSentenceClassificationPreprocessor, | |||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
| SequenceLabelingPreprocessor, RelationExtractionPreprocessor, | |||
| DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor, | |||
| PassageRankingPreprocessor, SentenceEmbeddingPreprocessor, | |||
| DocumentSegmentationPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor, | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, | |||
| TokenClassificationPreprocessor, | |||
| TextErrorCorrectionPreprocessor, | |||
| TextGenerationPreprocessor, | |||
| Text2TextGenerationPreprocessor, | |||
| WordSegmentationBlankSetToLabelPreprocessor) | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| ) | |||
| from .space import (DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor) | |||
| @@ -49,18 +55,22 @@ else: | |||
| 'kws': ['WavToLists'], | |||
| 'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], | |||
| 'nlp': [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| 'TokenClassificationPreprocessor', | |||
| 'TextErrorCorrectionPreprocessor', | |||
| 'TextGenerationPreprocessor', | |||
| 'Tokenize', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| 'ZeroShotClassificationPreprocessor', | |||
| ], | |||
| 'space': [ | |||
| 'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor', | |||
| @@ -6,32 +6,41 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .text_error_correction import TextErrorCorrectionPreprocessor | |||
| from .nlp_base import ( | |||
| Tokenize, SequenceClassificationPreprocessor, | |||
| TextGenerationPreprocessor, TokenClassificationPreprocessor, | |||
| SingleSentenceClassificationPreprocessor, | |||
| Text2TextGenerationPreprocessor, | |||
| PairSentenceClassificationPreprocessor, FillMaskPreprocessor, | |||
| ZeroShotClassificationPreprocessor, NERPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor, | |||
| RelationExtractionPreprocessor, DocumentSegmentationPreprocessor, | |||
| FillMaskPoNetPreprocessor, PassageRankingPreprocessor, | |||
| DocumentSegmentationPreprocessor, | |||
| FaqQuestionAnsweringPreprocessor, | |||
| FillMaskPoNetPreprocessor, | |||
| NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, | |||
| PassageRankingPreprocessor, | |||
| RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, | |||
| WordSegmentationBlankSetToLabelPreprocessor) | |||
| SequenceClassificationPreprocessor, | |||
| TokenClassificationPreprocessor, | |||
| TextGenerationPreprocessor, | |||
| Text2TextGenerationPreprocessor, | |||
| Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, | |||
| ) | |||
| else: | |||
| _import_structure = { | |||
| 'nlp_base': [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| 'TokenClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', | |||
| 'Tokenize', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor' | |||
| 'ZeroShotClassificationPreprocessor', | |||
| ], | |||
| 'text_error_correction': [ | |||
| 'TextErrorCorrectionPreprocessor', | |||
| @@ -2,14 +2,13 @@ | |||
| import os.path as osp | |||
| import re | |||
| import uuid | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| import numpy as np | |||
| from transformers import AutoTokenizer, BertTokenizerFast | |||
| import torch | |||
| from transformers import AutoTokenizer | |||
| from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.models.nlp.structbert import SbertTokenizerFast | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| @@ -23,24 +22,21 @@ from modelscope.utils.type_assert import type_assert | |||
| logger = get_logger() | |||
| __all__ = [ | |||
| 'Tokenize', | |||
| 'DocumentSegmentationPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', | |||
| 'NLPPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'NLPTokenizerPreprocessorBase', | |||
| 'PassageRankingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', | |||
| 'TokenClassificationPreprocessor', | |||
| 'PairSentenceClassificationPreprocessor', | |||
| 'Text2TextGenerationPreprocessor', | |||
| 'SingleSentenceClassificationPreprocessor', | |||
| 'FillMaskPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'NERPreprocessor', | |||
| 'SentenceEmbeddingPreprocessor', | |||
| 'PassageRankingPreprocessor', | |||
| 'FaqQuestionAnsweringPreprocessor', | |||
| 'SequenceLabelingPreprocessor', | |||
| 'RelationExtractionPreprocessor', | |||
| 'DocumentSegmentationPreprocessor', | |||
| 'FillMaskPoNetPreprocessor', | |||
| 'TextGenerationPreprocessor', | |||
| 'Tokenize', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| ] | |||
| @@ -48,85 +44,19 @@ __all__ = [ | |||
| class Tokenize(Preprocessor): | |||
| def __init__(self, tokenizer_name) -> None: | |||
| self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||
| def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: | |||
| if isinstance(data, str): | |||
| data = {InputFields.text: data} | |||
| token_dict = self._tokenizer(data[InputFields.text]) | |||
| token_dict = self.tokenizer(data[InputFields.text]) | |||
| data.update(token_dict) | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from easynlp.modelzoo import AutoTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
| print(f'this is the tokenzier {self.tokenizer}') | |||
| self.label2id = parse_label_mapping(self.model_dir) | |||
| @type_assert(object, (str, tuple, Dict)) | |||
| def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: | |||
| feature = super().__call__(data) | |||
| if isinstance(data, str): | |||
| new_data = {self.first_sequence: data} | |||
| elif isinstance(data, tuple): | |||
| sentence1, sentence2 = data | |||
| new_data = { | |||
| self.first_sequence: sentence1, | |||
| self.second_sequence: sentence2 | |||
| } | |||
| else: | |||
| new_data = data | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'id': [], | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [], | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data.get(self.second_sequence, None) | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| padding='max_length', | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| def __init__(self, model_dir: str, pair: bool, mode: str, **kwargs): | |||
| def __init__(self, model_dir: str, mode: str, **kwargs): | |||
| """The NLP tokenizer preprocessor base class. | |||
| Any nlp preprocessor which uses the hf tokenizer can inherit from this class. | |||
| @@ -138,7 +68,6 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| label: The label key | |||
| label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping | |||
| if this mapping is not supplied. | |||
| pair (bool): Pair sentence input or single sentence input. | |||
| mode: Run this preprocessor in either 'train'/'eval'/'inference' mode | |||
| kwargs: These kwargs will be directly fed into the tokenizer. | |||
| """ | |||
| @@ -148,7 +77,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.pair = pair | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self._mode = mode | |||
| self.label = kwargs.pop('label', OutputKeys.LABEL) | |||
| self.label2id = None | |||
| @@ -158,6 +88,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| self.label2id = parse_label_mapping(self.model_dir) | |||
| self.tokenize_kwargs = kwargs | |||
| self.tokenizer = self.build_tokenizer(model_dir) | |||
| @property | |||
| @@ -179,20 +110,38 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| @param model_dir: The local model dir. | |||
| @return: The initialized tokenizer. | |||
| """ | |||
| self.is_transformer_based_model = 'lstm' not in model_dir | |||
| # fast version lead to parallel inference failed | |||
| model_type = get_model_type(model_dir) | |||
| if model_type in (Models.structbert, Models.gpt3, Models.palm, | |||
| Models.plug): | |||
| from modelscope.models.nlp.structbert import SbertTokenizer | |||
| return SbertTokenizer.from_pretrained(model_dir, use_fast=False) | |||
| from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast | |||
| return SbertTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| elif model_type == Models.veco: | |||
| from modelscope.models.nlp.veco import VecoTokenizer | |||
| return VecoTokenizer.from_pretrained(model_dir) | |||
| from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast | |||
| return VecoTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| elif model_type == Models.deberta_v2: | |||
| from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer | |||
| return DebertaV2Tokenizer.from_pretrained(model_dir) | |||
| from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast | |||
| return DebertaV2Tokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained( | |||
| model_dir) | |||
| elif not self.is_transformer_based_model: | |||
| from transformers import BertTokenizer, BertTokenizerFast | |||
| return BertTokenizer.from_pretrained( | |||
| model_dir | |||
| ) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained( | |||
| model_dir) | |||
| else: | |||
| return AutoTokenizer.from_pretrained(model_dir, use_fast=False) | |||
| return AutoTokenizer.from_pretrained( | |||
| model_dir, | |||
| use_fast=False if self._mode == ModeKeys.INFERENCE else True) | |||
| def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| @@ -239,7 +188,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| if len(data) == 3: | |||
| text_a, text_b, labels = data | |||
| elif len(data) == 2: | |||
| if self.pair: | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| text_a, text_b = data | |||
| else: | |||
| text_a, labels = data | |||
| @@ -277,6 +226,22 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
| output[OutputKeys.LABELS] = labels | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.feature_extraction) | |||
| class NLPPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in MLM task. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||
| True) | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.passage_ranking) | |||
| class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| @@ -337,22 +302,12 @@ class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase): | |||
| Fields.nlp, module_name=Preprocessors.nli_tokenizer) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer) | |||
| class PairSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in pair sentence classification. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get( | |||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, pair=True, mode=mode, **kwargs) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) | |||
| class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in single sentence classification. | |||
| class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in sequence classification. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| @@ -360,7 +315,7 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| kwargs['padding'] = kwargs.get( | |||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| @PREPROCESSORS.register_module( | |||
| @@ -421,7 +376,7 @@ class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| model_dir (str): model path | |||
| """ | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| def __call__(self, data: Union[str, Dict], hypothesis_template: str, | |||
| candidate_labels: list) -> Dict[str, Any]: | |||
| @@ -496,14 +451,12 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| tokenizer=None, | |||
| mode=ModeKeys.INFERENCE, | |||
| **kwargs): | |||
| self.tokenizer = self.build_tokenizer( | |||
| model_dir) if tokenizer is None else tokenizer | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||
| False) | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| @staticmethod | |||
| def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]: | |||
| @@ -541,20 +494,6 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| } | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask) | |||
| class FillMaskPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in MLM task. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get('padding', 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids', | |||
| True) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, | |||
| module_name=Preprocessors.word_segment_text_to_label_preprocessor) | |||
| @@ -592,21 +531,40 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor): | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.ner_tokenizer) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) | |||
| class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| """The tokenizer preprocessor used in normal token classification task. | |||
| """The tokenizer preprocessor used in normal NER task. | |||
| """ | |||
| def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| kwargs['truncation'] = kwargs.get('truncation', True) | |||
| kwargs['padding'] = kwargs.get( | |||
| 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') | |||
| kwargs['max_length'] = kwargs.pop('sequence_length', 128) | |||
| self.label_all_tokens = kwargs.pop('label_all_tokens', False) | |||
| super().__init__(model_dir, pair=False, mode=mode, **kwargs) | |||
| super().__init__(model_dir, mode=mode, **kwargs) | |||
| def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: | |||
| if 'is_split_into_words' in kwargs: | |||
| self.is_split_into_words = kwargs.pop('is_split_into_words') | |||
| else: | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| if 'label2id' in kwargs: | |||
| kwargs.pop('label2id') | |||
| self.tokenize_kwargs = kwargs | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| @@ -618,23 +576,84 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| text_a = None | |||
| # preprocess the data for the model input | |||
| text = None | |||
| labels_list = None | |||
| if isinstance(data, str): | |||
| text_a = data | |||
| text = data | |||
| elif isinstance(data, dict): | |||
| text_a = data.get(self.first_sequence) | |||
| text = data.get(self.first_sequence) | |||
| labels_list = data.get(self.label) | |||
| if isinstance(text_a, str): | |||
| text_a = text_a.replace(' ', '').strip() | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| if self.is_split_into_words: | |||
| for offset, token in enumerate(list(data)): | |||
| subtoken_ids = self.tokenizer.encode( | |||
| token, add_special_tokens=False) | |||
| if len(subtoken_ids) == 0: | |||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||
| input_ids.extend(subtoken_ids) | |||
| label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) | |||
| offset_mapping.extend([(offset, offset + 1)]) | |||
| else: | |||
| if self.tokenizer.is_fast: | |||
| encodings = self.tokenizer( | |||
| text, | |||
| add_special_tokens=False, | |||
| return_offsets_mapping=True, | |||
| **self.tokenize_kwargs) | |||
| input_ids = encodings['input_ids'] | |||
| word_ids = encodings.word_ids() | |||
| for i in range(len(word_ids)): | |||
| if word_ids[i] is None: | |||
| label_mask.append(0) | |||
| elif word_ids[i] == word_ids[i - 1]: | |||
| label_mask.append(0) | |||
| offset_mapping[-1] = ( | |||
| offset_mapping[-1][0], | |||
| encodings['offset_mapping'][i][1]) | |||
| else: | |||
| label_mask.append(1) | |||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||
| else: | |||
| encodings = self.tokenizer( | |||
| text, add_special_tokens=False, **self.tokenize_kwargs) | |||
| input_ids = encodings['input_ids'] | |||
| label_mask, offset_mapping = self.get_label_mask_and_offset_mapping( | |||
| text) | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| offset_mapping = offset_mapping[:sum(label_mask)] | |||
| tokenized_inputs = self.tokenizer( | |||
| [t for t in text_a], | |||
| return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None, | |||
| is_split_into_words=True, | |||
| **self.tokenize_kwargs) | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |||
| attention_mask = torch.tensor(attention_mask).unsqueeze(0) | |||
| label_mask = torch.tensor( | |||
| label_mask, dtype=torch.bool).unsqueeze(0) | |||
| # the token classification | |||
| output = { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| # align the labels with tokenized text | |||
| if labels_list is not None: | |||
| assert self.label2id is not None | |||
| # Map that sends B-Xxx label to its I-Xxx counterpart | |||
| @@ -653,7 +672,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| b_to_i_label.append(idx) | |||
| label_row = [self.label2id[lb] for lb in labels_list] | |||
| word_ids = tokenized_inputs.word_ids() | |||
| previous_word_idx = None | |||
| label_ids = [] | |||
| for word_idx in word_ids: | |||
| @@ -668,229 +686,66 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase): | |||
| label_ids.append(-100) | |||
| previous_word_idx = word_idx | |||
| labels = label_ids | |||
| tokenized_inputs['labels'] = labels | |||
| # new code end | |||
| if self._mode == ModeKeys.INFERENCE: | |||
| tokenized_inputs[OutputKeys.TEXT] = text_a | |||
| return tokenized_inputs | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.ner_tokenizer) | |||
| class NERPreprocessor(Preprocessor): | |||
| """The tokenizer preprocessor used in normal NER task. | |||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| self.is_transformer_based_model = 'lstm' not in model_dir | |||
| if self.is_transformer_based_model: | |||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| else: | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, use_fast=True) | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| output['labels'] = labels | |||
| return output | |||
| # preprocess the data for the model input | |||
| text = data | |||
| if self.is_split_into_words: | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| for offset, token in enumerate(list(data)): | |||
| subtoken_ids = self.tokenizer.encode( | |||
| token, add_special_tokens=False) | |||
| if len(subtoken_ids) == 0: | |||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||
| input_ids.extend(subtoken_ids) | |||
| label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) | |||
| offset_mapping.extend([(offset, offset + 1)] | |||
| + [(offset + 1, offset + 1)] | |||
| * (len(subtoken_ids) - 1)) | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| offset_mapping = offset_mapping[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| else: | |||
| encodings = self.tokenizer( | |||
| text, | |||
| add_special_tokens=True, | |||
| padding=True, | |||
| truncation=True, | |||
| max_length=self.sequence_length, | |||
| return_offsets_mapping=True) | |||
| input_ids = encodings['input_ids'] | |||
| attention_mask = encodings['attention_mask'] | |||
| word_ids = encodings.word_ids() | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| for i in range(len(word_ids)): | |||
| if word_ids[i] is None: | |||
| label_mask.append(0) | |||
| elif word_ids[i] == word_ids[i - 1]: | |||
| label_mask.append(0) | |||
| offset_mapping[-1] = (offset_mapping[-1][0], | |||
| encodings['offset_mapping'][i][1]) | |||
| def get_tokenizer_class(self): | |||
| tokenizer_class = self.tokenizer.__class__.__name__ | |||
| if tokenizer_class.endswith( | |||
| 'Fast') and tokenizer_class != 'PreTrainedTokenizerFast': | |||
| tokenizer_class = tokenizer_class[:-4] | |||
| return tokenizer_class | |||
| def get_label_mask_and_offset_mapping(self, text): | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| tokens = self.tokenizer.tokenize(text) | |||
| offset = 0 | |||
| if self.get_tokenizer_class() == 'BertTokenizer': | |||
| for token in tokens: | |||
| is_start = (token[:2] != '##') | |||
| if is_start: | |||
| label_mask.append(True) | |||
| else: | |||
| label_mask.append(1) | |||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| return { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer) | |||
| class SequenceLabelingPreprocessor(Preprocessor): | |||
| """The tokenizer preprocessor used in normal NER task. | |||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| if 'lstm' in model_dir or 'gcnn' in model_dir: | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, use_fast=False) | |||
| elif 'structbert' in model_dir: | |||
| self.tokenizer = SbertTokenizerFast.from_pretrained( | |||
| model_dir, use_fast=False) | |||
| else: | |||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||
| model_dir, use_fast=False) | |||
| self.is_split_into_words = self.tokenizer.init_kwargs.get( | |||
| 'is_split_into_words', False) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| # preprocess the data for the model input | |||
| text = data | |||
| if self.is_split_into_words: | |||
| input_ids = [] | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| for offset, token in enumerate(list(data)): | |||
| subtoken_ids = self.tokenizer.encode( | |||
| token, add_special_tokens=False) | |||
| if len(subtoken_ids) == 0: | |||
| subtoken_ids = [self.tokenizer.unk_token_id] | |||
| input_ids.extend(subtoken_ids) | |||
| label_mask.extend([1] + [0] * (len(subtoken_ids) - 1)) | |||
| offset_mapping.extend([(offset, offset + 1)] | |||
| + [(offset + 1, offset + 1)] | |||
| * (len(subtoken_ids) - 1)) | |||
| if len(input_ids) >= self.sequence_length - 2: | |||
| input_ids = input_ids[:self.sequence_length - 2] | |||
| label_mask = label_mask[:self.sequence_length - 2] | |||
| offset_mapping = offset_mapping[:self.sequence_length - 2] | |||
| input_ids = [self.tokenizer.cls_token_id | |||
| ] + input_ids + [self.tokenizer.sep_token_id] | |||
| label_mask = [0] + label_mask + [0] | |||
| attention_mask = [1] * len(input_ids) | |||
| else: | |||
| encodings = self.tokenizer( | |||
| text, | |||
| add_special_tokens=True, | |||
| padding=True, | |||
| truncation=True, | |||
| max_length=self.sequence_length, | |||
| return_offsets_mapping=True) | |||
| input_ids = encodings['input_ids'] | |||
| attention_mask = encodings['attention_mask'] | |||
| word_ids = encodings.word_ids() | |||
| label_mask = [] | |||
| offset_mapping = [] | |||
| for i in range(len(word_ids)): | |||
| if word_ids[i] is None: | |||
| label_mask.append(0) | |||
| elif word_ids[i] == word_ids[i - 1]: | |||
| label_mask.append(0) | |||
| offset_mapping[-1] = (offset_mapping[-1][0], | |||
| encodings['offset_mapping'][i][1]) | |||
| token = token[2:] | |||
| label_mask.append(False) | |||
| start = offset + text[offset:].index(token) | |||
| end = start + len(token) | |||
| if is_start: | |||
| offset_mapping.append((start, end)) | |||
| else: | |||
| label_mask.append(1) | |||
| offset_mapping.append(encodings['offset_mapping'][i]) | |||
| offset_mapping[-1] = (offset_mapping[-1][0], end) | |||
| offset = end | |||
| elif self.get_tokenizer_class() == 'XLMRobertaTokenizer': | |||
| last_is_blank = False | |||
| for token in tokens: | |||
| is_start = (token[0] == '▁') | |||
| if is_start: | |||
| token = token[1:] | |||
| label_mask.append(True) | |||
| if len(token) == 0: | |||
| last_is_blank = True | |||
| continue | |||
| else: | |||
| label_mask.append(False) | |||
| start = offset + text[offset:].index(token) | |||
| end = start + len(token) | |||
| if last_is_blank or is_start: | |||
| offset_mapping.append((start, end)) | |||
| else: | |||
| offset_mapping[-1] = (offset_mapping[-1][0], end) | |||
| offset = end | |||
| last_is_blank = False | |||
| else: | |||
| raise NotImplementedError | |||
| if not self.is_transformer_based_model: | |||
| input_ids = input_ids[1:-1] | |||
| attention_mask = attention_mask[1:-1] | |||
| label_mask = label_mask[1:-1] | |||
| return { | |||
| 'text': text, | |||
| 'input_ids': input_ids, | |||
| 'attention_mask': attention_mask, | |||
| 'label_mask': label_mask, | |||
| 'offset_mapping': offset_mapping | |||
| } | |||
| return label_mask, offset_mapping | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.re_tokenizer) | |||
| class RelationExtractionPreprocessor(Preprocessor): | |||
| """The tokenizer preprocessor used in normal RE task. | |||
| NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition. | |||
| """The relation extraction preprocessor used in normal RE task. | |||
| """ | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -937,7 +792,7 @@ class FaqQuestionAnsweringPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super(FaqQuestionAnsweringPreprocessor, self).__init__( | |||
| model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs) | |||
| model_dir, mode=ModeKeys.INFERENCE, **kwargs) | |||
| import os | |||
| from transformers import BertTokenizer | |||
| @@ -1026,7 +881,7 @@ class DocumentSegmentationPreprocessor(Preprocessor): | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from transformers import BertTokenizerFast | |||
| self.tokenizer = BertTokenizerFast.from_pretrained( | |||
| model_dir, | |||
| use_fast=True, | |||
| @@ -115,6 +115,7 @@ class NLPTasks(object): | |||
| conversational_text_to_sql = 'conversational-text-to-sql' | |||
| information_extraction = 'information-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| class AudioTasks(object): | |||
| @@ -74,7 +74,6 @@ class Registry(object): | |||
| raise KeyError(f'{module_name} is already registered in ' | |||
| f'{self._name}[{group_key}]') | |||
| self._modules[group_key][module_name] = module_cls | |||
| module_cls.group_key = group_key | |||
| def register_module(self, | |||
| group_key: str = default_group, | |||
| @@ -196,6 +195,7 @@ def build_from_cfg(cfg, | |||
| if obj_cls is None: | |||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
| f' registry group {group_key}') | |||
| obj_cls.group_key = group_key | |||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
| obj_cls = obj_type | |||
| else: | |||
| @@ -75,7 +75,8 @@ class MsDatasetTest(unittest.TestCase): | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| nlp_model.model_dir, | |||
| first_sequence='premise', | |||
| second_sequence=None) | |||
| second_sequence=None, | |||
| padding='max_length') | |||
| ms_ds_train = MsDataset.load( | |||
| 'xcopa', | |||
| subset_name='translation-et', | |||
| @@ -6,11 +6,9 @@ import torch | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import DebertaV2ForMaskedLM | |||
| from modelscope.models.nlp.deberta_v2 import (DebertaV2Tokenizer, | |||
| DebertaV2TokenizerFast) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import FillMaskPipeline | |||
| from modelscope.preprocessors import FillMaskPreprocessor | |||
| from modelscope.preprocessors import NLPPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -24,7 +22,7 @@ class DeBERTaV2TaskTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_by_direct_model_download(self): | |||
| model_dir = snapshot_download(self.model_id_deberta) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = DebertaV2ForMaskedLM.from_pretrained(model_dir) | |||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| @@ -40,7 +38,7 @@ class DeBERTaV2TaskTest(unittest.TestCase): | |||
| # sbert | |||
| print(self.model_id_deberta) | |||
| model = Model.from_pretrained(self.model_id_deberta) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| import numpy as np | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import FeatureExtractionModel | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import FeatureExtractionPipeline | |||
| from modelscope.preprocessors import NLPPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class FeatureExtractionTaskModelTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.feature_extraction | |||
| self.model_id = 'damo/pert_feature-extraction_base-test' | |||
| sentence1 = '测试embedding' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = NLPPreprocessor(cache_path, padding=False) | |||
| model = FeatureExtractionModel.from_pretrained(self.model_id) | |||
| pipeline1 = FeatureExtractionPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.feature_extraction, model=model, preprocessor=tokenizer) | |||
| result = pipeline1(input=self.sentence1) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1:{np.shape(result[OutputKeys.TEXT_EMBEDDING])}') | |||
| result = pipeline2(input=self.sentence1) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1: {np.shape(result[OutputKeys.TEXT_EMBEDDING])}') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = NLPPreprocessor(model.model_dir, padding=False) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.feature_extraction, model=model, preprocessor=tokenizer) | |||
| result = pipeline_ins(input=self.sentence1) | |||
| print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.feature_extraction, model=self.model_id) | |||
| result = pipeline_ins(input=self.sentence1) | |||
| print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.feature_extraction) | |||
| result = pipeline_ins(input=self.sentence1) | |||
| print(np.shape(result[OutputKeys.TEXT_EMBEDDING])) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,13 +1,15 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from regex import R | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM, | |||
| VecoForMaskedLM) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import FillMaskPipeline | |||
| from modelscope.preprocessors import FillMaskPreprocessor | |||
| from modelscope.preprocessors import NLPPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||
| @@ -51,7 +53,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| # sbert | |||
| for language in ['zh']: | |||
| model_dir = snapshot_download(self.model_id_sbert[language]) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = StructBertForMaskedLM.from_pretrained(model_dir) | |||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| @@ -66,7 +68,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| # veco | |||
| model_dir = snapshot_download(self.model_id_veco) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = VecoForMaskedLM.from_pretrained(model_dir) | |||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| @@ -80,13 +82,28 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' | |||
| ) | |||
| # bert | |||
| language = 'zh' | |||
| model_dir = snapshot_download(self.model_id_bert, revision='beta') | |||
| preprocessor = NLPPreprocessor( | |||
| model_dir, first_sequence='sentence', second_sequence=None) | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| ori_text = self.ori_texts[language] | |||
| test_input = self.test_inputs[language] | |||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' | |||
| f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| # sbert | |||
| for language in ['zh']: | |||
| print(self.model_id_sbert[language]) | |||
| model = Model.from_pretrained(self.model_id_sbert[language]) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model.model_dir, | |||
| first_sequence='sentence', | |||
| second_sequence=None) | |||
| @@ -100,7 +117,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| # veco | |||
| model = Model.from_pretrained(self.model_id_veco) | |||
| preprocessor = FillMaskPreprocessor( | |||
| preprocessor = NLPPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| @@ -113,6 +130,18 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||
| f'{pipeline_ins(test_input)}\n') | |||
| # bert | |||
| language = 'zh' | |||
| model = Model.from_pretrained(self.model_id_bert, revision='beta') | |||
| preprocessor = NLPPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| pipeline_ins.model, f'fill_mask_bert_{language}' | |||
| print( | |||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| # veco | |||
| @@ -131,6 +160,16 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||
| # Bert | |||
| language = 'zh' | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, | |||
| model=self.model_id_bert, | |||
| model_revision='beta') | |||
| print( | |||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.fill_mask) | |||
| @@ -7,7 +7,7 @@ from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition, | |||
| TransformerCRFForNamedEntityRecognition) | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline | |||
| from modelscope.preprocessors import NERPreprocessor | |||
| from modelscope.preprocessors import TokenClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -26,7 +26,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_tcrf_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.tcrf_model_id) | |||
| tokenizer = NERPreprocessor(cache_path) | |||
| tokenizer = TokenClassificationPreprocessor(cache_path) | |||
| model = TransformerCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NamedEntityRecognitionPipeline( | |||
| @@ -43,7 +43,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_by_direct_model_download(self): | |||
| cache_path = snapshot_download(self.lcrf_model_id) | |||
| tokenizer = NERPreprocessor(cache_path) | |||
| tokenizer = TokenClassificationPreprocessor(cache_path) | |||
| model = LSTMCRFForNamedEntityRecognition( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = NamedEntityRecognitionPipeline( | |||
| @@ -60,7 +60,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_tcrf_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.tcrf_model_id) | |||
| tokenizer = NERPreprocessor(model.model_dir) | |||
| tokenizer = TokenClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=model, | |||
| @@ -70,7 +70,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_lcrf_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.lcrf_model_id) | |||
| tokenizer = NERPreprocessor(model.model_dir) | |||
| tokenizer = TokenClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.named_entity_recognition, | |||
| model=model, | |||
| @@ -5,8 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForSequenceClassification | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import PairSentenceClassificationPipeline | |||
| from modelscope.preprocessors import PairSentenceClassificationPreprocessor | |||
| from modelscope.pipelines.nlp import SequenceClassificationPipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||
| @@ -26,9 +26,9 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = PairSentenceClassificationPreprocessor(cache_path) | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| pipeline1 = PairSentenceClassificationPipeline( | |||
| pipeline1 = SequenceClassificationPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) | |||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| @@ -40,7 +40,7 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = PairSentenceClassificationPreprocessor(model.model_dir) | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.nli, model=model, preprocessor=tokenizer) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @@ -5,8 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForSequenceClassification | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import PairSentenceClassificationPipeline | |||
| from modelscope.preprocessors import PairSentenceClassificationPreprocessor | |||
| from modelscope.pipelines.nlp import SequenceClassificationPipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||
| @@ -26,9 +26,9 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = PairSentenceClassificationPreprocessor(cache_path) | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SbertForSequenceClassification.from_pretrained(cache_path) | |||
| pipeline1 = PairSentenceClassificationPipeline( | |||
| pipeline1 = SequenceClassificationPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentence_similarity, model=model, preprocessor=tokenizer) | |||
| @@ -43,7 +43,7 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = PairSentenceClassificationPreprocessor(model.model_dir) | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, | |||
| model=model, | |||
| @@ -6,8 +6,8 @@ from modelscope.models import Model | |||
| from modelscope.models.nlp.task_models.sequence_classification import \ | |||
| SequenceClassificationModel | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.nlp import SingleSentenceClassificationPipeline | |||
| from modelscope.preprocessors import SingleSentenceClassificationPreprocessor | |||
| from modelscope.pipelines.nlp import SequenceClassificationPipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -17,23 +17,21 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.task = Tasks.sentiment_classification | |||
| self.task = Tasks.text_classification | |||
| self.model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
| sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SingleSentenceClassificationPreprocessor(cache_path) | |||
| cache_path = snapshot_download(self.model_id, revision='beta') | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SequenceClassificationModel.from_pretrained( | |||
| self.model_id, num_labels=2) | |||
| pipeline1 = SingleSentenceClassificationPipeline( | |||
| self.model_id, num_labels=2, revision='beta') | |||
| pipeline1 = SequenceClassificationPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentiment_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| Tasks.text_classification, model=model, preprocessor=tokenizer) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence1)}') | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| @@ -41,10 +39,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SingleSentenceClassificationPreprocessor(model.model_dir) | |||
| model = Model.from_pretrained(self.model_id, revision='beta') | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentiment_classification, | |||
| task=Tasks.text_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| @@ -54,14 +52,17 @@ class SentimentClassificationTaskModelTest(unittest.TestCase, | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentiment_classification, model=self.model_id) | |||
| task=Tasks.text_classification, | |||
| model=self.model_id, | |||
| model_revision='beta') | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| self.assertTrue( | |||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.sentiment_classification) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.text_classification, model_revision='beta') | |||
| print(pipeline_ins(input=self.sentence1)) | |||
| self.assertTrue( | |||
| isinstance(pipeline_ins.model, SequenceClassificationModel)) | |||
| @@ -12,6 +12,7 @@ from modelscope.utils.test_utils import test_level | |||
| class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| sentence1 = 'i like this wonderful place' | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/bert-base-sst2' | |||
| @@ -46,7 +47,8 @@ class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| task=Tasks.text_classification, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| self.predict(pipeline_ins) | |||
| print(f'sentence1: {self.sentence1}\n' | |||
| f'pipeline1:{pipeline_ins(input=self.sentence1)}') | |||
| # @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| @unittest.skip('nlp model does not support tensor input, skipped') | |||
| @@ -32,6 +32,82 @@ class NLPPreprocessorTest(unittest.TestCase): | |||
| output['attention_mask'], | |||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | |||
| def test_token_classification_tokenize(self): | |||
| with self.subTest(tokenizer_type='bert'): | |||
| cfg = dict( | |||
| type='token-cls-tokenizer', | |||
| model_dir='bert-base-cased', | |||
| label2id={ | |||
| 'O': 0, | |||
| 'B': 1, | |||
| 'I': 2 | |||
| }) | |||
| preprocessor = build_preprocessor(cfg, Fields.nlp) | |||
| input = 'Do not meddle in the affairs of wizards, ' \ | |||
| 'for they are subtle and quick to anger.' | |||
| output = preprocessor(input) | |||
| self.assertTrue(InputFields.text in output) | |||
| self.assertEqual(output['input_ids'].tolist()[0], [ | |||
| 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, | |||
| 1116, 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, | |||
| 119, 102 | |||
| ]) | |||
| self.assertEqual(output['attention_mask'].tolist()[0], [ | |||
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |||
| 1 | |||
| ]) | |||
| self.assertEqual(output['label_mask'].tolist()[0], [ | |||
| False, True, True, True, False, True, True, True, True, True, | |||
| False, True, True, True, True, True, True, True, True, True, | |||
| True, False | |||
| ]) | |||
| self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6), | |||
| (7, 13), (14, 16), | |||
| (17, 20), (21, 28), | |||
| (29, 31), (32, 39), | |||
| (39, 40), (41, 44), | |||
| (45, 49), (50, 53), | |||
| (54, 60), (61, 64), | |||
| (65, 70), (71, 73), | |||
| (74, 79), (79, 80)]) | |||
| with self.subTest(tokenizer_type='roberta'): | |||
| cfg = dict( | |||
| type='token-cls-tokenizer', | |||
| model_dir='xlm-roberta-base', | |||
| label2id={ | |||
| 'O': 0, | |||
| 'B': 1, | |||
| 'I': 2 | |||
| }) | |||
| preprocessor = build_preprocessor(cfg, Fields.nlp) | |||
| input = 'Do not meddle in the affairs of wizards, ' \ | |||
| 'for they are subtle and quick to anger.' | |||
| output = preprocessor(input) | |||
| self.assertTrue(InputFields.text in output) | |||
| self.assertEqual(output['input_ids'].tolist()[0], [ | |||
| 0, 984, 959, 128, 19298, 23, 70, 103086, 7, 111, 6, 44239, | |||
| 99397, 4, 100, 1836, 621, 1614, 17991, 136, 63773, 47, 348, 56, | |||
| 5, 2 | |||
| ]) | |||
| self.assertEqual(output['attention_mask'].tolist()[0], [ | |||
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |||
| 1, 1, 1, 1, 1 | |||
| ]) | |||
| self.assertEqual(output['label_mask'].tolist()[0], [ | |||
| False, True, True, True, False, True, True, True, False, True, | |||
| True, False, False, False, True, True, True, True, False, True, | |||
| True, True, True, False, False, False | |||
| ]) | |||
| self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6), | |||
| (7, 13), (14, 16), | |||
| (17, 20), (21, 28), | |||
| (29, 31), (32, 40), | |||
| (41, 44), (45, 49), | |||
| (50, 53), (54, 60), | |||
| (61, 64), (65, 70), | |||
| (71, 73), (74, 80)]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -30,7 +30,7 @@ class AstScaningTest(unittest.TestCase): | |||
| def test_ast_scaning_class(self): | |||
| astScaner = AstScaning() | |||
| pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp', | |||
| 'sequence_classification_pipeline.py') | |||
| 'text_generation_pipeline.py') | |||
| output = astScaner.generate_ast(pipeline_file) | |||
| self.assertTrue(output['imports'] is not None) | |||
| self.assertTrue(output['from_imports'] is not None) | |||
| @@ -40,14 +40,12 @@ class AstScaningTest(unittest.TestCase): | |||
| self.assertIsInstance(imports, dict) | |||
| self.assertIsInstance(from_imports, dict) | |||
| self.assertIsInstance(decorators, list) | |||
| self.assertListEqual( | |||
| list(set(imports.keys()) - set(['typing', 'numpy'])), []) | |||
| self.assertEqual(len(from_imports.keys()), 9) | |||
| self.assertListEqual(list(set(imports.keys()) - set(['torch'])), []) | |||
| self.assertEqual(len(from_imports.keys()), 7) | |||
| self.assertTrue(from_imports['modelscope.metainfo'] is not None) | |||
| self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines']) | |||
| self.assertEqual( | |||
| decorators, | |||
| [('PIPELINES', 'text-classification', 'sentiment-analysis')]) | |||
| self.assertEqual(decorators, | |||
| [('PIPELINES', 'text-generation', 'text-generation')]) | |||
| def test_files_scaning_method(self): | |||
| fileScaner = FilesAstScaning() | |||