| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_model import * # noqa F403 | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| from .masked_language_model import * # noqa F403 | |||
| @@ -0,0 +1,43 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import numpy as np | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| from ...utils.constant import Tasks | |||
| __all__ = ['MaskedLanguageModel'] | |||
| @MODELS.register_module(Tasks.fill_mask, module_name=r'sbert') | |||
| class MaskedLanguageModel(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| from sofa.utils.backend import AutoConfig, AutoModelForMaskedLM | |||
| self.model_dir = model_dir | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.config = AutoConfig.from_pretrained(model_dir) | |||
| self.model = AutoModelForMaskedLM.from_pretrained(model_dir, config=self.config) | |||
| def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| rst = self.model( | |||
| input_ids=inputs["input_ids"], | |||
| attention_mask=inputs['attention_mask'], | |||
| token_type_ids=inputs["token_type_ids"] | |||
| ) | |||
| return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} | |||
| @@ -24,6 +24,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.image_generation: | |||
| ('person-image-cartoon', | |||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||
| Tasks.fill_mask: | |||
| ('sbert') | |||
| } | |||
| @@ -1,3 +1,4 @@ | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .fill_mask_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,57 @@ | |||
| from typing import Dict | |||
| from modelscope.models.nlp import MaskedLanguageModel | |||
| from modelscope.preprocessors import FillMaskPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| __all__ = ['FillMaskPipeline'] | |||
| @PIPELINES.register_module(Tasks.fill_mask, module_name=r'sbert') | |||
| class FillMaskPipeline(Pipeline): | |||
| def __init__(self, model: MaskedLanguageModel, | |||
| preprocessor: FillMaskPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.preprocessor = preprocessor | |||
| self.tokenizer = preprocessor.tokenizer | |||
| self.mask_id = {'veco': 250001, 'sbert': 103} | |||
| def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| import numpy as np | |||
| logits = inputs["logits"].detach().numpy() | |||
| input_ids = inputs["input_ids"].detach().numpy() | |||
| pred_ids = np.argmax(logits, axis=-1) | |||
| rst_ids = np.where(input_ids==self.mask_id[self.model.config.model_type], pred_ids, input_ids) | |||
| pred_strings = [] | |||
| for ids in rst_ids: | |||
| if self.model.config.model_type == 'veco': | |||
| pred_string = self.tokenizer.decode(ids).split('</s>')[0].replace("<s>", "").replace("</s>", "").replace("<pad>", "") | |||
| elif self.model.config.vocab_size == 21128: # zh bert | |||
| pred_string = self.tokenizer.convert_ids_to_tokens(ids) | |||
| pred_string = ''.join(pred_string).replace('##','') | |||
| pred_string = pred_string.split('[SEP]')[0].replace('[CLS]', '').replace('[SEP]', '').replace('[UNK]', '') | |||
| else: ## en bert | |||
| pred_string = self.tokenizer.decode(ids) | |||
| pred_string = pred_string.split('[SEP]')[0].replace('[CLS]', '').replace('[SEP]', '').replace('[UNK]', '') | |||
| pred_strings.append(pred_string) | |||
| return {'pred_string': pred_strings} | |||
| @@ -12,7 +12,8 @@ from .builder import PREPROCESSORS | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor' | |||
| 'TextGenerationPreprocessor', | |||
| 'FillMaskPreprocessor' | |||
| ] | |||
| @@ -166,8 +167,67 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'sbert') | |||
| class FillMaskPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa.utils.backend import AutoTokenizer | |||
| self.model_dir = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| import torch | |||
| new_data = {self.first_sequence: data} | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [] | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| padding='max_length', | |||
| truncation=True, | |||
| max_length=max_seq_length, | |||
| return_token_type_ids=True) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @@ -42,7 +42,7 @@ class Tasks(object): | |||
| table_question_answering = 'table-question-answering' | |||
| feature_extraction = 'feature-extraction' | |||
| sentence_similarity = 'sentence-similarity' | |||
| fill_mask = 'fill-mask ' | |||
| fill_mask = 'fill-mask' | |||
| summarization = 'summarization' | |||
| question_answering = 'question-answering' | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import unittest | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models.nlp import MaskedLanguageModel | |||
| from modelscope.pipelines import FillMaskPipeline, pipeline | |||
| from modelscope.preprocessors import FillMaskPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.models import Model | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class FillMaskTest(unittest.TestCase): | |||
| model_id_sbert = {'zh': 'damo/nlp_structbert_fill-mask-chinese_large', | |||
| 'en': 'damo/nlp_structbert_fill-mask-english_large'} | |||
| model_id_veco = 'damo/nlp_veco_fill-mask_large' | |||
| ori_texts = {"zh": "段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。你师父差得动你,你师父可差不动我。", | |||
| "en": "Everything in what you call reality is really just a reflection of your consciousness. Your whole universe is just a mirror reflection of your story."} | |||
| test_inputs = {"zh": "段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。", | |||
| "en": "Everything in [MASK] you call reality is really [MASK] a reflection of your [MASK]. Your whole universe is just a mirror [MASK] of your story."} | |||
| #def test_run(self): | |||
| # # sbert | |||
| # for language in ["zh", "en"]: | |||
| # model_dir = snapshot_download(self.model_id_sbert[language]) | |||
| # preprocessor = FillMaskPreprocessor( | |||
| # model_dir, first_sequence='sentence', second_sequence=None) | |||
| # model = MaskedLanguageModel(model_dir) | |||
| # pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| # pipeline2 = pipeline( | |||
| # Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| # ori_text = self.ori_texts[language] | |||
| # test_input = self.test_inputs[language] | |||
| # print( | |||
| # f'ori_text: {ori_text}\ninput: {test_input}\npipeline1: {pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}' | |||
| # ) | |||
| ## veco | |||
| #model_dir = snapshot_download(self.model_id_veco) | |||
| #preprocessor = FillMaskPreprocessor( | |||
| # model_dir, first_sequence='sentence', second_sequence=None) | |||
| #model = MaskedLanguageModel(model_dir) | |||
| #pipeline1 = FillMaskPipeline(model, preprocessor) | |||
| #pipeline2 = pipeline( | |||
| # Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| #for language in ["zh", "en"]: | |||
| # ori_text = self.ori_texts[language] | |||
| # test_input = self.test_inputs["zh"].replace("[MASK]", "<mask>") | |||
| # print( | |||
| # f'ori_text: {ori_text}\ninput: {test_input}\npipeline1: {pipeline1(test_input)}\npipeline2: {pipeline2(test_input)}' | |||
| def test_run_with_model_from_modelhub(self): | |||
| for language in ["zh"]: | |||
| print(self.model_id_sbert[language]) | |||
| model = Model.from_pretrained(self.model_id_sbert[language]) | |||
| print("model", model.model_dir) | |||
| preprocessor = FillMaskPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.fill_mask, model=model, preprocessor=preprocessor) | |||
| print(pipeline_ins(self_test_inputs[language])) | |||
| #def test_run_with_model_name(self): | |||
| ## veco | |||
| #pipeline_ins = pipeline( | |||
| # task=Tasks.fill_mask, model=self.model_id_veco) | |||
| #for language in ["zh", "en"]: | |||
| # input_ = self.test_inputs[language].replace("[MASK]", "<mask>") | |||
| # print(pipeline_ins(input_)) | |||
| ## structBert | |||
| #for language in ["zh"]: | |||
| # pipeline_ins = pipeline( | |||
| # task=Tasks.fill_mask, model=self.model_id_sbert[language]) | |||
| # print(pipeline_ins(self_test_inputs[language])) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||