Browse Source

[to #42322933]bert with sequence classification / token classification/ fill mask refactor

1.新增支持原始bert模型(非easynlp的 backbone prefix版本)
2.支持bert的在sequence classification/fill mask /token classification上的backbone head形式
3.统一了sequence classification几个任务的pipeline到一个类
4.fill mask 支持backbone head形式
5.token classification的几个子任务(ner,word seg, part of speech)的preprocessor 统一到了一起TokenClassificationPreprocessor
6. sequence classification的几个子任务(single classification, pair classification)的preprocessor 统一到了一起SequenceClassificationPreprocessor
7. 改动register中 cls的group_key 赋值位置,之前的group_key在多个decorators的情况下,会被覆盖,obj_cls的group_key信息不正确
8. 基于backbone head形式将 原本group_key和 module同名的情况尝试做调整,如下在modelscope/pipelines/nlp/sequence_classification_pipeline.py 中 
原本
 @PIPELINES.register_module(
    Tasks.sentiment_classification, module_name=Pipelines.sentiment_classification)
改成
@PIPELINES.register_module(
    Tasks.text_classification, module_name=Pipelines.sentiment_classification)
相应的configuration.json也有改动,这样的改动更符合任务和pipline(子任务)的关系。
8. 其他相应改动为支持上述功能
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10041463
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
d721fabb34
50 changed files with 3347 additions and 837 deletions
  1. +9
    -2
      modelscope/metainfo.py
  2. +6
    -3
      modelscope/models/builder.py
  3. +14
    -8
      modelscope/models/nlp/__init__.py
  4. +7
    -0
      modelscope/models/nlp/backbones/bert.py
  5. +60
    -0
      modelscope/models/nlp/bert/__init__.py
  6. +162
    -0
      modelscope/models/nlp/bert/configuration_bert.py
  7. +2040
    -0
      modelscope/models/nlp/bert/modeling_bert.py
  8. +0
    -70
      modelscope/models/nlp/bert_for_sequence_classification.py
  9. +0
    -10
      modelscope/models/nlp/deberta_v2/__init__.py
  10. +101
    -0
      modelscope/models/nlp/heads/fill_mask_head.py
  11. +1
    -1
      modelscope/models/nlp/heads/torch_pretrain_head.py
  12. +2
    -3
      modelscope/models/nlp/masked_language.py
  13. +3
    -6
      modelscope/models/nlp/nncrf_for_named_entity_recognition.py
  14. +81
    -2
      modelscope/models/nlp/sequence_classification.py
  15. +4
    -0
      modelscope/models/nlp/task_models/__init__.py
  16. +43
    -0
      modelscope/models/nlp/task_models/feature_extraction.py
  17. +47
    -0
      modelscope/models/nlp/task_models/fill_mask.py
  18. +3
    -12
      modelscope/models/nlp/task_models/information_extraction.py
  19. +20
    -29
      modelscope/models/nlp/task_models/sequence_classification.py
  20. +23
    -6
      modelscope/models/nlp/task_models/task_model.py
  21. +4
    -11
      modelscope/models/nlp/task_models/token_classification.py
  22. +48
    -1
      modelscope/models/nlp/token_classification.py
  23. +16
    -0
      modelscope/outputs.py
  24. +5
    -2
      modelscope/pipelines/builder.py
  25. +8
    -11
      modelscope/pipelines/nlp/__init__.py
  26. +82
    -0
      modelscope/pipelines/nlp/feature_extraction_pipeline.py
  27. +6
    -3
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  28. +1
    -1
      modelscope/pipelines/nlp/information_extraction_pipeline.py
  29. +3
    -2
      modelscope/pipelines/nlp/named_entity_recognition_pipeline.py
  30. +0
    -59
      modelscope/pipelines/nlp/pair_sentence_classification_pipeline.py
  31. +43
    -29
      modelscope/pipelines/nlp/sequence_classification_pipeline.py
  32. +0
    -62
      modelscope/pipelines/nlp/sequence_classification_pipeline_base.py
  33. +0
    -56
      modelscope/pipelines/nlp/single_sentence_classification_pipeline.py
  34. +1
    -1
      modelscope/pipelines/nlp/token_classification_pipeline.py
  35. +29
    -19
      modelscope/preprocessors/__init__.py
  36. +27
    -18
      modelscope/preprocessors/nlp/__init__.py
  37. +215
    -360
      modelscope/preprocessors/nlp/nlp_base.py
  38. +1
    -0
      modelscope/utils/constant.py
  39. +1
    -1
      modelscope/utils/registry.py
  40. +2
    -1
      tests/msdatasets/test_ms_dataset.py
  41. +3
    -5
      tests/pipelines/test_deberta_tasks.py
  42. +67
    -0
      tests/pipelines/test_feature_extraction.py
  43. +44
    -5
      tests/pipelines/test_fill_mask.py
  44. +5
    -5
      tests/pipelines/test_named_entity_recognition.py
  45. +5
    -5
      tests/pipelines/test_nli.py
  46. +5
    -5
      tests/pipelines/test_sentence_similarity.py
  47. +16
    -15
      tests/pipelines/test_sentiment_classification.py
  48. +3
    -1
      tests/pipelines/test_text_classification.py
  49. +76
    -0
      tests/preprocessors/test_nlp.py
  50. +5
    -7
      tests/utils/test_ast.py

+ 9
- 2
modelscope/metainfo.py View File

@@ -91,17 +91,22 @@ class TaskModels(object):
text_classification = 'text-classification'
token_classification = 'token-classification'
information_extraction = 'information-extraction'
fill_mask = 'fill-mask'
feature_extraction = 'feature-extraction'


class Heads(object):
# nlp heads

# text cls
text_classification = 'text-classification'
# mlm
# fill mask
fill_mask = 'fill-mask'
bert_mlm = 'bert-mlm'
# roberta mlm
roberta_mlm = 'roberta-mlm'
# token cls
token_classification = 'token-classification'
# extraction
information_extraction = 'information-extraction'


@@ -203,6 +208,7 @@ class Pipelines(object):
passage_ranking = 'passage-ranking'
relation_extraction = 'relation-extraction'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'

# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'
@@ -306,6 +312,7 @@ class Preprocessors(object):
table_question_answering_preprocessor = 'table-question-answering-preprocessor'
re_tokenizer = 're-tokenizer'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 6
- 3
modelscope/models/builder.py View File

@@ -37,13 +37,16 @@ def build_backbone(cfg: ConfigDict,
cfg, BACKBONES, group_key=field, default_args=default_args)


def build_head(cfg: ConfigDict, default_args: dict = None):
def build_head(cfg: ConfigDict,
group_key: str = None,
default_args: dict = None):
""" build head given config dict

Args:
cfg (:obj:`ConfigDict`): config dict for head object.
default_args (dict, optional): Default initialization arguments.
"""

if group_key is None:
group_key = cfg[TYPE_NAME]
return build_from_cfg(
cfg, HEADS, group_key=cfg[TYPE_NAME], default_args=default_args)
cfg, HEADS, group_key=group_key, default_args=default_args)

+ 14
- 8
modelscope/models/nlp/__init__.py View File

@@ -6,7 +6,6 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .backbones import SbertModel
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .bert_for_sequence_classification import BertForSequenceClassification
from .bert_for_document_segmentation import BertForDocumentSegmentation
from .csanmt_for_translation import CsanmtForTranslation
from .heads import SequenceClassificationHead
@@ -20,12 +19,15 @@ if TYPE_CHECKING:
from .palm_v2 import PalmForTextGeneration
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering
from .star_text_to_sql import StarForTextToSql
from .sequence_classification import VecoForSequenceClassification, SbertForSequenceClassification
from .sequence_classification import (VecoForSequenceClassification,
SbertForSequenceClassification,
BertForSequenceClassification)
from .space import SpaceForDialogIntent
from .space import SpaceForDialogModeling
from .space import SpaceForDialogStateTracking
from .table_question_answering import TableQuestionAnswering
from .task_models import (InformationExtractionModel,
from .task_models import (FeatureExtractionModel,
InformationExtractionModel,
SequenceClassificationModel,
SingleBackboneTaskModelBase,
TokenClassificationModel)
@@ -37,7 +39,6 @@ else:
_import_structure = {
'backbones': ['SbertModel'],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
'bert_for_sequence_classification': ['BertForSequenceClassification'],
'bert_for_document_segmentation': ['BertForDocumentSegmentation'],
'csanmt_for_translation': ['CsanmtForTranslation'],
'heads': ['SequenceClassificationHead'],
@@ -54,15 +55,20 @@ else:
'palm_v2': ['PalmForTextGeneration'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'],
'star_text_to_sql': ['StarForTextToSql'],
'sequence_classification':
['VecoForSequenceClassification', 'SbertForSequenceClassification'],
'sequence_classification': [
'VecoForSequenceClassification', 'SbertForSequenceClassification',
'BertForSequenceClassification'
],
'space': [
'SpaceForDialogIntent', 'SpaceForDialogModeling',
'SpaceForDialogStateTracking'
],
'task_models': [
'InformationExtractionModel', 'SequenceClassificationModel',
'SingleBackboneTaskModelBase', 'TokenClassificationModel'
'FeatureExtractionModel',
'InformationExtractionModel',
'SequenceClassificationModel',
'SingleBackboneTaskModelBase',
'TokenClassificationModel',
],
'token_classification': ['SbertForTokenClassification'],
'table_question_answering': ['TableQuestionAnswering'],


+ 7
- 0
modelscope/models/nlp/backbones/bert.py View File

@@ -0,0 +1,7 @@
from modelscope.metainfo import Models
from modelscope.models.builder import BACKBONES
from modelscope.models.nlp.bert import BertModel
from modelscope.utils.constant import Fields

BACKBONES.register_module(
group_key=Fields.nlp, module_name=Models.bert, module_cls=BertModel)

+ 60
- 0
modelscope/models/nlp/bert/__init__.py View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)

from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
from .tokenization_bert_fast import BertTokenizerFast

else:
_import_structure = {
'configuration_bert':
['BERT_PRETRAINED_CONFIG_ARCHIVE_MAP', 'BertConfig', 'BertOnnxConfig'],
'tokenization_bert':
['BasicTokenizer', 'BertTokenizer', 'WordpieceTokenizer'],
}
_import_structure['tokenization_bert_fast'] = ['BertTokenizerFast']

_import_structure['modeling_bert'] = [
'BERT_PRETRAINED_MODEL_ARCHIVE_LIST',
'BertForMaskedLM',
'BertForMultipleChoice',
'BertForNextSentencePrediction',
'BertForPreTraining',
'BertForQuestionAnswering',
'BertForSequenceClassification',
'BertForTokenClassification',
'BertLayer',
'BertLMHeadModel',
'BertModel',
'BertPreTrainedModel',
'load_tf_weights_in_bert',
]

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 162
- 0
modelscope/models/nlp/bert/configuration_bert.py View File

@@ -0,0 +1,162 @@
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT model configuration """
from collections import OrderedDict
from typing import Mapping

from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfig

from modelscope.utils.logger import get_logger

logger = get_logger(__name__)


class BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`BertModel`] or a [`TFBertModel`]. It is used to instantiate a BERT model
according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar
configuration to that of the BERT
[bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to
control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.


Args:
vocab_size (`int`, *optional*, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different
tokens that can be represented by the `inputs_ids` passed when
calling [`BertModel`] or [`TFBertModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the
Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward)
layer in the Transformer encoder.
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the
encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and
`"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the
embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or
1024 or 2048).
type_vocab_size (`int`, *optional*, defaults to 2):
The vocabulary size of the `token_type_ids` passed when calling
[`BertModel`] or [`TFBertModel`].
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`,
`"relative_key"`, `"relative_key_query"`. For positional embeddings
use `"absolute"`. For more information on `"relative_key"`, please
refer to [Self-Attention with Relative Position Representations
(Shaw et al.)](https://arxiv.org/abs/1803.02155). For more
information on `"relative_key_query"`, please refer to *Method 4* in
[Improve Transformer Models with Better Relative Position Embeddings
(Huang et al.)](https://arxiv.org/abs/2009.13658).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.

Examples:

```python >>> from transformers import BertModel, BertConfig

>>> # Initializing a BERT bert-base-uncased style configuration
>>> configuration = BertConfig()

>>> # Initializing a model from the bert-base-uncased style configuration
>>> model = BertModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = 'bert'

def __init__(self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type='absolute',
use_cache=True,
classifier_dropout=None,
**kwargs):
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout


class BertOnnxConfig(OnnxConfig):

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([
('input_ids', {
0: 'batch',
1: 'sequence'
}),
('attention_mask', {
0: 'batch',
1: 'sequence'
}),
('token_type_ids', {
0: 'batch',
1: 'sequence'
}),
])

+ 2040
- 0
modelscope/models/nlp/bert/modeling_bert.py
File diff suppressed because it is too large
View File


+ 0
- 70
modelscope/models/nlp/bert_for_sequence_classification.py View File

@@ -1,70 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict

import json
import numpy as np

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['BertForSequenceClassification']


@MODELS.register_module(Tasks.text_classification, module_name=Models.bert)
class BertForSequenceClassification(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
# Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
# Predictor.__init__(self, *args, **kwargs)
"""initialize the sequence classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
import torch
from easynlp.appzoo import SequenceClassification
from easynlp.core.predictor import get_model_predictor
self.model = get_model_predictor(
model_dir=self.model_dir,
model_cls=SequenceClassification,
input_keys=[('input_ids', torch.LongTensor),
('attention_mask', torch.LongTensor),
('token_type_ids', torch.LongTensor)],
output_keys=['predictions', 'probabilities', 'logits'])

self.label_path = os.path.join(self.model_dir, 'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {idx: name for name, idx in self.label_mapping.items()}

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
return self.model.predict(input)

def postprocess(self, inputs: Dict[str, np.ndarray],
**kwargs) -> Dict[str, np.ndarray]:
# N x num_classes
probs = inputs['probabilities']
result = {
'probs': probs,
}

return result

+ 0
- 10
modelscope/models/nlp/deberta_v2/__init__.py View File

@@ -21,21 +21,12 @@ from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

_import_structure = {
'configuration_deberta_v2': [
'DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP', 'DebertaV2Config',
'DebertaV2OnnxConfig'
],
'tokenization_deberta_v2': ['DebertaV2Tokenizer'],
}

if TYPE_CHECKING:
from .configuration_deberta_v2 import DebertaV2Config
from .tokenization_deberta_v2 import DebertaV2Tokenizer
from .tokenization_deberta_v2_fast import DebertaV2TokenizerFast

from .modeling_deberta_v2 import (
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaV2ForMaskedLM,
DebertaV2ForMultipleChoice,
DebertaV2ForQuestionAnswering,
@@ -55,7 +46,6 @@ else:
'DebertaV2TokenizerFast'
]
_import_structure['modeling_deberta_v2'] = [
'DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST',
'DebertaV2ForMaskedLM',
'DebertaV2ForMultipleChoice',
'DebertaV2ForQuestionAnswering',


+ 101
- 0
modelscope/models/nlp/heads/fill_mask_head.py View File

@@ -0,0 +1,101 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN

from modelscope.metainfo import Heads
from modelscope.models.base import TorchHead
from modelscope.models.builder import HEADS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks


@HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm)
class BertFillMaskHead(TorchHead):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = BertOnlyMLMHead(self.config)

def forward(self, sequence_output):
prediction_scores = self.cls(sequence_output)
return {OutputKeys.LOGITS: prediction_scores}

def compute_loss(self, outputs: Dict[str, torch.Tensor],
labels) -> Dict[str, torch.Tensor]:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
outputs.view(-1, self.config.vocab_size), labels.view(-1))
return {OutputKeys.LOSS: masked_lm_loss}


class BertPredictionHeadTransform(nn.Module):

def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


class BertLMPredictionHead(nn.Module):

def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class BertOnlyMLMHead(nn.Module):

def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)

def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
prediction_scores = self.predictions(sequence_output)
return prediction_scores

+ 1
- 1
modelscope/models/nlp/heads/torch_pretrain_head.py View File

@@ -11,7 +11,7 @@ from modelscope.models.builder import HEADS
from modelscope.utils.constant import Tasks


@HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm)
# @HEADS.register_module(Tasks.fill_mask, module_name=Heads.bert_mlm)
class BertMLMHead(BertOnlyMLMHead, TorchHead):

def compute_loss(self, outputs: Dict[str, torch.Tensor],


+ 2
- 3
modelscope/models/nlp/masked_language.py View File

@@ -1,10 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from transformers import BertForMaskedLM as BertForMaskedLMTransformer

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.bert import \
BertForMaskedLM as BertForMaskedLMTransformer
from modelscope.models.nlp.deberta_v2 import \
DebertaV2ForMaskedLM as DebertaV2ForMaskedLMTransformer
from modelscope.models.nlp.structbert import SbertForMaskedLM


+ 3
- 6
modelscope/models/nlp/nncrf_for_named_entity_recognition.py View File

@@ -41,12 +41,9 @@ class SequenceLabelingForNamedEntityRecognition(TorchModel):

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input_tensor = {
'input_ids':
torch.tensor(input['input_ids']).unsqueeze(0),
'attention_mask':
torch.tensor(input['attention_mask']).unsqueeze(0),
'label_mask':
torch.tensor(input['label_mask'], dtype=torch.bool).unsqueeze(0)
'input_ids': input['input_ids'],
'attention_mask': input['attention_mask'],
'label_mask': input['label_mask'],
}
output = {
'text': input['text'],


+ 81
- 2
modelscope/models/nlp/sequence_classification.py View File

@@ -7,6 +7,7 @@ from torch import nn
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.bert import BertPreTrainedModel
from modelscope.models.nlp.structbert import SbertPreTrainedModel
from modelscope.models.nlp.veco import \
VecoForSequenceClassification as VecoForSequenceClassificationTransform
@@ -16,7 +17,10 @@ from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)

__all__ = ['SbertForSequenceClassification', 'VecoForSequenceClassification']
__all__ = [
'SbertForSequenceClassification', 'VecoForSequenceClassification',
'BertForSequenceClassification'
]


class SequenceClassificationBase(TorchModel):
@@ -132,7 +136,7 @@ class SbertForSequenceClassification(SequenceClassificationBase,
label2id = parse_label_mapping(model_dir)
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)
cls.id2label = {id: label for label, id in label2id.items()}
model_args = {} if num_labels is None else {'num_labels': num_labels}
return super(SbertPreTrainedModel,
SbertForSequenceClassification).from_pretrained(
@@ -206,3 +210,78 @@ class VecoForSequenceClassification(TorchModel,
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)


@MODELS.register_module(Tasks.sentence_similarity, module_name=Models.bert)
@MODELS.register_module(
Tasks.sentiment_classification, module_name=Models.bert)
@MODELS.register_module(Tasks.nli, module_name=Models.bert)
@MODELS.register_module(Tasks.text_classification, module_name=Models.bert)
class BertForSequenceClassification(SequenceClassificationBase,
BertPreTrainedModel):
"""Bert sequence classification model.

Inherited from SequenceClassificationBase.
"""
base_model_prefix: str = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']

def __init__(self, config, model_dir):
if hasattr(config, 'base_model_prefix'):
BertForSequenceClassification.base_model_prefix = config.base_model_prefix
super().__init__(config, model_dir)

def build_base_model(self):
from .bert import BertModel
return BertModel(self.config, add_pooling_layer=True)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.

@param kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
@return: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""

model_dir = kwargs.get('model_dir')
num_labels = kwargs.get('num_labels')
if num_labels is None:
label2id = parse_label_mapping(model_dir)
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)

model_args = {} if num_labels is None else {'num_labels': num_labels}
return super(BertPreTrainedModel,
BertForSequenceClassification).from_pretrained(
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)

+ 4
- 0
modelscope/models/nlp/task_models/__init__.py View File

@@ -5,6 +5,8 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .information_extraction import InformationExtractionModel
from .feature_extraction import FeatureExtractionModel
from .fill_mask import FillMaskModel
from .sequence_classification import SequenceClassificationModel
from .task_model import SingleBackboneTaskModelBase
from .token_classification import TokenClassificationModel
@@ -12,6 +14,8 @@ if TYPE_CHECKING:
else:
_import_structure = {
'information_extraction': ['InformationExtractionModel'],
'feature_extraction': ['FeatureExtractionModel'],
'fill_mask': ['FillMaskModel'],
'sequence_classification': ['SequenceClassificationModel'],
'task_model': ['SingleBackboneTaskModelBase'],
'token_classification': ['TokenClassificationModel'],


+ 43
- 0
modelscope/models/nlp/task_models/feature_extraction.py View File

@@ -0,0 +1,43 @@
from typing import Any, Dict

import numpy as np

from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
from modelscope.models.nlp.bert import BertConfig
from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import parse_label_mapping

__all__ = ['FeatureExtractionModel']


@MODELS.register_module(
Tasks.feature_extraction, module_name=TaskModels.feature_extraction)
class FeatureExtractionModel(SingleBackboneTaskModelBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the fill mask model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

self.build_backbone(self.backbone_cfg)

def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:

# backbone do not need labels, only head need for loss compute
labels = input.pop(OutputKeys.LABELS, None)

outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
if labels is not None:
input[OutputKeys.LABELS] = labels

return {OutputKeys.TEXT_EMBEDDING: sequence_output}

+ 47
- 0
modelscope/models/nlp/task_models/fill_mask.py View File

@@ -0,0 +1,47 @@
from typing import Any, Dict

import numpy as np

from modelscope.metainfo import TaskModels
from modelscope.models.builder import MODELS
from modelscope.models.nlp.bert import BertConfig
from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import parse_label_mapping

__all__ = ['FillMaskModel']


@MODELS.register_module(Tasks.fill_mask, module_name=TaskModels.fill_mask)
class FillMaskModel(SingleBackboneTaskModelBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the fill mask model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

self.build_backbone(self.backbone_cfg)
self.build_head(self.head_cfg)

def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:

# backbone do not need labels, only head need for loss compute
labels = input.pop(OutputKeys.LABELS, None)

outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(sequence_output)

if labels is not None:
input[OutputKeys.LABELS] = labels
loss = self.compute_loss(outputs, labels)
outputs.update(loss)
outputs[OutputKeys.INPUT_IDS] = input[OutputKeys.INPUT_IDS]
return outputs

+ 3
- 12
modelscope/models/nlp/task_models/information_extraction.py View File

@@ -26,21 +26,12 @@ class InformationExtractionModel(SingleBackboneTaskModelBase):
"""
super().__init__(model_dir, *args, **kwargs)

backbone_cfg = self.cfg.backbone
head_cfg = self.cfg.head
self.build_backbone(backbone_cfg)
self.build_head(head_cfg)
self.build_backbone(self.backbone_cfg)
self.build_head(self.head_cfg)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:
outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(sequence_output, input['text'],
input['offsets'])
return {OutputKeys.SPO_LIST: outputs}

def extract_backbone_outputs(self, outputs):
sequence_output = None
pooled_output = None
if hasattr(self.backbone, 'extract_sequence_outputs'):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
return sequence_output, pooled_output

+ 20
- 29
modelscope/models/nlp/task_models/sequence_classification.py View File

@@ -11,10 +11,14 @@ from modelscope.models.nlp.task_models.task_model import \
SingleBackboneTaskModelBase
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import parse_label_mapping

__all__ = ['SequenceClassificationModel']


@MODELS.register_module(
Tasks.sentence_similarity, module_name=TaskModels.text_classification)
@MODELS.register_module(Tasks.nli, module_name=TaskModels.text_classification)
@MODELS.register_module(
Tasks.sentiment_classification, module_name=TaskModels.text_classification)
@MODELS.register_module(
@@ -31,49 +35,36 @@ class SequenceClassificationModel(SingleBackboneTaskModelBase):
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

backbone_cfg = self.cfg.backbone
head_cfg = self.cfg.head

# get the num_labels from label_mapping.json
self.id2label = {}
self.label_path = os.path.join(model_dir, 'label_mapping.json')
if os.path.exists(self.label_path):
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.id2label = {
idx: name
for name, idx in self.label_mapping.items()
}
head_cfg['num_labels'] = len(self.label_mapping)
# get the num_labels
num_labels = kwargs.get('num_labels')
if num_labels is None:
label2id = parse_label_mapping(model_dir)
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)
self.id2label = {id: label for label, id in label2id.items()}
self.head_cfg['num_labels'] = num_labels

self.build_backbone(backbone_cfg)
self.build_head(head_cfg)
self.build_backbone(self.backbone_cfg)
self.build_head(self.head_cfg)

def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:
# backbone do not need labels, only head need for loss compute
labels = input.pop(OutputKeys.LABELS, None)

outputs = super().forward(input)
sequence_output, pooled_output = self.extract_backbone_outputs(outputs)
outputs = self.head.forward(pooled_output)
if 'labels' in input:
loss = self.compute_loss(outputs, input['labels'])
if labels is not None:
input[OutputKeys.LABELS] = labels
loss = self.compute_loss(outputs, labels)
outputs.update(loss)
return outputs

def extract_logits(self, outputs):
return outputs[OutputKeys.LOGITS].cpu().detach()

def extract_backbone_outputs(self, outputs):
sequence_output = None
pooled_output = None
if hasattr(self.backbone, 'extract_sequence_outputs'):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
if hasattr(self.backbone, 'extract_pooled_outputs'):
pooled_output = self.backbone.extract_pooled_outputs(outputs)
return sequence_output, pooled_output

def compute_loss(self, outputs, labels):
loss = self.head.compute_loss(outputs, labels)
return loss

def postprocess(self, input, **kwargs):
logits = self.extract_logits(input)
probs = logits.softmax(-1).numpy()


+ 23
- 6
modelscope/models/nlp/task_models/task_model.py View File

@@ -74,7 +74,7 @@ class BaseTaskModel(TorchModel, ABC):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.cfg = ConfigDict(kwargs)
self.config = ConfigDict(kwargs)

def __repr__(self):
# only log backbone and head name
@@ -397,6 +397,9 @@ class SingleBackboneTaskModelBase(BaseTaskModel):

def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.backbone_cfg = self.config.get('backbone', None)
assert self.backbone_cfg is not None
self.head_cfg = self.config.get('head', None)

def build_backbone(self, cfg):
if 'prefix' in cfg:
@@ -405,9 +408,13 @@ class SingleBackboneTaskModelBase(BaseTaskModel):
setattr(self, cfg['prefix'], backbone)

def build_head(self, cfg):
if cfg is None:
raise ValueError(
'Head config is missing, check if this was a backbone-only model'
)
if 'prefix' in cfg:
self._head_prefix = cfg['prefix']
head = build_head(cfg)
head = build_head(cfg, group_key=self.group_key)
setattr(self, self._head_prefix, head)
return head

@@ -431,8 +438,18 @@ class SingleBackboneTaskModelBase(BaseTaskModel):
outputs = self.backbone.forward(**input)
return outputs

def compute_loss(self, outputs: Dict[str, Any], labels):
raise NotImplementedError()
def compute_loss(self, outputs, labels):
loss = self.head.compute_loss(outputs, labels)
return loss

def extract_backbone_outputs(self, outputs):
sequence_output = None
pooled_output = None
if hasattr(self.backbone, 'extract_sequence_outputs'):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
if hasattr(self.backbone, 'extract_pooled_outputs'):
pooled_output = self.backbone.extract_pooled_outputs(outputs)
return sequence_output, pooled_output


class EncoderDecoderTaskModelBase(BaseTaskModel):
@@ -453,7 +470,7 @@ class EncoderDecoderTaskModelBase(BaseTaskModel):

def build_encoder(self):
encoder = build_backbone(
self.cfg,
self.config,
type_name=self._encoder_key_in_cfg,
task_name=Tasks.backbone)
setattr(self, self._encoder_prefix, encoder)
@@ -461,7 +478,7 @@ class EncoderDecoderTaskModelBase(BaseTaskModel):

def build_decoder(self):
decoder = build_backbone(
self.cfg,
self.config,
type_name=self._decoder_key_in_cfg,
task_name=Tasks.backbone)
setattr(self, self._decoder_prefix, decoder)


+ 4
- 11
modelscope/models/nlp/task_models/token_classification.py View File

@@ -31,9 +31,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase):
if 'base_model_prefix' in kwargs:
self._base_model_prefix = kwargs['base_model_prefix']

backbone_cfg = self.cfg.backbone
head_cfg = self.cfg.head

# get the num_labels
num_labels = kwargs.get('num_labels')
if num_labels is None:
@@ -41,12 +38,12 @@ class TokenClassificationModel(SingleBackboneTaskModelBase):
if label2id is not None and len(label2id) > 0:
num_labels = len(label2id)
self.id2label = {id: label for label, id in label2id.items()}
head_cfg['num_labels'] = num_labels
self.head_cfg['num_labels'] = num_labels

self.build_backbone(backbone_cfg)
self.build_head(head_cfg)
self.build_backbone(self.backbone_cfg)
self.build_head(self.head_cfg)

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
def forward(self, **input: Dict[str, Any]) -> Dict[str, np.ndarray]:
labels = None
if OutputKeys.LABEL in input:
labels = input.pop(OutputKeys.LABEL)
@@ -71,10 +68,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase):
sequence_output = self.backbone.extract_sequence_outputs(outputs)
return sequence_output, pooled_output

def compute_loss(self, outputs, labels):
loss = self.head.compute_loss(outputs, labels)
return loss

def postprocess(self, input, **kwargs):
logits = self.extract_logits(input)
pred = torch.argmax(logits[0], dim=-1)


+ 48
- 1
modelscope/models/nlp/token_classification.py View File

@@ -10,12 +10,13 @@ from torch import nn
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.bert import BertPreTrainedModel
from modelscope.models.nlp.structbert import SbertPreTrainedModel
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.tensor_utils import (torch_nested_detach,
torch_nested_numpify)
from .structbert import SbertPreTrainedModel

__all__ = ['SbertForTokenClassification']

@@ -171,3 +172,49 @@ class SbertForTokenClassification(TokenClassification, SbertPreTrainedModel):
pretrained_model_name_or_path=kwargs.get('model_dir'),
model_dir=kwargs.get('model_dir'),
**model_args)


@MODELS.register_module(Tasks.word_segmentation, module_name=Models.bert)
@MODELS.register_module(Tasks.token_classification, module_name=Models.bert)
class BertForSequenceClassification(TokenClassification, BertPreTrainedModel):
"""Bert token classification model.

Inherited from TokenClassificationBase.
"""
base_model_prefix: str = 'bert'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r'position_ids']

def __init__(self, config, model_dir):
if hasattr(config, 'base_model_prefix'):
BertForSequenceClassification.base_model_prefix = config.base_model_prefix
super().__init__(config, model_dir)

def build_base_model(self):
from .bert import BertModel
return BertModel(self.config, add_pooling_layer=True)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs)

+ 16
- 0
modelscope/outputs.py View File

@@ -417,6 +417,22 @@ TASK_OUTPUTS = {
# }
Tasks.fill_mask: [OutputKeys.TEXT],

# feature extraction result for single sample
# {
# "text_embedding": [[
# [1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04],
# [6.45841064e-05, 1.15997791e-04, 5.11605394e-05, 9.87020373e-01],
# [2.66957268e-05, 4.72324500e-05, 9.74208378e-05, 4.18022355e-05]
# ],
# [
# [2.97343540e-05, 5.81317654e-05, 5.44203431e-05, 6.28319322e-05],
# [8.24327726e-05, 4.66077945e-05, 5.32869453e-05, 4.16190960e-05],
# [3.61441926e-05, 3.38475402e-05, 3.44323053e-05, 5.70138109e-05]
# ]
# ]
# }
Tasks.feature_extraction: [OutputKeys.TEXT_EMBEDDING],

# (Deprecated) dialog intent prediction result for single sample
# {'output': {'prediction': array([2.62349960e-03, 4.12110658e-03, 4.12748595e-05, 3.77560973e-05,
# 1.08599677e-04, 1.72710388e-05, 2.95618793e-05, 1.93638436e-04,


+ 5
- 2
modelscope/pipelines/builder.py View File

@@ -52,8 +52,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vit_object-detection_coco'),
Tasks.image_denoising: (Pipelines.image_denoise,
'damo/cv_nafnet_image-denoise_sidd'),
Tasks.text_classification: (Pipelines.sentiment_analysis,
'damo/bert-base-sst2'),
Tasks.text_classification:
(Pipelines.sentiment_classification,
'damo/nlp_structbert_sentiment-classification_chinese-base'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.zero_shot_classification:
@@ -80,6 +81,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
Tasks.action_detection: (Pipelines.action_detection,


+ 8
- 11
modelscope/pipelines/nlp/__init__.py View File

@@ -11,12 +11,13 @@ if TYPE_CHECKING:
from .dialog_state_tracking_pipeline import DialogStateTrackingPipeline
from .document_segmentation_pipeline import DocumentSegmentationPipeline
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline
from .feature_extraction_pipeline import FeatureExtractionPipeline
from .fill_mask_pipeline import FillMaskPipeline
from .fill_mask_ponet_pipeline import FillMaskPonetPipeline
from .information_extraction_pipeline import InformationExtractionPipeline
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline
from .single_sentence_classification_pipeline import SingleSentenceClassificationPipeline
from .passage_ranking_pipeline import PassageRankingPipeline
from .sentence_embedding_pipeline import SentenceEmbeddingPipeline
from .sequence_classification_pipeline import SequenceClassificationPipeline
from .summarization_pipeline import SummarizationPipeline
from .text_classification_pipeline import TextClassificationPipeline
@@ -27,8 +28,7 @@ if TYPE_CHECKING:
from .translation_pipeline import TranslationPipeline
from .word_segmentation_pipeline import WordSegmentationPipeline
from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline
from .passage_ranking_pipeline import PassageRankingPipeline
from .sentence_embedding_pipeline import SentenceEmbeddingPipeline

else:
_import_structure = {
'conversational_text_to_sql_pipeline':
@@ -41,16 +41,15 @@ else:
'dialog_state_tracking_pipeline': ['DialogStateTrackingPipeline'],
'document_segmentation_pipeline': ['DocumentSegmentationPipeline'],
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'],
'feature_extraction_pipeline': ['FeatureExtractionPipeline'],
'fill_mask_pipeline': ['FillMaskPipeline'],
'fill_mask_ponet_pipeline': ['FillMaskPoNetPipeline'],
'information_extraction_pipeline': ['InformationExtractionPipeline'],
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'information_extraction_pipeline': ['InformationExtractionPipeline'],
'pair_sentence_classification_pipeline':
['PairSentenceClassificationPipeline'],
'passage_ranking_pipeline': ['PassageRankingPipeline'],
'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline'],
'sequence_classification_pipeline': ['SequenceClassificationPipeline'],
'single_sentence_classification_pipeline':
['SingleSentenceClassificationPipeline'],
'summarization_pipeline': ['SummarizationPipeline'],
'text_classification_pipeline': ['TextClassificationPipeline'],
'text_error_correction_pipeline': ['TextErrorCorrectionPipeline'],
@@ -61,8 +60,6 @@ else:
'word_segmentation_pipeline': ['WordSegmentationPipeline'],
'zero_shot_classification_pipeline':
['ZeroShotClassificationPipeline'],
'passage_ranking_pipeline': ['PassageRankingPipeline'],
'sentence_embedding_pipeline': ['SentenceEmbeddingPipeline']
}

import sys


+ 82
- 0
modelscope/pipelines/nlp/feature_extraction_pipeline.py View File

@@ -0,0 +1,82 @@
import os
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import NLPPreprocessor, Preprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['FeatureExtractionPipeline']


@PIPELINES.register_module(
Tasks.feature_extraction, module_name=Pipelines.feature_extraction)
class FeatureExtractionPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
first_sequence='sentence',
**kwargs):
"""Use `model` and `preprocessor` to create a nlp feature extraction pipeline for prediction

Args:
model (str or Model): Supply either a local model dir which supported feature extraction task, or a
no-head model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the sentence in.
sequence_length: Max sequence length in the user's custom scenario. 128 will be used as a default value.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipe_ins = pipeline('feature_extraction', model='damo/nlp_structbert_feature-extraction_english-large')
>>> input = 'Everything you love is treasure'
>>> print(pipe_ins(input))


"""
model = model if isinstance(model,
Model) else Model.from_pretrained(model)

if preprocessor is None:
preprocessor = NLPPreprocessor(
model.model_dir,
padding=kwargs.pop('padding', False),
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

self.preprocessor = preprocessor
self.config = Config.from_file(
os.path.join(model.model_dir, ModelFile.CONFIGURATION))
self.tokenizer = preprocessor.tokenizer

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(**inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

return {
OutputKeys.TEXT_EMBEDDING:
inputs[OutputKeys.TEXT_EMBEDDING].tolist()
}

+ 6
- 3
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -10,7 +10,7 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import FillMaskPreprocessor, Preprocessor
from modelscope.preprocessors import NLPPreprocessor, Preprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

@@ -57,7 +57,7 @@ class FillMaskPipeline(Pipeline):
model, Model) else Model.from_pretrained(model)

if preprocessor is None:
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
fill_mask_model.model_dir,
first_sequence=first_sequence,
second_sequence=None,
@@ -118,7 +118,10 @@ class FillMaskPipeline(Pipeline):
logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy()
input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy()
pred_ids = np.argmax(logits, axis=-1)
model_type = self.model.config.model_type
if hasattr(self.model.config, 'backbone'):
model_type = self.model.config.backbone.type
else:
model_type = self.model.config.model_type
process_type = model_type if model_type in self.mask_id else _type_map[
model_type]
rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids,


+ 1
- 1
modelscope/pipelines/nlp/information_extraction_pipeline.py View File

@@ -36,7 +36,7 @@ class InformationExtractionPipeline(Pipeline):
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)
return self.model(**inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:


+ 3
- 2
modelscope/pipelines/nlp/named_entity_recognition_pipeline.py View File

@@ -9,7 +9,8 @@ from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import NERPreprocessor, Preprocessor
from modelscope.preprocessors import (Preprocessor,
TokenClassificationPreprocessor)
from modelscope.utils.constant import Tasks

__all__ = ['NamedEntityRecognitionPipeline']
@@ -46,7 +47,7 @@ class NamedEntityRecognitionPipeline(Pipeline):
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = NERPreprocessor(
preprocessor = TokenClassificationPreprocessor(
model.model_dir,
sequence_length=kwargs.pop('sequence_length', 512))
model.eval()


+ 0
- 59
modelscope/pipelines/nlp/pair_sentence_classification_pipeline.py View File

@@ -1,59 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Union

from modelscope.models.base import Model
from ...metainfo import Pipelines
from ...preprocessors import (PairSentenceClassificationPreprocessor,
Preprocessor)
from ...utils.constant import Tasks
from ..builder import PIPELINES
from .sequence_classification_pipeline_base import \
SequenceClassificationPipelineBase

__all__ = ['PairSentenceClassificationPipeline']


@PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli)
@PIPELINES.register_module(
Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity)
class PairSentenceClassificationPipeline(SequenceClassificationPipelineBase):

def __init__(self,
model: Union[Model, str],
preprocessor: Preprocessor = None,
first_sequence='first_sequence',
second_sequence='second_sequence',
**kwargs):
"""Use `model` and `preprocessor` to create a nlp pair sequence classification pipeline for prediction.

Args:
model (str or Model): Supply either a local model dir which supported the sequence classification task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the first sentence in.
second_sequence: The key to read the second sentence in.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

NOTE: Inputs of type 'tuple' or 'list' are also supported. In this scenario, the 'first_sequence' and
'second_sequence' param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='nli', model='damo/nlp_structbert_nli_chinese-base')
>>> sentence1 = '四川商务职业学院和四川财经职业学院哪个好?'
>>> sentence2 = '四川商务职业学院商务管理在哪个校区?'
>>> print(pipeline_ins((sentence1, sentence2)))
>>> # Or use the dict input:
>>> print(pipeline_ins({'first_sequence': sentence1, 'second_sequence': sentence2}))

To view other examples plese check the tests/pipelines/test_nli.py.
"""
if preprocessor is None:
preprocessor = PairSentenceClassificationPreprocessor(
model.model_dir if isinstance(model, Model) else model,
first_sequence=first_sequence,
second_sequence=second_sequence,
sequence_length=kwargs.pop('sequence_length', 512))
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

+ 43
- 29
modelscope/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -1,48 +1,64 @@
from typing import Any, Dict, Union

import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp import BertForSequenceClassification
from modelscope.models.base import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.preprocessors import (Preprocessor,
SequenceClassificationPreprocessor)
from modelscope.utils.constant import Tasks

__all__ = ['SequenceClassificationPipeline']


@PIPELINES.register_module(
Tasks.text_classification, module_name=Pipelines.sentiment_analysis)
@PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli)
@PIPELINES.register_module(
Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity)
@PIPELINES.register_module(
Tasks.text_classification, module_name=Pipelines.sentiment_classification)
class SequenceClassificationPipeline(Pipeline):

def __init__(self,
model: Union[BertForSequenceClassification, str],
preprocessor: SequenceClassificationPreprocessor = None,
model: Union[Model, str],
preprocessor: Preprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
"""This is the base class for all the sequence classification sub-tasks.

Args:
model (BertForSequenceClassification): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (str or Model): A model instance or a model local dir or a model id in the model hub.
preprocessor (Preprocessor): a preprocessor instance, must not be None.
"""
assert isinstance(model, str) or isinstance(model, BertForSequenceClassification), \
'model must be a single str or BertForSequenceClassification'
sc_model = model if isinstance(
model,
BertForSequenceClassification) else Model.from_pretrained(model)
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
first_sequence = kwargs.pop('first_sequence', 'first_sequence')
second_sequence = kwargs.pop('second_sequence', None)

if preprocessor is None:
preprocessor = SequenceClassificationPreprocessor(
sc_model.model_dir,
first_sequence='sentence',
second_sequence=None,
model.model_dir if isinstance(model, Model) else model,
first_sequence=first_sequence,
second_sequence=second_sequence,
sequence_length=kwargs.pop('sequence_length', 512))
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)

assert hasattr(self.model, 'id2label'), \
'id2label map should be initalizaed in init function.'
assert preprocessor is not None
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.id2label = kwargs.get('id2label')
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
self.id2label = self.preprocessor.id2label
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \
'as a parameter or make sure the preprocessor has the attribute.'

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(**inputs, **forward_params)

def postprocess(self,
inputs: Dict[str, Any],
@@ -50,20 +66,18 @@ class SequenceClassificationPipeline(Pipeline):
"""process the prediction results

Args:
inputs (Dict[str, Any]): input data dict
topk (int): return topk classification result.

inputs (Dict[str, Any]): _description_
topk (int): The topk probs to take
Returns:
Dict[str, str]: the prediction results
"""
# NxC np.ndarray
probs = inputs['probs'][0]
probs = inputs[OutputKeys.PROBABILITIES][0]
num_classes = probs.shape[0]
topk = min(topk, num_classes)
top_indices = np.argpartition(probs, -topk)[-topk:]
cls_ids = top_indices[np.argsort(probs[top_indices])]
probs = probs[cls_ids].tolist()

cls_names = [self.model.id2label[cid] for cid in cls_ids]

cls_names = [self.id2label[cid] for cid in cls_ids]
return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names}

+ 0
- 62
modelscope/pipelines/nlp/sequence_classification_pipeline_base.py View File

@@ -1,62 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict, Union

import numpy as np
import torch

from modelscope.models.base import Model
from modelscope.outputs import OutputKeys
from ...preprocessors import Preprocessor
from ..base import Pipeline


class SequenceClassificationPipelineBase(Pipeline):

def __init__(self, model: Union[Model, str], preprocessor: Preprocessor,
**kwargs):
"""This is the base class for all the sequence classification sub-tasks.

Args:
model (str or Model): A model instance or a model local dir or a model id in the model hub.
preprocessor (Preprocessor): a preprocessor instance, must not be None.
"""
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or Model'
model = model if isinstance(model,
Model) else Model.from_pretrained(model)
assert preprocessor is not None
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.id2label = kwargs.get('id2label')
if self.id2label is None and hasattr(self.preprocessor, 'id2label'):
self.id2label = self.preprocessor.id2label
assert self.id2label is not None, 'Cannot convert id to the original label, please pass in the mapping ' \
'as a parameter or make sure the preprocessor has the attribute.'

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(**inputs, **forward_params)

def postprocess(self,
inputs: Dict[str, Any],
topk: int = 5) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_
topk (int): The topk probs to take
Returns:
Dict[str, str]: the prediction results
"""

probs = inputs[OutputKeys.PROBABILITIES][0]
num_classes = probs.shape[0]
topk = min(topk, num_classes)
top_indices = np.argpartition(probs, -topk)[-topk:]
cls_ids = top_indices[np.argsort(probs[top_indices])]
probs = probs[cls_ids].tolist()

cls_names = [self.id2label[cid] for cid in cls_ids]
return {OutputKeys.SCORES: probs, OutputKeys.LABELS: cls_names}

+ 0
- 56
modelscope/pipelines/nlp/single_sentence_classification_pipeline.py View File

@@ -1,56 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Union

from ...metainfo import Pipelines
from ...models import Model
from ...preprocessors import (Preprocessor,
SingleSentenceClassificationPreprocessor)
from ...utils.constant import Tasks
from ..builder import PIPELINES
from .sequence_classification_pipeline_base import \
SequenceClassificationPipelineBase

__all__ = ['SingleSentenceClassificationPipeline']


@PIPELINES.register_module(
Tasks.sentiment_classification,
module_name=Pipelines.sentiment_classification)
class SingleSentenceClassificationPipeline(SequenceClassificationPipelineBase):

def __init__(self,
model: Union[Model, str],
preprocessor: Preprocessor = None,
first_sequence='first_sequence',
**kwargs):
"""Use `model` and `preprocessor` to create a nlp single sequence classification pipeline for prediction.

Args:
model (str or Model): Supply either a local model dir which supported the sequence classification task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the first sentence in.
sequence_length: Max sequence length in the user's custom scenario. 512 will be used as a default value.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='sentiment-classification',
>>> model='damo/nlp_structbert_sentiment-classification_chinese-base')
>>> sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音'
>>> print(pipeline_ins(sentence1))
>>> # Or use the dict input:
>>> print(pipeline_ins({'first_sequence': sentence1}))

To view other examples plese check the tests/pipelines/test_sentiment-classification.py.
"""
if preprocessor is None:
preprocessor = SingleSentenceClassificationPreprocessor(
model.model_dir if isinstance(model, Model) else model,
first_sequence=first_sequence,
sequence_length=kwargs.pop('sequence_length', 512))
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

+ 1
- 1
modelscope/pipelines/nlp/token_classification_pipeline.py View File

@@ -49,7 +49,7 @@ class TokenClassificationPipeline(Pipeline):
text = inputs.pop(OutputKeys.TEXT)
with torch.no_grad():
return {
**self.model(inputs, **forward_params), OutputKeys.TEXT: text
**self.model(**inputs, **forward_params), OutputKeys.TEXT: text
}

def postprocess(self, inputs: Dict[str, Any],


+ 29
- 19
modelscope/preprocessors/__init__.py View File

@@ -16,17 +16,23 @@ if TYPE_CHECKING:
from .kws import WavToLists
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor)
from .nlp import (
Tokenize, SequenceClassificationPreprocessor,
TextGenerationPreprocessor, TokenClassificationPreprocessor,
SingleSentenceClassificationPreprocessor,
PairSentenceClassificationPreprocessor, FillMaskPreprocessor,
ZeroShotClassificationPreprocessor, NERPreprocessor,
TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor,
SequenceLabelingPreprocessor, RelationExtractionPreprocessor,
DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor,
PassageRankingPreprocessor, SentenceEmbeddingPreprocessor,
DocumentSegmentationPreprocessor,
FaqQuestionAnsweringPreprocessor,
FillMaskPoNetPreprocessor,
NLPPreprocessor,
NLPTokenizerPreprocessorBase,
PassageRankingPreprocessor,
RelationExtractionPreprocessor,
SentenceEmbeddingPreprocessor,
SequenceClassificationPreprocessor,
TokenClassificationPreprocessor,
TextErrorCorrectionPreprocessor,
TextGenerationPreprocessor,
Text2TextGenerationPreprocessor,
WordSegmentationBlankSetToLabelPreprocessor)
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
)
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
@@ -49,18 +55,22 @@ else:
'kws': ['WavToLists'],
'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'],
'nlp': [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
'SingleSentenceClassificationPreprocessor',
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor',
'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor',
'DocumentSegmentationPreprocessor',
'FaqQuestionAnsweringPreprocessor',
'FillMaskPoNetPreprocessor',
'NLPPreprocessor',
'NLPTokenizerPreprocessorBase',
'PassageRankingPreprocessor',
'RelationExtractionPreprocessor',
'SentenceEmbeddingPreprocessor',
'SequenceClassificationPreprocessor',
'TokenClassificationPreprocessor',
'TextErrorCorrectionPreprocessor',
'TextGenerationPreprocessor',
'Tokenize',
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor'
'ZeroShotClassificationPreprocessor',
],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',


+ 27
- 18
modelscope/preprocessors/nlp/__init__.py View File

@@ -6,32 +6,41 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .text_error_correction import TextErrorCorrectionPreprocessor
from .nlp_base import (
Tokenize, SequenceClassificationPreprocessor,
TextGenerationPreprocessor, TokenClassificationPreprocessor,
SingleSentenceClassificationPreprocessor,
Text2TextGenerationPreprocessor,
PairSentenceClassificationPreprocessor, FillMaskPreprocessor,
ZeroShotClassificationPreprocessor, NERPreprocessor,
FaqQuestionAnsweringPreprocessor, SequenceLabelingPreprocessor,
RelationExtractionPreprocessor, DocumentSegmentationPreprocessor,
FillMaskPoNetPreprocessor, PassageRankingPreprocessor,
DocumentSegmentationPreprocessor,
FaqQuestionAnsweringPreprocessor,
FillMaskPoNetPreprocessor,
NLPPreprocessor,
NLPTokenizerPreprocessorBase,
PassageRankingPreprocessor,
RelationExtractionPreprocessor,
SentenceEmbeddingPreprocessor,
WordSegmentationBlankSetToLabelPreprocessor)
SequenceClassificationPreprocessor,
TokenClassificationPreprocessor,
TextGenerationPreprocessor,
Text2TextGenerationPreprocessor,
Tokenize,
WordSegmentationBlankSetToLabelPreprocessor,
ZeroShotClassificationPreprocessor,
)

else:
_import_structure = {
'nlp_base': [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
'SingleSentenceClassificationPreprocessor',
'PairSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'SentenceEmbeddingPreprocessor', 'PassageRankingPreprocessor',
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor',
'DocumentSegmentationPreprocessor',
'FaqQuestionAnsweringPreprocessor',
'FillMaskPoNetPreprocessor',
'NLPPreprocessor',
'NLPTokenizerPreprocessorBase',
'PassageRankingPreprocessor',
'RelationExtractionPreprocessor',
'SentenceEmbeddingPreprocessor',
'SequenceClassificationPreprocessor',
'TokenClassificationPreprocessor',
'TextGenerationPreprocessor',
'Tokenize',
'Text2TextGenerationPreprocessor',
'WordSegmentationBlankSetToLabelPreprocessor',
'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor'
'ZeroShotClassificationPreprocessor',
],
'text_error_correction': [
'TextErrorCorrectionPreprocessor',


+ 215
- 360
modelscope/preprocessors/nlp/nlp_base.py View File

@@ -2,14 +2,13 @@

import os.path as osp
import re
import uuid
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
from transformers import AutoTokenizer, BertTokenizerFast
import torch
from transformers import AutoTokenizer

from modelscope.metainfo import Models, Preprocessors
from modelscope.models.nlp.structbert import SbertTokenizerFast
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
@@ -23,24 +22,21 @@ from modelscope.utils.type_assert import type_assert
logger = get_logger()

__all__ = [
'Tokenize',
'DocumentSegmentationPreprocessor',
'FaqQuestionAnsweringPreprocessor',
'NLPPreprocessor',
'FillMaskPoNetPreprocessor',
'NLPTokenizerPreprocessorBase',
'PassageRankingPreprocessor',
'RelationExtractionPreprocessor',
'SentenceEmbeddingPreprocessor',
'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor',
'TokenClassificationPreprocessor',
'PairSentenceClassificationPreprocessor',
'Text2TextGenerationPreprocessor',
'SingleSentenceClassificationPreprocessor',
'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor',
'NERPreprocessor',
'SentenceEmbeddingPreprocessor',
'PassageRankingPreprocessor',
'FaqQuestionAnsweringPreprocessor',
'SequenceLabelingPreprocessor',
'RelationExtractionPreprocessor',
'DocumentSegmentationPreprocessor',
'FillMaskPoNetPreprocessor',
'TextGenerationPreprocessor',
'Tokenize',
'WordSegmentationBlankSetToLabelPreprocessor',
'ZeroShotClassificationPreprocessor',
]


@@ -48,85 +44,19 @@ __all__ = [
class Tokenize(Preprocessor):

def __init__(self, tokenizer_name) -> None:
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(data, str):
data = {InputFields.text: data}
token_dict = self._tokenizer(data[InputFields.text])
token_dict = self.tokenizer(data[InputFields.text])
data.update(token_dict)
return data


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer)
class SequenceClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from easynlp.modelzoo import AutoTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
print(f'this is the tokenzier {self.tokenizer}')
self.label2id = parse_label_mapping(self.model_dir)

@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
feature = super().__call__(data)
if isinstance(data, str):
new_data = {self.first_sequence: data}
elif isinstance(data, tuple):
sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
else:
new_data = data

# preprocess the data for the model input

rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': [],
}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
text_b = new_data.get(self.second_sequence, None)

feature = self.tokenizer(
text_a,
text_b,
padding='max_length',
truncation=True,
max_length=max_seq_length)

rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
return rst


class NLPTokenizerPreprocessorBase(Preprocessor):

def __init__(self, model_dir: str, pair: bool, mode: str, **kwargs):
def __init__(self, model_dir: str, mode: str, **kwargs):
"""The NLP tokenizer preprocessor base class.

Any nlp preprocessor which uses the hf tokenizer can inherit from this class.
@@ -138,7 +68,6 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
label: The label key
label2id: An optional label2id mapping, the class will try to call utils.parse_label_mapping
if this mapping is not supplied.
pair (bool): Pair sentence input or single sentence input.
mode: Run this preprocessor in either 'train'/'eval'/'inference' mode
kwargs: These kwargs will be directly fed into the tokenizer.
"""
@@ -148,7 +77,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.pair = pair
self.sequence_length = kwargs.pop('sequence_length', 128)

self._mode = mode
self.label = kwargs.pop('label', OutputKeys.LABEL)
self.label2id = None
@@ -158,6 +88,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
self.label2id = parse_label_mapping(self.model_dir)

self.tokenize_kwargs = kwargs

self.tokenizer = self.build_tokenizer(model_dir)

@property
@@ -179,20 +110,38 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
@param model_dir: The local model dir.
@return: The initialized tokenizer.
"""

self.is_transformer_based_model = 'lstm' not in model_dir
# fast version lead to parallel inference failed
model_type = get_model_type(model_dir)
if model_type in (Models.structbert, Models.gpt3, Models.palm,
Models.plug):
from modelscope.models.nlp.structbert import SbertTokenizer
return SbertTokenizer.from_pretrained(model_dir, use_fast=False)
from modelscope.models.nlp.structbert import SbertTokenizer, SbertTokenizerFast
return SbertTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else SbertTokenizerFast.from_pretrained(
model_dir)
elif model_type == Models.veco:
from modelscope.models.nlp.veco import VecoTokenizer
return VecoTokenizer.from_pretrained(model_dir)
from modelscope.models.nlp.veco import VecoTokenizer, VecoTokenizerFast
return VecoTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else VecoTokenizerFast.from_pretrained(
model_dir)
elif model_type == Models.deberta_v2:
from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer
return DebertaV2Tokenizer.from_pretrained(model_dir)
from modelscope.models.nlp.deberta_v2 import DebertaV2Tokenizer, DebertaV2TokenizerFast
return DebertaV2Tokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else DebertaV2TokenizerFast.from_pretrained(
model_dir)
elif not self.is_transformer_based_model:
from transformers import BertTokenizer, BertTokenizerFast
return BertTokenizer.from_pretrained(
model_dir
) if self._mode == ModeKeys.INFERENCE else BertTokenizerFast.from_pretrained(
model_dir)
else:
return AutoTokenizer.from_pretrained(model_dir, use_fast=False)
return AutoTokenizer.from_pretrained(
model_dir,
use_fast=False if self._mode == ModeKeys.INFERENCE else True)

def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data
@@ -239,7 +188,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
if len(data) == 3:
text_a, text_b, labels = data
elif len(data) == 2:
if self.pair:
if self._mode == ModeKeys.INFERENCE:
text_a, text_b = data
else:
text_a, labels = data
@@ -277,6 +226,22 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
output[OutputKeys.LABELS] = labels


@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask)
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.feature_extraction)
class NLPPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in MLM task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
True)
super().__init__(model_dir, mode=mode, **kwargs)


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.passage_ranking)
class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase):
@@ -337,22 +302,12 @@ class PassageRankingPreprocessor(NLPTokenizerPreprocessorBase):
Fields.nlp, module_name=Preprocessors.nli_tokenizer)
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer)
class PairSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in pair sentence classification.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get(
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, pair=True, mode=mode, **kwargs)


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer)
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer)
class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in single sentence classification.
class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in sequence classification.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
@@ -360,7 +315,7 @@ class SingleSentenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
kwargs['padding'] = kwargs.get(
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)
super().__init__(model_dir, mode=mode, **kwargs)


@PREPROCESSORS.register_module(
@@ -421,7 +376,7 @@ class ZeroShotClassificationPreprocessor(NLPTokenizerPreprocessorBase):
model_dir (str): model path
"""
self.sequence_length = kwargs.pop('sequence_length', 512)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)
super().__init__(model_dir, mode=mode, **kwargs)

def __call__(self, data: Union[str, Dict], hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
@@ -496,14 +451,12 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
tokenizer=None,
mode=ModeKeys.INFERENCE,
**kwargs):
self.tokenizer = self.build_tokenizer(
model_dir) if tokenizer is None else tokenizer
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
False)
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)
super().__init__(model_dir, mode=mode, **kwargs)

@staticmethod
def get_roberta_tokenizer_dir(model_dir: str) -> Optional[str]:
@@ -541,20 +494,6 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
}


@PREPROCESSORS.register_module(Fields.nlp, module_name=Preprocessors.fill_mask)
class FillMaskPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in MLM task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
True)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)


@PREPROCESSORS.register_module(
Fields.nlp,
module_name=Preprocessors.word_segment_text_to_label_preprocessor)
@@ -592,21 +531,40 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.ner_tokenizer)
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.token_cls_tokenizer)
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer)
class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in normal token classification task.
"""The tokenizer preprocessor used in normal NER task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get(
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
self.label_all_tokens = kwargs.pop('label_all_tokens', False)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)
super().__init__(model_dir, mode=mode, **kwargs)

def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]:
if 'is_split_into_words' in kwargs:
self.is_split_into_words = kwargs.pop('is_split_into_words')
else:
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)
if 'label2id' in kwargs:
kwargs.pop('label2id')
self.tokenize_kwargs = kwargs

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -618,23 +576,84 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
Dict[str, Any]: the preprocessed data
"""

text_a = None
# preprocess the data for the model input
text = None
labels_list = None
if isinstance(data, str):
text_a = data
text = data
elif isinstance(data, dict):
text_a = data.get(self.first_sequence)
text = data.get(self.first_sequence)
labels_list = data.get(self.label)

if isinstance(text_a, str):
text_a = text_a.replace(' ', '').strip()
input_ids = []
label_mask = []
offset_mapping = []
if self.is_split_into_words:
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
label_mask.extend([1] + [0] * (len(subtoken_ids) - 1))
offset_mapping.extend([(offset, offset + 1)])
else:
if self.tokenizer.is_fast:
encodings = self.tokenizer(
text,
add_special_tokens=False,
return_offsets_mapping=True,
**self.tokenize_kwargs)
input_ids = encodings['input_ids']
word_ids = encodings.word_ids()
for i in range(len(word_ids)):
if word_ids[i] is None:
label_mask.append(0)
elif word_ids[i] == word_ids[i - 1]:
label_mask.append(0)
offset_mapping[-1] = (
offset_mapping[-1][0],
encodings['offset_mapping'][i][1])
else:
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])
else:
encodings = self.tokenizer(
text, add_special_tokens=False, **self.tokenize_kwargs)
input_ids = encodings['input_ids']
label_mask, offset_mapping = self.get_label_mask_and_offset_mapping(
text)

if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
offset_mapping = offset_mapping[:sum(label_mask)]

tokenized_inputs = self.tokenizer(
[t for t in text_a],
return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None,
is_split_into_words=True,
**self.tokenize_kwargs)
if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]

if self._mode == ModeKeys.INFERENCE:
input_ids = torch.tensor(input_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)
label_mask = torch.tensor(
label_mask, dtype=torch.bool).unsqueeze(0)

# the token classification
output = {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}

# align the labels with tokenized text
if labels_list is not None:
assert self.label2id is not None
# Map that sends B-Xxx label to its I-Xxx counterpart
@@ -653,7 +672,6 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
b_to_i_label.append(idx)

label_row = [self.label2id[lb] for lb in labels_list]
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
@@ -668,229 +686,66 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
label_ids.append(-100)
previous_word_idx = word_idx
labels = label_ids
tokenized_inputs['labels'] = labels
# new code end

if self._mode == ModeKeys.INFERENCE:
tokenized_inputs[OutputKeys.TEXT] = text_a
return tokenized_inputs


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.ner_tokenizer)
class NERPreprocessor(Preprocessor):
"""The tokenizer preprocessor used in normal NER task.

NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition.
"""

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.is_transformer_based_model = 'lstm' not in model_dir
if self.is_transformer_based_model:
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=True)
else:
self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir, use_fast=True)
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
output['labels'] = labels
return output

# preprocess the data for the model input
text = data
if self.is_split_into_words:
input_ids = []
label_mask = []
offset_mapping = []
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
label_mask.extend([1] + [0] * (len(subtoken_ids) - 1))
offset_mapping.extend([(offset, offset + 1)]
+ [(offset + 1, offset + 1)]
* (len(subtoken_ids) - 1))
if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
offset_mapping = offset_mapping[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
else:
encodings = self.tokenizer(
text,
add_special_tokens=True,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_offsets_mapping=True)
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
word_ids = encodings.word_ids()
label_mask = []
offset_mapping = []
for i in range(len(word_ids)):
if word_ids[i] is None:
label_mask.append(0)
elif word_ids[i] == word_ids[i - 1]:
label_mask.append(0)
offset_mapping[-1] = (offset_mapping[-1][0],
encodings['offset_mapping'][i][1])
def get_tokenizer_class(self):
tokenizer_class = self.tokenizer.__class__.__name__
if tokenizer_class.endswith(
'Fast') and tokenizer_class != 'PreTrainedTokenizerFast':
tokenizer_class = tokenizer_class[:-4]
return tokenizer_class

def get_label_mask_and_offset_mapping(self, text):
label_mask = []
offset_mapping = []
tokens = self.tokenizer.tokenize(text)
offset = 0
if self.get_tokenizer_class() == 'BertTokenizer':
for token in tokens:
is_start = (token[:2] != '##')
if is_start:
label_mask.append(True)
else:
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
return {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sequence_labeling_tokenizer)
class SequenceLabelingPreprocessor(Preprocessor):
"""The tokenizer preprocessor used in normal NER task.

NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition.
"""

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)

if 'lstm' in model_dir or 'gcnn' in model_dir:
self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir, use_fast=False)
elif 'structbert' in model_dir:
self.tokenizer = SbertTokenizerFast.from_pretrained(
model_dir, use_fast=False)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)
self.is_split_into_words = self.tokenizer.init_kwargs.get(
'is_split_into_words', False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""

# preprocess the data for the model input
text = data
if self.is_split_into_words:
input_ids = []
label_mask = []
offset_mapping = []
for offset, token in enumerate(list(data)):
subtoken_ids = self.tokenizer.encode(
token, add_special_tokens=False)
if len(subtoken_ids) == 0:
subtoken_ids = [self.tokenizer.unk_token_id]
input_ids.extend(subtoken_ids)
label_mask.extend([1] + [0] * (len(subtoken_ids) - 1))
offset_mapping.extend([(offset, offset + 1)]
+ [(offset + 1, offset + 1)]
* (len(subtoken_ids) - 1))
if len(input_ids) >= self.sequence_length - 2:
input_ids = input_ids[:self.sequence_length - 2]
label_mask = label_mask[:self.sequence_length - 2]
offset_mapping = offset_mapping[:self.sequence_length - 2]
input_ids = [self.tokenizer.cls_token_id
] + input_ids + [self.tokenizer.sep_token_id]
label_mask = [0] + label_mask + [0]
attention_mask = [1] * len(input_ids)
else:
encodings = self.tokenizer(
text,
add_special_tokens=True,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_offsets_mapping=True)
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
word_ids = encodings.word_ids()
label_mask = []
offset_mapping = []
for i in range(len(word_ids)):
if word_ids[i] is None:
label_mask.append(0)
elif word_ids[i] == word_ids[i - 1]:
label_mask.append(0)
offset_mapping[-1] = (offset_mapping[-1][0],
encodings['offset_mapping'][i][1])
token = token[2:]
label_mask.append(False)
start = offset + text[offset:].index(token)
end = start + len(token)
if is_start:
offset_mapping.append((start, end))
else:
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])
offset_mapping[-1] = (offset_mapping[-1][0], end)
offset = end
elif self.get_tokenizer_class() == 'XLMRobertaTokenizer':
last_is_blank = False
for token in tokens:
is_start = (token[0] == '▁')
if is_start:
token = token[1:]
label_mask.append(True)
if len(token) == 0:
last_is_blank = True
continue
else:
label_mask.append(False)
start = offset + text[offset:].index(token)
end = start + len(token)
if last_is_blank or is_start:
offset_mapping.append((start, end))
else:
offset_mapping[-1] = (offset_mapping[-1][0], end)
offset = end
last_is_blank = False
else:
raise NotImplementedError

if not self.is_transformer_based_model:
input_ids = input_ids[1:-1]
attention_mask = attention_mask[1:-1]
label_mask = label_mask[1:-1]
return {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'label_mask': label_mask,
'offset_mapping': offset_mapping
}
return label_mask, offset_mapping


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.re_tokenizer)
class RelationExtractionPreprocessor(Preprocessor):
"""The tokenizer preprocessor used in normal RE task.

NOTE: This preprocessor may be merged with the TokenClassificationPreprocessor in the next edition.
"""The relation extraction preprocessor used in normal RE task.
"""

def __init__(self, model_dir: str, *args, **kwargs):
@@ -937,7 +792,7 @@ class FaqQuestionAnsweringPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
super(FaqQuestionAnsweringPreprocessor, self).__init__(
model_dir, pair=False, mode=ModeKeys.INFERENCE, **kwargs)
model_dir, mode=ModeKeys.INFERENCE, **kwargs)
import os
from transformers import BertTokenizer

@@ -1026,7 +881,7 @@ class DocumentSegmentationPreprocessor(Preprocessor):
"""

super().__init__(*args, **kwargs)
from transformers import BertTokenizerFast
self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir,
use_fast=True,


+ 1
- 0
modelscope/utils/constant.py View File

@@ -115,6 +115,7 @@ class NLPTasks(object):
conversational_text_to_sql = 'conversational-text-to-sql'
information_extraction = 'information-extraction'
document_segmentation = 'document-segmentation'
feature_extraction = 'feature-extraction'


class AudioTasks(object):


+ 1
- 1
modelscope/utils/registry.py View File

@@ -74,7 +74,6 @@ class Registry(object):
raise KeyError(f'{module_name} is already registered in '
f'{self._name}[{group_key}]')
self._modules[group_key][module_name] = module_cls
module_cls.group_key = group_key

def register_module(self,
group_key: str = default_group,
@@ -196,6 +195,7 @@ def build_from_cfg(cfg,
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name}'
f' registry group {group_key}')
obj_cls.group_key = group_key
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:


+ 2
- 1
tests/msdatasets/test_ms_dataset.py View File

@@ -75,7 +75,8 @@ class MsDatasetTest(unittest.TestCase):
preprocessor = SequenceClassificationPreprocessor(
nlp_model.model_dir,
first_sequence='premise',
second_sequence=None)
second_sequence=None,
padding='max_length')
ms_ds_train = MsDataset.load(
'xcopa',
subset_name='translation-et',


+ 3
- 5
tests/pipelines/test_deberta_tasks.py View File

@@ -6,11 +6,9 @@ import torch
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import DebertaV2ForMaskedLM
from modelscope.models.nlp.deberta_v2 import (DebertaV2Tokenizer,
DebertaV2TokenizerFast)
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import FillMaskPipeline
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.preprocessors import NLPPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

@@ -24,7 +22,7 @@ class DeBERTaV2TaskTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id_deberta)
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = DebertaV2ForMaskedLM.from_pretrained(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
@@ -40,7 +38,7 @@ class DeBERTaV2TaskTest(unittest.TestCase):
# sbert
print(self.model_id_deberta)
model = Model.from_pretrained(self.model_id_deberta)
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)


+ 67
- 0
tests/pipelines/test_feature_extraction.py View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

import numpy as np

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import FeatureExtractionModel
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import FeatureExtractionPipeline
from modelscope.preprocessors import NLPPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class FeatureExtractionTaskModelTest(unittest.TestCase,
DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.feature_extraction
self.model_id = 'damo/pert_feature-extraction_base-test'

sentence1 = '测试embedding'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = NLPPreprocessor(cache_path, padding=False)
model = FeatureExtractionModel.from_pretrained(self.model_id)
pipeline1 = FeatureExtractionPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.feature_extraction, model=model, preprocessor=tokenizer)
result = pipeline1(input=self.sentence1)

print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{np.shape(result[OutputKeys.TEXT_EMBEDDING])}')
result = pipeline2(input=self.sentence1)
print(f'sentence1: {self.sentence1}\n'
f'pipeline1: {np.shape(result[OutputKeys.TEXT_EMBEDDING])}')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = NLPPreprocessor(model.model_dir, padding=False)
pipeline_ins = pipeline(
task=Tasks.feature_extraction, model=model, preprocessor=tokenizer)
result = pipeline_ins(input=self.sentence1)
print(np.shape(result[OutputKeys.TEXT_EMBEDDING]))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.feature_extraction, model=self.model_id)
result = pipeline_ins(input=self.sentence1)
print(np.shape(result[OutputKeys.TEXT_EMBEDDING]))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.feature_extraction)
result = pipeline_ins(input=self.sentence1)
print(np.shape(result[OutputKeys.TEXT_EMBEDDING]))


if __name__ == '__main__':
unittest.main()

+ 44
- 5
tests/pipelines/test_fill_mask.py View File

@@ -1,13 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from regex import R

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import (BertForMaskedLM, StructBertForMaskedLM,
VecoForMaskedLM)
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import FillMaskPipeline
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.preprocessors import NLPPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
@@ -51,7 +53,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
# sbert
for language in ['zh']:
model_dir = snapshot_download(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = StructBertForMaskedLM.from_pretrained(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
@@ -66,7 +68,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):

# veco
model_dir = snapshot_download(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = VecoForMaskedLM.from_pretrained(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
@@ -80,13 +82,28 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
)

# bert
language = 'zh'
model_dir = snapshot_download(self.model_id_bert, revision='beta')
preprocessor = NLPPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = Model.from_pretrained(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):

# sbert
for language in ['zh']:
print(self.model_id_sbert[language])
model = Model.from_pretrained(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model.model_dir,
first_sequence='sentence',
second_sequence=None)
@@ -100,7 +117,7 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):

# veco
model = Model.from_pretrained(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
preprocessor = NLPPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
@@ -113,6 +130,18 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

# bert
language = 'zh'
model = Model.from_pretrained(self.model_id_bert, revision='beta')
preprocessor = NLPPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
pipeline_ins.model, f'fill_mask_bert_{language}'
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
# veco
@@ -131,6 +160,16 @@ class FillMaskTest(unittest.TestCase, DemoCompatibilityCheck):
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

# Bert
language = 'zh'
pipeline_ins = pipeline(
task=Tasks.fill_mask,
model=self.model_id_bert,
model_revision='beta')
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.fill_mask)


+ 5
- 5
tests/pipelines/test_named_entity_recognition.py View File

@@ -7,7 +7,7 @@ from modelscope.models.nlp import (LSTMCRFForNamedEntityRecognition,
TransformerCRFForNamedEntityRecognition)
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import NamedEntityRecognitionPipeline
from modelscope.preprocessors import NERPreprocessor
from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -26,7 +26,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
cache_path = snapshot_download(self.tcrf_model_id)
tokenizer = NERPreprocessor(cache_path)
tokenizer = TokenClassificationPreprocessor(cache_path)
model = TransformerCRFForNamedEntityRecognition(
cache_path, tokenizer=tokenizer)
pipeline1 = NamedEntityRecognitionPipeline(
@@ -43,7 +43,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_by_direct_model_download(self):
cache_path = snapshot_download(self.lcrf_model_id)
tokenizer = NERPreprocessor(cache_path)
tokenizer = TokenClassificationPreprocessor(cache_path)
model = LSTMCRFForNamedEntityRecognition(
cache_path, tokenizer=tokenizer)
pipeline1 = NamedEntityRecognitionPipeline(
@@ -60,7 +60,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_tcrf_with_model_from_modelhub(self):
model = Model.from_pretrained(self.tcrf_model_id)
tokenizer = NERPreprocessor(model.model_dir)
tokenizer = TokenClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
model=model,
@@ -70,7 +70,7 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_lcrf_with_model_from_modelhub(self):
model = Model.from_pretrained(self.lcrf_model_id)
tokenizer = NERPreprocessor(model.model_dir)
tokenizer = TokenClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition,
model=model,


+ 5
- 5
tests/pipelines/test_nli.py View File

@@ -5,8 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import PairSentenceClassificationPipeline
from modelscope.preprocessors import PairSentenceClassificationPreprocessor
from modelscope.pipelines.nlp import SequenceClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
@@ -26,9 +26,9 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = PairSentenceClassificationPreprocessor(cache_path)
tokenizer = SequenceClassificationPreprocessor(cache_path)
model = SbertForSequenceClassification.from_pretrained(cache_path)
pipeline1 = PairSentenceClassificationPipeline(
pipeline1 = SequenceClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer)
print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n'
@@ -40,7 +40,7 @@ class NLITest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = PairSentenceClassificationPreprocessor(model.model_dir)
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.nli, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))


+ 5
- 5
tests/pipelines/test_sentence_similarity.py View File

@@ -5,8 +5,8 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSequenceClassification
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import PairSentenceClassificationPipeline
from modelscope.preprocessors import PairSentenceClassificationPreprocessor
from modelscope.pipelines.nlp import SequenceClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.regress_test_utils import MsRegressTool
@@ -26,9 +26,9 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
tokenizer = PairSentenceClassificationPreprocessor(cache_path)
tokenizer = SequenceClassificationPreprocessor(cache_path)
model = SbertForSequenceClassification.from_pretrained(cache_path)
pipeline1 = PairSentenceClassificationPipeline(
pipeline1 = SequenceClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.sentence_similarity, model=model, preprocessor=tokenizer)
@@ -43,7 +43,7 @@ class SentenceSimilarityTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = PairSentenceClassificationPreprocessor(model.model_dir)
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentence_similarity,
model=model,


+ 16
- 15
tests/pipelines/test_sentiment_classification.py View File

@@ -6,8 +6,8 @@ from modelscope.models import Model
from modelscope.models.nlp.task_models.sequence_classification import \
SequenceClassificationModel
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import SingleSentenceClassificationPipeline
from modelscope.preprocessors import SingleSentenceClassificationPreprocessor
from modelscope.pipelines.nlp import SequenceClassificationPipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -17,23 +17,21 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.sentiment_classification
self.task = Tasks.text_classification
self.model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'

sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = SingleSentenceClassificationPreprocessor(cache_path)
cache_path = snapshot_download(self.model_id, revision='beta')
tokenizer = SequenceClassificationPreprocessor(cache_path)
model = SequenceClassificationModel.from_pretrained(
self.model_id, num_labels=2)
pipeline1 = SingleSentenceClassificationPipeline(
self.model_id, num_labels=2, revision='beta')
pipeline1 = SequenceClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.sentiment_classification,
model=model,
preprocessor=tokenizer)
Tasks.text_classification, model=model, preprocessor=tokenizer)
print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{pipeline1(input=self.sentence1)}')
print(f'sentence1: {self.sentence1}\n'
@@ -41,10 +39,10 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = SingleSentenceClassificationPreprocessor(model.model_dir)
model = Model.from_pretrained(self.model_id, revision='beta')
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentiment_classification,
task=Tasks.text_classification,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence1))
@@ -54,14 +52,17 @@ class SentimentClassificationTaskModelTest(unittest.TestCase,
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.sentiment_classification, model=self.model_id)
task=Tasks.text_classification,
model=self.model_id,
model_revision='beta')
print(pipeline_ins(input=self.sentence1))
self.assertTrue(
isinstance(pipeline_ins.model, SequenceClassificationModel))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.sentiment_classification)
pipeline_ins = pipeline(
task=Tasks.text_classification, model_revision='beta')
print(pipeline_ins(input=self.sentence1))
self.assertTrue(
isinstance(pipeline_ins.model, SequenceClassificationModel))


+ 3
- 1
tests/pipelines/test_text_classification.py View File

@@ -12,6 +12,7 @@ from modelscope.utils.test_utils import test_level


class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
sentence1 = 'i like this wonderful place'

def setUp(self) -> None:
self.model_id = 'damo/bert-base-sst2'
@@ -46,7 +47,8 @@ class SequenceClassificationTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.text_classification,
model=model,
preprocessor=preprocessor)
self.predict(pipeline_ins)
print(f'sentence1: {self.sentence1}\n'
f'pipeline1:{pipeline_ins(input=self.sentence1)}')

# @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('nlp model does not support tensor input, skipped')


+ 76
- 0
tests/preprocessors/test_nlp.py View File

@@ -32,6 +32,82 @@ class NLPPreprocessorTest(unittest.TestCase):
output['attention_mask'],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

def test_token_classification_tokenize(self):
with self.subTest(tokenizer_type='bert'):
cfg = dict(
type='token-cls-tokenizer',
model_dir='bert-base-cased',
label2id={
'O': 0,
'B': 1,
'I': 2
})
preprocessor = build_preprocessor(cfg, Fields.nlp)
input = 'Do not meddle in the affairs of wizards, ' \
'for they are subtle and quick to anger.'
output = preprocessor(input)
self.assertTrue(InputFields.text in output)
self.assertEqual(output['input_ids'].tolist()[0], [
101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678,
1116, 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470,
119, 102
])
self.assertEqual(output['attention_mask'].tolist()[0], [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1
])
self.assertEqual(output['label_mask'].tolist()[0], [
False, True, True, True, False, True, True, True, True, True,
False, True, True, True, True, True, True, True, True, True,
True, False
])
self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6),
(7, 13), (14, 16),
(17, 20), (21, 28),
(29, 31), (32, 39),
(39, 40), (41, 44),
(45, 49), (50, 53),
(54, 60), (61, 64),
(65, 70), (71, 73),
(74, 79), (79, 80)])

with self.subTest(tokenizer_type='roberta'):
cfg = dict(
type='token-cls-tokenizer',
model_dir='xlm-roberta-base',
label2id={
'O': 0,
'B': 1,
'I': 2
})
preprocessor = build_preprocessor(cfg, Fields.nlp)
input = 'Do not meddle in the affairs of wizards, ' \
'for they are subtle and quick to anger.'
output = preprocessor(input)
self.assertTrue(InputFields.text in output)
self.assertEqual(output['input_ids'].tolist()[0], [
0, 984, 959, 128, 19298, 23, 70, 103086, 7, 111, 6, 44239,
99397, 4, 100, 1836, 621, 1614, 17991, 136, 63773, 47, 348, 56,
5, 2
])
self.assertEqual(output['attention_mask'].tolist()[0], [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1
])
self.assertEqual(output['label_mask'].tolist()[0], [
False, True, True, True, False, True, True, True, False, True,
True, False, False, False, True, True, True, True, False, True,
True, True, True, False, False, False
])
self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6),
(7, 13), (14, 16),
(17, 20), (21, 28),
(29, 31), (32, 40),
(41, 44), (45, 49),
(50, 53), (54, 60),
(61, 64), (65, 70),
(71, 73), (74, 80)])


if __name__ == '__main__':
unittest.main()

+ 5
- 7
tests/utils/test_ast.py View File

@@ -30,7 +30,7 @@ class AstScaningTest(unittest.TestCase):
def test_ast_scaning_class(self):
astScaner = AstScaning()
pipeline_file = os.path.join(MODELSCOPE_PATH, 'pipelines', 'nlp',
'sequence_classification_pipeline.py')
'text_generation_pipeline.py')
output = astScaner.generate_ast(pipeline_file)
self.assertTrue(output['imports'] is not None)
self.assertTrue(output['from_imports'] is not None)
@@ -40,14 +40,12 @@ class AstScaningTest(unittest.TestCase):
self.assertIsInstance(imports, dict)
self.assertIsInstance(from_imports, dict)
self.assertIsInstance(decorators, list)
self.assertListEqual(
list(set(imports.keys()) - set(['typing', 'numpy'])), [])
self.assertEqual(len(from_imports.keys()), 9)
self.assertListEqual(list(set(imports.keys()) - set(['torch'])), [])
self.assertEqual(len(from_imports.keys()), 7)
self.assertTrue(from_imports['modelscope.metainfo'] is not None)
self.assertEqual(from_imports['modelscope.metainfo'], ['Pipelines'])
self.assertEqual(
decorators,
[('PIPELINES', 'text-classification', 'sentiment-analysis')])
self.assertEqual(decorators,
[('PIPELINES', 'text-generation', 'text-generation')])

def test_files_scaning_method(self):
fileScaner = FilesAstScaning()


Loading…
Cancel
Save