Browse Source

[to #42322933] Add nlp-structbert/veco-fill-mask-pipeline to maas lib

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9069107
master
suluyan.sly 3 years ago
parent
commit
0286dd45cc
10 changed files with 342 additions and 5 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/models/nlp/__init__.py
  3. +51
    -0
      modelscope/models/nlp/masked_language_model.py
  4. +1
    -3
      modelscope/pipelines/builder.py
  5. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  6. +93
    -0
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  7. +6
    -0
      modelscope/pipelines/outputs.py
  8. +57
    -1
      modelscope/preprocessors/nlp.py
  9. +1
    -1
      requirements/nlp.txt
  10. +129
    -0
      tests/pipelines/test_fill_mask.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -15,6 +15,7 @@ class Models(object):
bert = 'bert'
palm = 'palm-v2'
structbert = 'structbert'
veco = 'veco'

# audio models
sambert_hifi_16k = 'sambert-hifi-16k'
@@ -46,6 +47,7 @@ class Pipelines(object):
word_segmentation = 'word-segmentation'
text_generation = 'text-generation'
sentiment_analysis = 'sentiment-analysis'
fill_mask = 'fill-mask'

# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'


+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -1,4 +1,5 @@
from .bert_for_sequence_classification import * # noqa F403
from .masked_language_model import * # noqa F403
from .palm_for_text_generation import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_token_classification import * # noqa F403

+ 51
- 0
modelscope/models/nlp/masked_language_model.py View File

@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional, Union

import numpy as np

from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM']


class AliceMindBaseForMaskedLM(Model):

def __init__(self, model_dir: str, *args, **kwargs):
from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM
self.model_dir = model_dir
super().__init__(model_dir, *args, **kwargs)

self.config = AutoConfig.from_pretrained(model_dir)
self.model = AutoModelForMaskedLM.from_pretrained(
model_dir, config=self.config)

def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
"""
rst = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}


@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)
class StructBertForMaskedLM(AliceMindBaseForMaskedLM):
# The StructBert for MaskedLM uses the same underlying model structure
# as the base model class.
pass


@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco)
class VecoForMaskedLM(AliceMindBaseForMaskedLM):
# The Veco for MaskedLM uses the same underlying model structure
# as the base model class.
pass

+ 1
- 3
modelscope/pipelines/builder.py View File

@@ -1,10 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from typing import List, Union

from attr import has

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.utils.config import Config, ConfigDict
@@ -37,6 +34,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
}


+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,3 +1,4 @@
from .fill_mask_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403


+ 93
- 0
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -0,0 +1,93 @@
from typing import Dict, Optional, Union

from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.models.nlp.masked_language_model import \
AliceMindBaseForMaskedLM
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

__all__ = ['FillMaskPipeline']


@PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask)
class FillMaskPipeline(Pipeline):

def __init__(self,
model: Union[AliceMindBaseForMaskedLM, str],
preprocessor: Optional[FillMaskPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction

Args:
model (AliceMindBaseForMaskedLM): a model instance
preprocessor (FillMaskPreprocessor): a preprocessor instance
"""
fill_mask_model = model if isinstance(
model, AliceMindBaseForMaskedLM) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = FillMaskPreprocessor(
fill_mask_model.model_dir,
first_sequence='sentence',
second_sequence=None)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.preprocessor = preprocessor
self.tokenizer = preprocessor.tokenizer
self.mask_id = {'veco': 250001, 'sbert': 103}

self.rep_map = {
'sbert': {
'[unused0]': '',
'[PAD]': '',
'[unused1]': '',
r' +': ' ',
'[SEP]': '',
'[unused2]': '',
'[CLS]': '',
'[UNK]': ''
},
'veco': {
r' +': ' ',
'<mask>': '<q>',
'<pad>': '',
'<s>': '',
'</s>': '',
'<unk>': ' '
}
}

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
import numpy as np
logits = inputs['logits'].detach().numpy()
input_ids = inputs['input_ids'].detach().numpy()
pred_ids = np.argmax(logits, axis=-1)
model_type = self.model.config.model_type
rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids,
input_ids)

def rep_tokens(string, rep_map):
for k, v in rep_map.items():
string = string.replace(k, v)
return string.strip()

pred_strings = []
for ids in rst_ids: # batch
if self.model.config.vocab_size == 21128: # zh bert
pred_string = self.tokenizer.convert_ids_to_tokens(ids)
pred_string = ''.join(pred_string)
else:
pred_string = self.tokenizer.decode(ids)
pred_string = rep_tokens(pred_string, self.rep_map[model_type])
pred_strings.append(pred_string)

return {'text': pred_strings}

+ 6
- 0
modelscope/pipelines/outputs.py View File

@@ -82,6 +82,12 @@ TASK_OUTPUTS = {
# }
Tasks.text_generation: ['text'],

# fill mask result for single sample
# {
# "text": "this is the text which masks filled by model."
# }
Tasks.fill_mask: ['text'],

# word segmentation result for single sample
# {
# "output": "今天 天气 不错 , 适合 出去 游玩"


+ 57
- 1
modelscope/preprocessors/nlp.py View File

@@ -13,7 +13,8 @@ from .builder import PREPROCESSORS

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor'
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
'FillMaskPreprocessor'
]


@@ -181,6 +182,61 @@ class TextGenerationPreprocessor(Preprocessor):
return {k: torch.tensor(v) for k, v in rst.items()}


@PREPROCESSORS.register_module(Fields.nlp)
class FillMaskPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
from sofa.utils.backend import AutoTokenizer
self.model_dir = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
import torch

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
feature = self.tokenizer(
text_a,
padding='max_length',
truncation=True,
max_length=max_seq_length,
return_token_type_ids=True)

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])

return {k: torch.tensor(v) for k, v in rst.items()}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer)
class TokenClassifcationPreprocessor(Preprocessor):


+ 1
- 1
requirements/nlp.txt View File

@@ -1 +1 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl

+ 129
- 0
tests/pipelines/test_fill_mask.py View File

@@ -0,0 +1,129 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM
from modelscope.pipelines import FillMaskPipeline, pipeline
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class FillMaskTest(unittest.TestCase):
model_id_sbert = {
'zh': 'damo/nlp_structbert_fill-mask_chinese-large',
'en': 'damo/nlp_structbert_fill-mask_english-large'
}
model_id_veco = 'damo/nlp_veco_fill-mask-large'

ori_texts = {
'zh':
'段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。'
'你师父差得动你,你师父可差不动我。',
'en':
'Everything in what you call reality is really just a reflection of your '
'consciousness. Your whole universe is just a mirror reflection of your story.'
}

test_inputs = {
'zh':
'段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你'
'师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。',
'en':
'Everything in [MASK] you call reality is really [MASK] a reflection of your '
'[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
# sbert
for language in ['zh', 'en']:
model_dir = snapshot_download(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = StructBertForMaskedLM(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
)

# veco
model_dir = snapshot_download(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = VecoForMaskedLM(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
# sbert
for language in ['zh', 'en']:
print(self.model_id_sbert[language])
model = Model.from_pretrained(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
model.model_dir,
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

# veco
model = Model.from_pretrained(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
# veco
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_veco)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

# structBert
language = 'zh'
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=self.model_id_sbert[language])
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.fill_mask)
language = 'en'
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save