Browse Source

merge with fill_mask

master
智丞 3 years ago
parent
commit
17e6a3d41d
10 changed files with 343 additions and 3 deletions
  1. +1
    -0
      modelscope/models/nlp/__init__.py
  2. +50
    -0
      modelscope/models/nlp/masked_language_model.py
  3. +1
    -0
      modelscope/pipelines/builder.py
  4. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  5. +93
    -0
      modelscope/pipelines/nlp/fill_mask_pipeline.py
  6. +6
    -0
      modelscope/pipelines/outputs.py
  7. +56
    -1
      modelscope/preprocessors/nlp.py
  8. +1
    -1
      modelscope/utils/constant.py
  9. +1
    -1
      requirements/nlp.txt
  10. +133
    -0
      tests/pipelines/test_fill_mask.py

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -1,4 +1,5 @@
from .bert_for_sequence_classification import * # noqa F403 from .bert_for_sequence_classification import * # noqa F403
from .masked_language_model import * # noqa F403
from .nli_model import * # noqa F403 from .nli_model import * # noqa F403
from .palm_for_text_generation import * # noqa F403 from .palm_for_text_generation import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403 from .sbert_for_sentence_similarity import * # noqa F403


+ 50
- 0
modelscope/models/nlp/masked_language_model.py View File

@@ -0,0 +1,50 @@
from typing import Any, Dict, Optional, Union

import numpy as np

from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM']


class AliceMindBaseForMaskedLM(Model):

def __init__(self, model_dir: str, *args, **kwargs):
from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM
self.model_dir = model_dir
super().__init__(model_dir, *args, **kwargs)

self.config = AutoConfig.from_pretrained(model_dir)
self.model = AutoModelForMaskedLM.from_pretrained(
model_dir, config=self.config)

def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
"""
rst = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}


@MODELS.register_module(Tasks.fill_mask, module_name=r'sbert')
class StructBertForMaskedLM(AliceMindBaseForMaskedLM):
# The StructBert for MaskedLM uses the same underlying model structure
# as the base model class.
pass


@MODELS.register_module(Tasks.fill_mask, module_name=r'veco')
class VecoForMaskedLM(AliceMindBaseForMaskedLM):
# The Veco for MaskedLM uses the same underlying model structure
# as the base model class.
pass

+ 1
- 0
modelscope/pipelines/builder.py View File

@@ -38,6 +38,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'), 'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: ('ocr-detection', Tasks.ocr_detection: ('ocr-detection',
'damo/cv_resnet18_ocr-detection-line-level_damo'), 'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.fill_mask: ('veco', 'damo/nlp_veco_fill-mask_large')
} }






+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,3 +1,4 @@
from .fill_mask_pipeline import * # noqa F403
from .nli_pipeline import * # noqa F403 from .nli_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403 from .sentence_similarity_pipeline import * # noqa F403
from .sentiment_classification_pipeline import * # noqa F403 from .sentiment_classification_pipeline import * # noqa F403


+ 93
- 0
modelscope/pipelines/nlp/fill_mask_pipeline.py View File

@@ -0,0 +1,93 @@
from typing import Dict, Optional, Union

from modelscope.models import Model
from modelscope.models.nlp.masked_language_model import \
AliceMindBaseForMaskedLM
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

__all__ = ['FillMaskPipeline']


@PIPELINES.register_module(Tasks.fill_mask, module_name=r'sbert')
@PIPELINES.register_module(Tasks.fill_mask, module_name=r'veco')
class FillMaskPipeline(Pipeline):

def __init__(self,
model: Union[AliceMindBaseForMaskedLM, str],
preprocessor: Optional[FillMaskPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction

Args:
model (AliceMindBaseForMaskedLM): a model instance
preprocessor (FillMaskPreprocessor): a preprocessor instance
"""
fill_mask_model = model if isinstance(
model, AliceMindBaseForMaskedLM) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = FillMaskPreprocessor(
fill_mask_model.model_dir,
first_sequence='sentence',
second_sequence=None)
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.preprocessor = preprocessor
self.tokenizer = preprocessor.tokenizer
self.mask_id = {'veco': 250001, 'sbert': 103}

self.rep_map = {
'sbert': {
'[unused0]': '',
'[PAD]': '',
'[unused1]': '',
r' +': ' ',
'[SEP]': '',
'[unused2]': '',
'[CLS]': '',
'[UNK]': ''
},
'veco': {
r' +': ' ',
'<mask>': '<q>',
'<pad>': '',
'<s>': '',
'</s>': '',
'<unk>': ' '
}
}

def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
import numpy as np
logits = inputs['logits'].detach().numpy()
input_ids = inputs['input_ids'].detach().numpy()
pred_ids = np.argmax(logits, axis=-1)
model_type = self.model.config.model_type
rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids,
input_ids)

def rep_tokens(string, rep_map):
for k, v in rep_map.items():
string = string.replace(k, v)
return string.strip()

pred_strings = []
for ids in rst_ids: # batch
if self.model.config.vocab_size == 21128: # zh bert
pred_string = self.tokenizer.convert_ids_to_tokens(ids)
pred_string = ''.join(pred_string)
else:
pred_string = self.tokenizer.decode(ids)
pred_string = rep_tokens(pred_string, self.rep_map[model_type])
pred_strings.append(pred_string)

return {'text': pred_strings}

+ 6
- 0
modelscope/pipelines/outputs.py View File

@@ -76,6 +76,12 @@ TASK_OUTPUTS = {
# } # }
Tasks.text_generation: ['text'], Tasks.text_generation: ['text'],


# fill mask result for single sample
# {
# "text": "this is the text which masks filled by model."
# }
Tasks.fill_mask: ['text'],

# word segmentation result for single sample # word segmentation result for single sample
# { # {
# "output": "今天 天气 不错 , 适合 出去 游玩" # "output": "今天 天气 不错 , 适合 出去 游玩"


+ 56
- 1
modelscope/preprocessors/nlp.py View File

@@ -14,7 +14,7 @@ __all__ = [
'Tokenize', 'SequenceClassificationPreprocessor', 'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor', 'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor',
'TokenClassifcationPreprocessor', 'NLIPreprocessor', 'TokenClassifcationPreprocessor', 'NLIPreprocessor',
'SentimentClassificationPreprocessor'
'SentimentClassificationPreprocessor', 'FillMaskPreprocessor'
] ]




@@ -311,6 +311,61 @@ class TextGenerationPreprocessor(Preprocessor):


rst['input_ids'].append(feature['input_ids']) rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask']) rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
return {k: torch.tensor(v) for k, v in rst.items()}


@PREPROCESSORS.register_module(Fields.nlp)
class FillMaskPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
from sofa.utils.backend import AutoTokenizer
self.model_dir = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
import torch

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
feature = self.tokenizer(
text_a,
padding='max_length',
truncation=True,
max_length=max_seq_length,
return_token_type_ids=True)

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])


return {k: torch.tensor(v) for k, v in rst.items()} return {k: torch.tensor(v) for k, v in rst.items()}




+ 1
- 1
modelscope/utils/constant.py View File

@@ -47,7 +47,7 @@ class Tasks(object):
table_question_answering = 'table-question-answering' table_question_answering = 'table-question-answering'
feature_extraction = 'feature-extraction' feature_extraction = 'feature-extraction'
sentence_similarity = 'sentence-similarity' sentence_similarity = 'sentence-similarity'
fill_mask = 'fill-mask '
fill_mask = 'fill-mask'
summarization = 'summarization' summarization = 'summarization'
question_answering = 'question-answering' question_answering = 'question-answering'




+ 1
- 1
requirements/nlp.txt View File

@@ -1 +1 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl

+ 133
- 0
tests/pipelines/test_fill_mask.py View File

@@ -0,0 +1,133 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.models import Model
from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM
from modelscope.pipelines import FillMaskPipeline, pipeline
from modelscope.preprocessors import FillMaskPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.test_utils import test_level


class FillMaskTest(unittest.TestCase):
model_id_sbert = {
'zh': 'damo/nlp_structbert_fill-mask-chinese_large',
'en': 'damo/nlp_structbert_fill-mask-english_large'
}
model_id_veco = 'damo/nlp_veco_fill-mask_large'

ori_texts = {
'zh':
'段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。'
'你师父差得动你,你师父可差不动我。',
'en':
'Everything in what you call reality is really just a reflection of your '
'consciousness. Your whole universe is just a mirror reflection of your story.'
}

test_inputs = {
'zh':
'段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你'
'师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。',
'en':
'Everything in [MASK] you call reality is really [MASK] a reflection of your '
'[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
}

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_by_direct_model_download(self):
# sbert
for language in ['zh', 'en']:
model_dir = snapshot_download(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = StructBertForMaskedLM(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language]
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
)

# veco
model_dir = snapshot_download(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
model_dir, first_sequence='sentence', second_sequence=None)
model = VecoForMaskedLM(model_dir)
pipeline1 = FillMaskPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(
f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: '
f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n'
)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
# sbert
for language in ['zh', 'en']:
print(self.model_id_sbert[language])
model = Model.from_pretrained(self.model_id_sbert[language])
preprocessor = FillMaskPreprocessor(
model.model_dir,
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=model, preprocessor=preprocessor)
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

# veco
model = Model.from_pretrained(self.model_id_veco)
preprocessor = FillMaskPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
Tasks.fill_mask, model=model, preprocessor=preprocessor)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
# veco
pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_veco)
for language in ['zh', 'en']:
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')

# structBert
language = 'zh'
pipeline_ins = pipeline(
task=Tasks.fill_mask, model=self.model_id_sbert[language])
print(
f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: '
f'{pipeline_ins(self.test_inputs[language])}\n')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.fill_mask)
language = 'en'
ori_text = self.ori_texts[language]
test_input = self.test_inputs[language].replace('[MASK]', '<mask>')
print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
f'{pipeline_ins(test_input)}\n')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save