Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9069107master
| @@ -15,6 +15,7 @@ class Models(object): | |||||
| bert = 'bert' | bert = 'bert' | ||||
| palm = 'palm-v2' | palm = 'palm-v2' | ||||
| structbert = 'structbert' | structbert = 'structbert' | ||||
| veco = 'veco' | |||||
| # audio models | # audio models | ||||
| sambert_hifi_16k = 'sambert-hifi-16k' | sambert_hifi_16k = 'sambert-hifi-16k' | ||||
| @@ -46,6 +47,7 @@ class Pipelines(object): | |||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| text_generation = 'text-generation' | text_generation = 'text-generation' | ||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| fill_mask = 'fill-mask' | |||||
| # audio tasks | # audio tasks | ||||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | ||||
| @@ -1,4 +1,5 @@ | |||||
| from .bert_for_sequence_classification import * # noqa F403 | from .bert_for_sequence_classification import * # noqa F403 | ||||
| from .masked_language_model import * # noqa F403 | |||||
| from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
| from .sbert_for_sentence_similarity import * # noqa F403 | from .sbert_for_sentence_similarity import * # noqa F403 | ||||
| from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
| @@ -0,0 +1,51 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ..base import Model, Tensor | |||||
| from ..builder import MODELS | |||||
| __all__ = ['StructBertForMaskedLM', 'VecoForMaskedLM'] | |||||
| class AliceMindBaseForMaskedLM(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM | |||||
| self.model_dir = model_dir | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.config = AutoConfig.from_pretrained(model_dir) | |||||
| self.model = AutoModelForMaskedLM.from_pretrained( | |||||
| model_dir, config=self.config) | |||||
| def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| """ | |||||
| rst = self.model( | |||||
| input_ids=inputs['input_ids'], | |||||
| attention_mask=inputs['attention_mask'], | |||||
| token_type_ids=inputs['token_type_ids']) | |||||
| return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} | |||||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) | |||||
| class StructBertForMaskedLM(AliceMindBaseForMaskedLM): | |||||
| # The StructBert for MaskedLM uses the same underlying model structure | |||||
| # as the base model class. | |||||
| pass | |||||
| @MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) | |||||
| class VecoForMaskedLM(AliceMindBaseForMaskedLM): | |||||
| # The Veco for MaskedLM uses the same underlying model structure | |||||
| # as the base model class. | |||||
| pass | |||||
| @@ -1,10 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | |||||
| from typing import List, Union | from typing import List, Union | ||||
| from attr import has | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| @@ -37,6 +34,7 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | Tasks.ocr_detection: (Pipelines.ocr_detection, | ||||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | 'damo/cv_resnet18_ocr-detection-line-level_damo'), | ||||
| Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||||
| Tasks.action_recognition: (Pipelines.action_recognition, | Tasks.action_recognition: (Pipelines.action_recognition, | ||||
| 'damo/cv_TAdaConv_action-recognition'), | 'damo/cv_TAdaConv_action-recognition'), | ||||
| } | } | ||||
| @@ -1,3 +1,4 @@ | |||||
| from .fill_mask_pipeline import * # noqa F403 | |||||
| from .sentence_similarity_pipeline import * # noqa F403 | from .sentence_similarity_pipeline import * # noqa F403 | ||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| @@ -0,0 +1,93 @@ | |||||
| from typing import Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp.masked_language_model import \ | |||||
| AliceMindBaseForMaskedLM | |||||
| from modelscope.preprocessors import FillMaskPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | |||||
| from ..builder import PIPELINES | |||||
| __all__ = ['FillMaskPipeline'] | |||||
| @PIPELINES.register_module(Tasks.fill_mask, module_name=Pipelines.fill_mask) | |||||
| class FillMaskPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[AliceMindBaseForMaskedLM, str], | |||||
| preprocessor: Optional[FillMaskPreprocessor] = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp fill mask pipeline for prediction | |||||
| Args: | |||||
| model (AliceMindBaseForMaskedLM): a model instance | |||||
| preprocessor (FillMaskPreprocessor): a preprocessor instance | |||||
| """ | |||||
| fill_mask_model = model if isinstance( | |||||
| model, AliceMindBaseForMaskedLM) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = FillMaskPreprocessor( | |||||
| fill_mask_model.model_dir, | |||||
| first_sequence='sentence', | |||||
| second_sequence=None) | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.preprocessor = preprocessor | |||||
| self.tokenizer = preprocessor.tokenizer | |||||
| self.mask_id = {'veco': 250001, 'sbert': 103} | |||||
| self.rep_map = { | |||||
| 'sbert': { | |||||
| '[unused0]': '', | |||||
| '[PAD]': '', | |||||
| '[unused1]': '', | |||||
| r' +': ' ', | |||||
| '[SEP]': '', | |||||
| '[unused2]': '', | |||||
| '[CLS]': '', | |||||
| '[UNK]': '' | |||||
| }, | |||||
| 'veco': { | |||||
| r' +': ' ', | |||||
| '<mask>': '<q>', | |||||
| '<pad>': '', | |||||
| '<s>': '', | |||||
| '</s>': '', | |||||
| '<unk>': ' ' | |||||
| } | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| import numpy as np | |||||
| logits = inputs['logits'].detach().numpy() | |||||
| input_ids = inputs['input_ids'].detach().numpy() | |||||
| pred_ids = np.argmax(logits, axis=-1) | |||||
| model_type = self.model.config.model_type | |||||
| rst_ids = np.where(input_ids == self.mask_id[model_type], pred_ids, | |||||
| input_ids) | |||||
| def rep_tokens(string, rep_map): | |||||
| for k, v in rep_map.items(): | |||||
| string = string.replace(k, v) | |||||
| return string.strip() | |||||
| pred_strings = [] | |||||
| for ids in rst_ids: # batch | |||||
| if self.model.config.vocab_size == 21128: # zh bert | |||||
| pred_string = self.tokenizer.convert_ids_to_tokens(ids) | |||||
| pred_string = ''.join(pred_string) | |||||
| else: | |||||
| pred_string = self.tokenizer.decode(ids) | |||||
| pred_string = rep_tokens(pred_string, self.rep_map[model_type]) | |||||
| pred_strings.append(pred_string) | |||||
| return {'text': pred_strings} | |||||
| @@ -82,6 +82,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.text_generation: ['text'], | Tasks.text_generation: ['text'], | ||||
| # fill mask result for single sample | |||||
| # { | |||||
| # "text": "this is the text which masks filled by model." | |||||
| # } | |||||
| Tasks.fill_mask: ['text'], | |||||
| # word segmentation result for single sample | # word segmentation result for single sample | ||||
| # { | # { | ||||
| # "output": "今天 天气 不错 , 适合 出去 游玩" | # "output": "今天 天气 不错 , 适合 出去 游玩" | ||||
| @@ -13,7 +13,8 @@ from .builder import PREPROCESSORS | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor' | |||||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | |||||
| 'FillMaskPreprocessor' | |||||
| ] | ] | ||||
| @@ -181,6 +182,61 @@ class TextGenerationPreprocessor(Preprocessor): | |||||
| return {k: torch.tensor(v) for k, v in rst.items()} | return {k: torch.tensor(v) for k, v in rst.items()} | ||||
| @PREPROCESSORS.register_module(Fields.nlp) | |||||
| class FillMaskPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| from sofa.utils.backend import AutoTokenizer | |||||
| self.model_dir = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = AutoTokenizer.from_pretrained( | |||||
| model_dir, use_fast=False) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| import torch | |||||
| new_data = {self.first_sequence: data} | |||||
| # preprocess the data for the model input | |||||
| rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data[self.first_sequence] | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=max_seq_length, | |||||
| return_token_type_ids=True) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | ||||
| class TokenClassifcationPreprocessor(Preprocessor): | class TokenClassifcationPreprocessor(Preprocessor): | ||||
| @@ -1 +1 @@ | |||||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.2-py3-none-any.whl | |||||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl | |||||
| @@ -0,0 +1,129 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM | |||||
| from modelscope.pipelines import FillMaskPipeline, pipeline | |||||
| from modelscope.preprocessors import FillMaskPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class FillMaskTest(unittest.TestCase): | |||||
| model_id_sbert = { | |||||
| 'zh': 'damo/nlp_structbert_fill-mask_chinese-large', | |||||
| 'en': 'damo/nlp_structbert_fill-mask_english-large' | |||||
| } | |||||
| model_id_veco = 'damo/nlp_veco_fill-mask-large' | |||||
| ori_texts = { | |||||
| 'zh': | |||||
| '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。' | |||||
| '你师父差得动你,你师父可差不动我。', | |||||
| 'en': | |||||
| 'Everything in what you call reality is really just a reflection of your ' | |||||
| 'consciousness. Your whole universe is just a mirror reflection of your story.' | |||||
| } | |||||
| test_inputs = { | |||||
| 'zh': | |||||
| '段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你' | |||||
| '师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。', | |||||
| 'en': | |||||
| 'Everything in [MASK] you call reality is really [MASK] a reflection of your ' | |||||
| '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.' | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| # sbert | |||||
| for language in ['zh', 'en']: | |||||
| model_dir = snapshot_download(self.model_id_sbert[language]) | |||||
| preprocessor = FillMaskPreprocessor( | |||||
| model_dir, first_sequence='sentence', second_sequence=None) | |||||
| model = StructBertForMaskedLM(model_dir) | |||||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||||
| ori_text = self.ori_texts[language] | |||||
| test_input = self.test_inputs[language] | |||||
| print( | |||||
| f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' | |||||
| f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' | |||||
| ) | |||||
| # veco | |||||
| model_dir = snapshot_download(self.model_id_veco) | |||||
| preprocessor = FillMaskPreprocessor( | |||||
| model_dir, first_sequence='sentence', second_sequence=None) | |||||
| model = VecoForMaskedLM(model_dir) | |||||
| pipeline1 = FillMaskPipeline(model, preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||||
| for language in ['zh', 'en']: | |||||
| ori_text = self.ori_texts[language] | |||||
| test_input = self.test_inputs[language].replace('[MASK]', '<mask>') | |||||
| print( | |||||
| f'\nori_text: {ori_text}\ninput: {test_input}\npipeline1: ' | |||||
| f'{pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}\n' | |||||
| ) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| # sbert | |||||
| for language in ['zh', 'en']: | |||||
| print(self.model_id_sbert[language]) | |||||
| model = Model.from_pretrained(self.model_id_sbert[language]) | |||||
| preprocessor = FillMaskPreprocessor( | |||||
| model.model_dir, | |||||
| first_sequence='sentence', | |||||
| second_sequence=None) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||||
| print( | |||||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||||
| # veco | |||||
| model = Model.from_pretrained(self.model_id_veco) | |||||
| preprocessor = FillMaskPreprocessor( | |||||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||||
| pipeline_ins = pipeline( | |||||
| Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||||
| for language in ['zh', 'en']: | |||||
| ori_text = self.ori_texts[language] | |||||
| test_input = self.test_inputs[language].replace('[MASK]', '<mask>') | |||||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||||
| f'{pipeline_ins(test_input)}\n') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| # veco | |||||
| pipeline_ins = pipeline(task=Tasks.fill_mask, model=self.model_id_veco) | |||||
| for language in ['zh', 'en']: | |||||
| ori_text = self.ori_texts[language] | |||||
| test_input = self.test_inputs[language].replace('[MASK]', '<mask>') | |||||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||||
| f'{pipeline_ins(test_input)}\n') | |||||
| # structBert | |||||
| language = 'zh' | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.fill_mask, model=self.model_id_sbert[language]) | |||||
| print( | |||||
| f'\nori_text: {self.ori_texts[language]}\ninput: {self.test_inputs[language]}\npipeline: ' | |||||
| f'{pipeline_ins(self.test_inputs[language])}\n') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.fill_mask) | |||||
| language = 'en' | |||||
| ori_text = self.ori_texts[language] | |||||
| test_input = self.test_inputs[language].replace('[MASK]', '<mask>') | |||||
| print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: ' | |||||
| f'{pipeline_ins(test_input)}\n') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||