shichen.fsc yingda.chen 3 years ago
parent
commit
a8fd9c4afe
14 changed files with 2475 additions and 230 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +41
    -0
      modelscope/models/nlp/ponet/__init__.py
  4. +117
    -0
      modelscope/models/nlp/ponet/configuration_ponet.py
  5. +1591
    -0
      modelscope/models/nlp/ponet/modeling_ponet.py
  6. +155
    -0
      modelscope/models/nlp/ponet/tokenization_ponet.py
  7. +53
    -0
      modelscope/models/nlp/ponet_for_masked_language.py
  8. +2
    -0
      modelscope/pipelines/nlp/__init__.py
  9. +136
    -0
      modelscope/pipelines/nlp/fill_mask_ponet_pipeline.py
  10. +4
    -4
      modelscope/preprocessors/__init__.py
  11. +303
    -3
      modelscope/preprocessors/nlp.py
  12. +0
    -223
      modelscope/preprocessors/slp.py
  13. +20
    -0
      modelscope/utils/nlp/nlp_utils.py
  14. +48
    -0
      tests/pipelines/test_fill_mask_ponet.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -62,6 +62,7 @@ class Models(object):
gpt3 = 'gpt3'
plug = 'plug'
bert_for_ds = 'bert-for-document-segmentation'
ponet = 'ponet'

# audio models
sambert_hifigan = 'sambert-hifigan'
@@ -179,6 +180,7 @@ class Pipelines(object):
sentiment_classification = 'sentiment-classification'
text_classification = 'text-classification'
fill_mask = 'fill-mask'
fill_mask_ponet = 'fill-mask-ponet'
csanmt_translation = 'csanmt-translation'
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
@@ -281,6 +283,7 @@ class Preprocessors(object):
sequence_labeling_tokenizer = 'sequence-labeling-tokenizer'
word_segment_text_to_label_preprocessor = 'word-segment-text-to-label-preprocessor'
fill_mask = 'fill-mask'
fill_mask_ponet = 'fill-mask-ponet'
faq_question_answering_preprocessor = 'faq-question-answering-preprocessor'
conversational_text_to_sql = 'conversational-text-to-sql'
re_tokenizer = 're-tokenizer'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -13,6 +13,7 @@ if TYPE_CHECKING:
from .gpt3 import GPT3ForTextGeneration
from .masked_language import (StructBertForMaskedLM, VecoForMaskedLM,
BertForMaskedLM, DebertaV2ForMaskedLM)
from .ponet_for_masked_language import PoNetForMaskedLM
from .nncrf_for_named_entity_recognition import (
TransformerCRFForNamedEntityRecognition,
LSTMCRFForNamedEntityRecognition)
@@ -46,6 +47,7 @@ else:
'TransformerCRFForNamedEntityRecognition',
'LSTMCRFForNamedEntityRecognition'
],
'ponet_for_masked_language': ['PoNetForMaskedLM'],
'palm_v2': ['PalmForTextGeneration'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'],
'star_text_to_sql': ['StarForTextToSql'],


+ 41
- 0
modelscope/models/nlp/ponet/__init__.py View File

@@ -0,0 +1,41 @@
# Copyright 2021-2022 The Alibaba DAMO Team Authors.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .configuration_ponet import PoNetConfig
from .modeling_ponet import (PoNetForMaskedLM, PoNetModel,
PoNetPreTrainedModel)
from .tokenization_ponet import PoNetTokenizer
else:
_import_structure = {
'configuration_ponet': ['PoNetConfig'],
'modeling_ponet':
['PoNetForMaskedLM', 'PoNetModel', 'PoNetPreTrainedModel'],
'tokenization_ponet': ['PoNetTokenizer'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 117
- 0
modelscope/models/nlp/ponet/configuration_ponet.py View File

@@ -0,0 +1,117 @@
# Copyright 2021-2022 The Alibaba DAMO Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PoNet model configuration, mainly copied from :class:`~transformers.BertConfig` """
from transformers import PretrainedConfig

from modelscope.utils import logger as logging

logger = logging.get_logger(__name__)


class PoNetConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration
of a :class:`~modelscope.models.nlp.ponet.PoNetModel`.
It is used to instantiate a PoNet model according to the specified arguments.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.


Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
:class:`~transformers.TFBertModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.
clsgsepg (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not use a trick to make sure the segment and local information will not leak.
"""
model_type = 'ponet'

def __init__(self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type='absolute',
use_cache=True,
classifier_dropout=None,
clsgsepg=True,
**kwargs):
super().__init__(pad_token_id=pad_token_id, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.clsgsepg = clsgsepg

+ 1591
- 0
modelscope/models/nlp/ponet/modeling_ponet.py
File diff suppressed because it is too large
View File


+ 155
- 0
modelscope/models/nlp/ponet/tokenization_ponet.py View File

@@ -0,0 +1,155 @@
# Copyright 2021-2022 The Alibaba DAMO Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for PoNet """

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from transformers.file_utils import PaddingStrategy
from transformers.models.bert.tokenization_bert import BertTokenizer

from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger

logger = get_logger(__name__)

VOCAB_FILES_NAMES = {'vocab_file': ModelFile.VOCAB_FILE}

PRETRAINED_VOCAB_FILES_MAP = {'vocab_file': {}}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'nlp_ponet_fill-mask_chinese-base': 512,
'nlp_ponet_fill-mask_english-base': 512,
}

PRETRAINED_INIT_CONFIGURATION = {
'nlp_ponet_fill-mask_chinese-base': {
'do_lower_case': True
},
'nlp_ponet_fill-mask_english-base': {
'do_lower_case': True
},
}


class PoNetTokenizer(BertTokenizer):
r"""
Construct an PoNet tokenizer. Based on BertTokenizer.

This tokenizer inherits from :class:`~transformers.BertTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.

Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""

vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION

def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)

Args:
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or
batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.

- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:

- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta).
return_attention_mask: (optional) Set to False to avoid returning
attention mask (default: set to model specifics)
"""
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = 'attention_mask' in self.model_input_names

required_input = encoded_inputs[self.model_input_names[0]]

if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)

if max_length is not None and pad_to_multiple_of is not None and (
max_length % pad_to_multiple_of != 0):
max_length = (
(max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(
required_input) != max_length

if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == 'right':
if return_attention_mask:
encoded_inputs['attention_mask'] = [1] * len(
required_input) + [0] * difference
if 'token_type_ids' in encoded_inputs:
encoded_inputs['token_type_ids'] = (
encoded_inputs['token_type_ids']
+ [self.pad_token_type_id] * difference)
if 'special_tokens_mask' in encoded_inputs:
encoded_inputs['special_tokens_mask'] = encoded_inputs[
'special_tokens_mask'] + [1] * difference
if 'segment_ids' in encoded_inputs:
encoded_inputs[
'segment_ids'] = encoded_inputs['segment_ids'] + [
encoded_inputs['segment_ids'][-1] + 1
] * difference # noqa *
encoded_inputs[self.model_input_names[
0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == 'left':
if return_attention_mask:
encoded_inputs['attention_mask'] = [0] * difference + [
1
] * len(required_input)
if 'token_type_ids' in encoded_inputs:
encoded_inputs['token_type_ids'] = [
self.pad_token_type_id
] * difference + encoded_inputs['token_type_ids']
if 'segment_ids' in encoded_inputs:
encoded_inputs['segment_ids'] = [encoded_inputs['segment_ids'][-1] + 1] * difference + \
encoded_inputs['segment_ids'] # noqa *
if 'special_tokens_mask' in encoded_inputs:
encoded_inputs['special_tokens_mask'] = [
1
] * difference + encoded_inputs['special_tokens_mask']
encoded_inputs[self.model_input_names[
0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError('Invalid padding strategy:'
+ str(self.padding_side))
elif return_attention_mask and 'attention_mask' not in encoded_inputs:
encoded_inputs['attention_mask'] = [1] * len(required_input)

return encoded_inputs

+ 53
- 0
modelscope/models/nlp/ponet_for_masked_language.py View File

@@ -0,0 +1,53 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.ponet import \
PoNetForMaskedLM as PoNetForMaskedLMTransformer
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks

__all__ = ['PoNetForMaskedLM']


@MODELS.register_module(Tasks.fill_mask, module_name=Models.ponet)
class PoNetForMaskedLM(TorchModel, PoNetForMaskedLMTransformer):
"""PoNet for MLM model.'.

Inherited from ponet.PoNetForMaskedLM and TorchModel, so this class can be registered into Model sets.
"""

def __init__(self, config, model_dir):
super(TorchModel, self).__init__(model_dir)
PoNetForMaskedLMTransformer.__init__(self, config)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
segment_ids=None,
position_ids=None,
head_mask=None,
labels=None):
output = PoNetForMaskedLMTransformer.forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
segment_ids=segment_ids,
position_ids=position_ids,
head_mask=head_mask,
labels=labels)
output[OutputKeys.INPUT_IDS] = input_ids
return output

@classmethod
def _instantiate(cls, **kwargs):
model_dir = kwargs.get('model_dir')
return super(PoNetForMaskedLMTransformer,
PoNetForMaskedLM).from_pretrained(
pretrained_model_name_or_path=model_dir,
model_dir=model_dir)

+ 2
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .document_segmentation_pipeline import DocumentSegmentationPipeline
from .faq_question_answering_pipeline import FaqQuestionAnsweringPipeline
from .fill_mask_pipeline import FillMaskPipeline
from .fill_mask_ponet_pipeline import FillMaskPoNetPreprocessor
from .information_extraction_pipeline import InformationExtractionPipeline
from .named_entity_recognition_pipeline import NamedEntityRecognitionPipeline
from .pair_sentence_classification_pipeline import PairSentenceClassificationPipeline
@@ -36,6 +37,7 @@ else:
'document_segmentation_pipeline': ['DocumentSegmentationPipeline'],
'faq_question_answering_pipeline': ['FaqQuestionAnsweringPipeline'],
'fill_mask_pipeline': ['FillMaskPipeline'],
'fill_mask_ponet_pipeline': ['FillMaskPoNetPipeline'],
'named_entity_recognition_pipeline':
['NamedEntityRecognitionPipeline'],
'information_extraction_pipeline': ['InformationExtractionPipeline'],


+ 136
- 0
modelscope/pipelines/nlp/fill_mask_ponet_pipeline.py View File

@@ -0,0 +1,136 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
from typing import Any, Dict, Optional, Union

import torch

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import FillMaskPoNetPreprocessor, Preprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['FillMaskPonetPipeline']
_type_map = {'ponet': 'bert'}


@PIPELINES.register_module(
Tasks.fill_mask, module_name=Pipelines.fill_mask_ponet)
class FillMaskPonetPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
first_sequence='sentence',
**kwargs):
"""Use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction

Args:
model (str or Model): Supply either a local model dir which supported fill-mask task,
or a fill-mask model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
first_sequence: The key to read the sentence in.

NOTE: Inputs of type 'str' are also supported. In this scenario, the 'first_sequence'
param will have no effect.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(
'fill-mask', model='damo/nlp_ponet_fill-mask_english-base')
>>> input = 'Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK].'
>>> print(pipeline_ins(input))

NOTE2: Please pay attention to the model's special tokens.
If bert based model(bert, structbert, etc.) is used, the mask token is '[MASK]'.
If the xlm-roberta(xlm-roberta, veco, etc.) based model is used, the mask token is '<mask>'.
To view other examples plese check the tests/pipelines/test_fill_mask.py.
"""
fill_mask_model = model if isinstance(
model, Model) else Model.from_pretrained(model)

self.config = Config.from_file(
os.path.join(fill_mask_model.model_dir, ModelFile.CONFIGURATION))

if preprocessor is None:
preprocessor = FillMaskPoNetPreprocessor(
fill_mask_model.model_dir,
first_sequence=first_sequence,
second_sequence=None,
sequence_length=kwargs.pop('sequence_length', 512))

fill_mask_model.eval()
super().__init__(
model=fill_mask_model, preprocessor=preprocessor, **kwargs)

self.preprocessor = preprocessor

self.tokenizer = preprocessor.tokenizer
self.mask_id = {'roberta': 250001, 'bert': 103}

self.rep_map = {
'bert': {
'[unused0]': '',
'[PAD]': '',
'[unused1]': '',
r' +': ' ',
'[SEP]': '',
'[unused2]': '',
'[CLS]': '',
'[UNK]': ''
},
'roberta': {
r' +': ' ',
'<mask>': '<q>',
'<pad>': '',
'<s>': '',
'</s>': '',
'<unk>': ' '
}
}

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return self.model(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
import numpy as np
logits = inputs[OutputKeys.LOGITS].detach().cpu().numpy()
input_ids = inputs[OutputKeys.INPUT_IDS].detach().cpu().numpy()
pred_ids = np.argmax(logits, axis=-1)
model_type = self.model.config.model_type
process_type = model_type if model_type in self.mask_id else _type_map[
model_type]
rst_ids = np.where(input_ids == self.mask_id[process_type], pred_ids,
input_ids)

def rep_tokens(string, rep_map):
for k, v in rep_map.items():
string = string.replace(k, v)
return string.strip()

pred_strings = []
for ids in rst_ids: # batch
if 'language' in self.config.model and self.config.model.language == 'zh':
pred_string = self.tokenizer.convert_ids_to_tokens(ids)
pred_string = ''.join(pred_string)
else:
pred_string = self.tokenizer.decode(ids)
pred_string = rep_tokens(pred_string, self.rep_map[process_type])
pred_strings.append(pred_string)

return {OutputKeys.TEXT: pred_strings}

+ 4
- 4
modelscope/preprocessors/__init__.py View File

@@ -22,8 +22,8 @@ if TYPE_CHECKING:
PairSentenceClassificationPreprocessor, FillMaskPreprocessor,
ZeroShotClassificationPreprocessor, NERPreprocessor,
TextErrorCorrectionPreprocessor, FaqQuestionAnsweringPreprocessor,
SequenceLabelingPreprocessor, RelationExtractionPreprocessor)
from .slp import DocumentSegmentationPreprocessor
SequenceLabelingPreprocessor, RelationExtractionPreprocessor,
DocumentSegmentationPreprocessor, FillMaskPoNetPreprocessor)
from .space import (DialogIntentPredictionPreprocessor,
DialogModelingPreprocessor,
DialogStateTrackingPreprocessor)
@@ -52,9 +52,9 @@ else:
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor',
'FaqQuestionAnsweringPreprocessor', 'SequenceLabelingPreprocessor',
'RelationExtractionPreprocessor'
'RelationExtractionPreprocessor',
'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor'
],
'slp': ['DocumentSegmentationPreprocessor'],
'space': [
'DialogIntentPredictionPreprocessor', 'DialogModelingPreprocessor',
'DialogStateTrackingPreprocessor', 'InputFeatures'


+ 303
- 3
modelscope/preprocessors/nlp.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
import re
import uuid
from typing import Any, Dict, Iterable, Optional, Tuple, Union

@@ -11,13 +12,17 @@ from transformers import AutoTokenizer, BertTokenizerFast
from modelscope.metainfo import Models, Preprocessors
from modelscope.models.nlp.structbert import SbertTokenizerFast
from modelscope.outputs import OutputKeys
from modelscope.utils.config import ConfigFields
from modelscope.utils.constant import Fields, InputFields, ModeKeys
from modelscope.utils.config import Config, ConfigFields
from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile
from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.logger import get_logger
from modelscope.utils.nlp.nlp_utils import import_external_nltk_data
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

logger = get_logger()

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
@@ -25,7 +30,8 @@ __all__ = [
'SingleSentenceClassificationPreprocessor', 'FillMaskPreprocessor',
'ZeroShotClassificationPreprocessor', 'NERPreprocessor',
'TextErrorCorrectionPreprocessor', 'FaqQuestionAnsweringPreprocessor',
'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor'
'SequenceLabelingPreprocessor', 'RelationExtractionPreprocessor',
'DocumentSegmentationPreprocessor', 'FillMaskPoNetPreprocessor'
]


@@ -903,3 +909,297 @@ class FaqQuestionAnsweringPreprocessor(Preprocessor):
max_length = self.MAX_LEN
return self.tokenizer.batch_encode_plus(
sentence_list, padding=True, max_length=max_length)


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.document_segmentation)
class DocumentSegmentationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, config, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir,
use_fast=True,
)
self.question_column_name = 'labels'
self.context_column_name = 'sentences'
self.example_id_column_name = 'example_id'
self.label_to_id = {'B-EOP': 0, 'O': 1}
self.target_specical_ids = set()
self.target_specical_ids.add(self.tokenizer.eos_token_id)
self.max_seq_length = config.max_position_embeddings
self.label_list = ['B-EOP', 'O']

def __call__(self, examples) -> Dict[str, Any]:
questions = examples[self.question_column_name]
contexts = examples[self.context_column_name]
example_ids = examples[self.example_id_column_name]
num_examples = len(questions)

sentences = []
for sentence_list in contexts:
sentence_list = [_ + '[EOS]' for _ in sentence_list]
sentences.append(sentence_list)

try:
tokenized_examples = self.tokenizer(
sentences,
is_split_into_words=True,
add_special_tokens=False,
return_token_type_ids=True,
return_attention_mask=True,
)
except Exception as e:
logger.error(e)
return {}

segment_ids = []
token_seq_labels = []
for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_labels = questions[example_index]
example_labels = [
self.label_to_id[_] if _ in self.label_to_id else -100
for _ in example_labels
]
example_token_labels = []
segment_id = []
cur_seg_id = 1
for token_index in range(len(example_input_ids)):
if example_input_ids[token_index] in self.target_specical_ids:
example_token_labels.append(example_labels[cur_seg_id - 1])
segment_id.append(cur_seg_id)
cur_seg_id += 1
else:
example_token_labels.append(-100)
segment_id.append(cur_seg_id)

segment_ids.append(segment_id)
token_seq_labels.append(example_token_labels)

tokenized_examples['segment_ids'] = segment_ids
tokenized_examples['token_seq_labels'] = token_seq_labels

new_segment_ids = []
new_token_seq_labels = []
new_input_ids = []
new_token_type_ids = []
new_attention_mask = []
new_example_ids = []
new_sentences = []

for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_token_type_ids = tokenized_examples['token_type_ids'][
example_index]
example_attention_mask = tokenized_examples['attention_mask'][
example_index]
example_segment_ids = tokenized_examples['segment_ids'][
example_index]
example_token_seq_labels = tokenized_examples['token_seq_labels'][
example_index]
example_sentences = contexts[example_index]
example_id = example_ids[example_index]
example_total_num_sentences = len(questions[example_index])
example_total_num_tokens = len(
tokenized_examples['input_ids'][example_index])
accumulate_length = [
i for i, x in enumerate(tokenized_examples['input_ids']
[example_index])
if x == self.tokenizer.eos_token_id
]
samples_boundary = []
left_index = 0
sent_left_index = 0
sent_i = 0

# for sent_i, length in enumerate(accumulate_length):
while sent_i < len(accumulate_length):
length = accumulate_length[sent_i]
right_index = length + 1
sent_right_index = sent_i + 1
if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens:
samples_boundary.append([left_index, right_index])

sample_input_ids = [
self.tokenizer.cls_token_id
] + example_input_ids[left_index:right_index]
sample_input_ids = sample_input_ids[:self.max_seq_length]

sample_token_type_ids = [
0
] + example_token_type_ids[left_index:right_index]
sample_token_type_ids = sample_token_type_ids[:self.
max_seq_length]

sample_attention_mask = [
1
] + example_attention_mask[left_index:right_index]
sample_attention_mask = sample_attention_mask[:self.
max_seq_length]

sample_segment_ids = [
0
] + example_segment_ids[left_index:right_index]
sample_segment_ids = sample_segment_ids[:self.
max_seq_length]

sample_token_seq_labels = [
-100
] + example_token_seq_labels[left_index:right_index]
sample_token_seq_labels = sample_token_seq_labels[:self.
max_seq_length]

if sent_right_index - 1 == sent_left_index:
left_index = right_index
sample_input_ids[-1] = self.tokenizer.eos_token_id
sample_token_seq_labels[-1] = -100
else:
left_index = accumulate_length[sent_i - 1] + 1
if sample_token_seq_labels[-1] != -100:
sample_token_seq_labels[-1] = -100

if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens:
sample_sentences = example_sentences[
sent_left_index:sent_right_index]
sent_left_index = sent_right_index
sent_i += 1
else:
sample_sentences = example_sentences[
sent_left_index:sent_right_index - 1]
sent_left_index = sent_right_index - 1

if (len([_ for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences) - 1 and (len([
_
for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences):
tmp = []
for w_i, w, l in zip(
sample_input_ids,
self.tokenizer.decode(sample_input_ids).split(
' '), sample_token_seq_labels):
tmp.append((w_i, w, l))
while len(sample_input_ids) < self.max_seq_length:
sample_input_ids.append(self.tokenizer.pad_token_id)
sample_token_type_ids.append(0)
sample_attention_mask.append(0)
sample_segment_ids.append(example_total_num_sentences
+ 1)
sample_token_seq_labels.append(-100)

new_input_ids.append(sample_input_ids)
new_token_type_ids.append(sample_token_type_ids)
new_attention_mask.append(sample_attention_mask)
new_segment_ids.append(sample_segment_ids)
new_token_seq_labels.append(sample_token_seq_labels)
new_example_ids.append(example_id)
new_sentences.append(sample_sentences)
else:
sent_i += 1
continue

output_samples = {}

output_samples['input_ids'] = new_input_ids
output_samples['token_type_ids'] = new_token_type_ids
output_samples['attention_mask'] = new_attention_mask

output_samples['segment_ids'] = new_segment_ids
output_samples['example_id'] = new_example_ids
output_samples['labels'] = new_token_seq_labels
output_samples['sentences'] = new_sentences

return output_samples


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.fill_mask_ponet)
class FillMaskPoNetPreprocessor(NLPTokenizerPreprocessorBase):
"""The tokenizer preprocessor used in MLM task.
"""

def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
kwargs['truncation'] = kwargs.get('truncation', True)
kwargs['padding'] = kwargs.get('padding', 'max_length')
kwargs['max_length'] = kwargs.pop('sequence_length', 512)
kwargs['return_token_type_ids'] = kwargs.get('return_token_type_ids',
True)
super().__init__(model_dir, pair=False, mode=mode, **kwargs)

self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.language = self.cfg.model.get('language', 'en')
if self.language == 'en':
from nltk.tokenize import sent_tokenize
import_external_nltk_data(
osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt')
elif self.language in ['zh', 'cn']:

def sent_tokenize(para):
para = re.sub(r'([。!!?\?])([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa *
para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2',
para) # noqa *
para = para.rstrip()
return [_ for _ in para.split('\n') if _]
else:
raise NotImplementedError

self.sent_tokenize = sent_tokenize
self.max_length = kwargs['max_length']

def __call__(self, data: Union[str, Tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
data (tuple): [sentence1, sentence2]
sentence1 (str): a sentence
Example:
'you are so handsome.'
sentence2 (str): a sentence
Example:
'you are so beautiful.'
Returns:
Dict[str, Any]: the preprocessed data
"""

text_a, text_b, labels = self.parse_text_and_label(data)
output = self.tokenizer(
text_a,
text_b,
return_tensors='pt' if self._mode == ModeKeys.INFERENCE else None,
**self.tokenize_kwargs)
max_seq_length = self.max_length

if text_b is None:
segment_ids = []
seg_lens = list(
map(
len,
self.tokenizer(
self.sent_tokenize(text_a),
add_special_tokens=False,
truncation=True)['input_ids']))
segment_id = [0] + sum(
[[i] * sl for i, sl in enumerate(seg_lens, start=1)], [])
segment_id = segment_id[:max_seq_length - 1]
segment_ids.append(segment_id + [segment_id[-1] + 1]
* (max_seq_length - len(segment_id)))
output['segment_ids'] = segment_ids

output = {
k: np.array(v) if isinstance(v, list) else v
for k, v in output.items()
}

self.labels_to_id(labels, output)
return output

+ 0
- 223
modelscope/preprocessors/slp.py View File

@@ -1,223 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from transformers import BertTokenizerFast

from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields
from modelscope.utils.hub import get_model_type, parse_label_mapping
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['DocumentSegmentationPreprocessor']


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.document_segmentation)
class DocumentSegmentationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, config, *args, **kwargs):
"""preprocess the data

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

self.tokenizer = BertTokenizerFast.from_pretrained(
model_dir,
use_fast=True,
)
self.question_column_name = 'labels'
self.context_column_name = 'sentences'
self.example_id_column_name = 'example_id'
self.label_to_id = {'B-EOP': 0, 'O': 1}
self.target_specical_ids = set()
self.target_specical_ids.add(self.tokenizer.eos_token_id)
self.max_seq_length = config.max_position_embeddings
self.label_list = ['B-EOP', 'O']

def __call__(self, examples) -> Dict[str, Any]:
questions = examples[self.question_column_name]
contexts = examples[self.context_column_name]
example_ids = examples[self.example_id_column_name]
num_examples = len(questions)

sentences = []
for sentence_list in contexts:
sentence_list = [_ + '[EOS]' for _ in sentence_list]
sentences.append(sentence_list)

try:
tokenized_examples = self.tokenizer(
sentences,
is_split_into_words=True,
add_special_tokens=False,
return_token_type_ids=True,
return_attention_mask=True,
)
except Exception as e:
print(str(e))
return {}

segment_ids = []
token_seq_labels = []
for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_labels = questions[example_index]
example_labels = [
self.label_to_id[_] if _ in self.label_to_id else -100
for _ in example_labels
]
example_token_labels = []
segment_id = []
cur_seg_id = 1
for token_index in range(len(example_input_ids)):
if example_input_ids[token_index] in self.target_specical_ids:
example_token_labels.append(example_labels[cur_seg_id - 1])
segment_id.append(cur_seg_id)
cur_seg_id += 1
else:
example_token_labels.append(-100)
segment_id.append(cur_seg_id)

segment_ids.append(segment_id)
token_seq_labels.append(example_token_labels)

tokenized_examples['segment_ids'] = segment_ids
tokenized_examples['token_seq_labels'] = token_seq_labels

new_segment_ids = []
new_token_seq_labels = []
new_input_ids = []
new_token_type_ids = []
new_attention_mask = []
new_example_ids = []
new_sentences = []

for example_index in range(num_examples):
example_input_ids = tokenized_examples['input_ids'][example_index]
example_token_type_ids = tokenized_examples['token_type_ids'][
example_index]
example_attention_mask = tokenized_examples['attention_mask'][
example_index]
example_segment_ids = tokenized_examples['segment_ids'][
example_index]
example_token_seq_labels = tokenized_examples['token_seq_labels'][
example_index]
example_sentences = contexts[example_index]
example_id = example_ids[example_index]
example_total_num_sentences = len(questions[example_index])
example_total_num_tokens = len(
tokenized_examples['input_ids'][example_index])
accumulate_length = [
i for i, x in enumerate(tokenized_examples['input_ids']
[example_index])
if x == self.tokenizer.eos_token_id
]
samples_boundary = []
left_index = 0
sent_left_index = 0
sent_i = 0

# for sent_i, length in enumerate(accumulate_length):
while sent_i < len(accumulate_length):
length = accumulate_length[sent_i]
right_index = length + 1
sent_right_index = sent_i + 1
if right_index - left_index >= self.max_seq_length - 1 or right_index == example_total_num_tokens:
samples_boundary.append([left_index, right_index])

sample_input_ids = [
self.tokenizer.cls_token_id
] + example_input_ids[left_index:right_index]
sample_input_ids = sample_input_ids[:self.max_seq_length]

sample_token_type_ids = [
0
] + example_token_type_ids[left_index:right_index]
sample_token_type_ids = sample_token_type_ids[:self.
max_seq_length]

sample_attention_mask = [
1
] + example_attention_mask[left_index:right_index]
sample_attention_mask = sample_attention_mask[:self.
max_seq_length]

sample_segment_ids = [
0
] + example_segment_ids[left_index:right_index]
sample_segment_ids = sample_segment_ids[:self.
max_seq_length]

sample_token_seq_labels = [
-100
] + example_token_seq_labels[left_index:right_index]
sample_token_seq_labels = sample_token_seq_labels[:self.
max_seq_length]

if sent_right_index - 1 == sent_left_index:
left_index = right_index
sample_input_ids[-1] = self.tokenizer.eos_token_id
sample_token_seq_labels[-1] = -100
else:
left_index = accumulate_length[sent_i - 1] + 1
if sample_token_seq_labels[-1] != -100:
sample_token_seq_labels[-1] = -100

if sent_right_index - 1 == sent_left_index or right_index == example_total_num_tokens:
sample_sentences = example_sentences[
sent_left_index:sent_right_index]
sent_left_index = sent_right_index
sent_i += 1
else:
sample_sentences = example_sentences[
sent_left_index:sent_right_index - 1]
sent_left_index = sent_right_index - 1

if (len([_ for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences) - 1 and (len([
_
for _ in sample_token_seq_labels if _ != -100
])) != len(sample_sentences):
tmp = []
for w_i, w, l in zip(
sample_input_ids,
self.tokenizer.decode(sample_input_ids).split(
' '), sample_token_seq_labels):
tmp.append((w_i, w, l))
while len(sample_input_ids) < self.max_seq_length:
sample_input_ids.append(self.tokenizer.pad_token_id)
sample_token_type_ids.append(0)
sample_attention_mask.append(0)
sample_segment_ids.append(example_total_num_sentences
+ 1)
sample_token_seq_labels.append(-100)

new_input_ids.append(sample_input_ids)
new_token_type_ids.append(sample_token_type_ids)
new_attention_mask.append(sample_attention_mask)
new_segment_ids.append(sample_segment_ids)
new_token_seq_labels.append(sample_token_seq_labels)
new_example_ids.append(example_id)
new_sentences.append(sample_sentences)
else:
sent_i += 1
continue

output_samples = {}

output_samples['input_ids'] = new_input_ids
output_samples['token_type_ids'] = new_token_type_ids
output_samples['attention_mask'] = new_attention_mask

output_samples['segment_ids'] = new_segment_ids
output_samples['example_id'] = new_example_ids
output_samples['labels'] = new_token_seq_labels
output_samples['sentences'] = new_sentences

return output_samples

+ 20
- 0
modelscope/utils/nlp/nlp_utils.py View File

@@ -1,3 +1,4 @@
import os.path as osp
from typing import List

from modelscope.outputs import OutputKeys
@@ -41,3 +42,22 @@ def tracking_and_print_dialog_states(
print(json.dumps(result))

history_states.extend([result[OutputKeys.OUTPUT], {}])


def import_external_nltk_data(nltk_data_dir, package_name):
"""import external nltk_data, and extract nltk zip package.

Args:
nltk_data_dir (str): external nltk_data dir path, eg. /home/xx/nltk_data
package_name (str): nltk package name, eg. tokenizers/punkt
"""
import nltk
nltk.data.path.append(nltk_data_dir)

filepath = osp.join(nltk_data_dir, package_name + '.zip')
zippath = osp.join(nltk_data_dir, package_name)
packagepath = osp.dirname(zippath)
if not osp.exists(zippath):
import zipfile
with zipfile.ZipFile(filepath) as zf:
zf.extractall(osp.join(packagepath))

+ 48
- 0
tests/pipelines/test_fill_mask_ponet.py View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.metainfo import Pipelines
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class FillMaskPonetTest(unittest.TestCase):
model_id_ponet = {
'zh': 'damo/nlp_ponet_fill-mask_chinese-base',
'en': 'damo/nlp_ponet_fill-mask_english-base'
}

ori_texts = {
'zh':
'段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。'
'你师父差得动你,你师父可差不动我。',
'en':
'Everything in what you call reality is really just a reflection of your '
'consciousness. Your whole universe is just a mirror reflection of your story.'
}

test_inputs = {
'zh':
'段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你'
'师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。',
'en':
'Everything in [MASK] you call reality is really [MASK] a reflection of your '
'[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_ponet_model(self):
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]

pipeline_ins = pipeline(
task=Tasks.fill_mask, model=self.model_id_ponet[language])

print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save