Adding the new task of sentence_similarity, in which the model is the sofa version of structbert
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9016402
* sbert-sentence-similarity
* [to #42322933] pip8
* merge with master for file dirs update
* add test cases
* pre-commit lint check
* remove useless file
* download models again~
* skip time consuming test case
* update for pr reviews
* merge with master
* add test level
* reset test level to env level
* [to #42322933] init
* [to #42322933] init
* adding purge logic in test
* merge with head
* change test level
* using sequence classification processor for similarity
master
| @@ -2,4 +2,4 @@ | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .nlp import BertForSequenceClassification | |||
| from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity | |||
| @@ -1,2 +1,3 @@ | |||
| from .sentence_similarity_model import * # noqa F403 | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| @@ -0,0 +1,88 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertModel | |||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||
| from torch import nn | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['SbertForSentenceSimilarity'] | |||
| class SbertTextClassifier(SbertPreTrainedModel): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.num_labels = config.num_labels | |||
| self.config = config | |||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
| def forward(self, input_ids=None, token_type_ids=None): | |||
| outputs = self.encoder( | |||
| input_ids, | |||
| token_type_ids=token_type_ids, | |||
| return_dict=None, | |||
| ) | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| return logits | |||
| @MODELS.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| class SbertForSentenceSimilarity(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the sentence similarity model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = SbertTextClassifier.from_pretrained( | |||
| model_dir, num_labels=2) | |||
| self.model.eval() | |||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||
| token_type_ids = torch.tensor( | |||
| input['token_type_ids'], dtype=torch.long) | |||
| with torch.no_grad(): | |||
| logits = self.model(input_ids, token_type_ids) | |||
| probs = logits.softmax(-1).numpy() | |||
| pred = logits.argmax(-1).numpy() | |||
| logits = logits.numpy() | |||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||
| return res | |||
| @@ -15,7 +15,7 @@ from modelscope.utils.logger import get_logger | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| InputModel = Union[str, Model] | |||
| output_keys = [ | |||
| @@ -13,6 +13,9 @@ PIPELINES = Registry('pipelines') | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.sentence_similarity: | |||
| ('sbert-base-chinese-sentence-similarity', | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| @@ -1,2 +1,3 @@ | |||
| from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,65 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...models import Model | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| __all__ = ['SentenceSimilarityPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| class SentenceSimilarityPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForSentenceSimilarity, str], | |||
| preprocessor: SequenceClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp sentence similarity pipeline for prediction | |||
| Args: | |||
| model (SbertForSentenceSimilarity): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \ | |||
| 'model must be a single str or SbertForSentenceSimilarity' | |||
| sc_model = model if isinstance( | |||
| model, | |||
| SbertForSentenceSimilarity) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence') | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| assert hasattr(self.model, 'id2label'), \ | |||
| 'id2label map should be initalizaed in init function.' | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs['probabilities'][0] | |||
| num_classes = probs.shape[0] | |||
| top_indices = np.argpartition(probs, -num_classes)[-num_classes:] | |||
| cls_ids = top_indices[np.argsort(-probs[top_indices], axis=-1)] | |||
| probs = probs[cls_ids].tolist() | |||
| cls_names = [self.model.id2label[cid] for cid in cls_ids] | |||
| b = 0 | |||
| return {'scores': probs[b], 'labels': cls_names[b]} | |||
| @@ -5,4 +5,3 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .nlp import TextGenerationPreprocessor | |||
| @@ -10,7 +10,10 @@ from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor' | |||
| ] | |||
| @PREPROCESSORS.register_module(Fields.nlp) | |||
| @@ -28,7 +31,7 @@ class Tokenize(Preprocessor): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-sentiment-analysis') | |||
| Fields.nlp, module_name=r'bert-sequence-classification') | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -48,21 +51,42 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
| print(f'this is the tokenzier {self.tokenizer}') | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| @type_assert(object, (str, tuple)) | |||
| def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| data (str or tuple): | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| or | |||
| (sentence1, sentence2) | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| new_data = {self.first_sequence: data} | |||
| if not isinstance(data, tuple): | |||
| data = ( | |||
| data, | |||
| None, | |||
| ) | |||
| sentence1, sentence2 = data | |||
| new_data = { | |||
| self.first_sequence: sentence1, | |||
| self.second_sequence: sentence2 | |||
| } | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| @@ -31,6 +31,7 @@ class Tasks(object): | |||
| # nlp tasks | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| sentence_similarity = 'sentence-similarity' | |||
| text_classification = 'text-classification' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| @@ -0,0 +1,67 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import shutil | |||
| import unittest | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.pipelines import SentenceSimilarityPipeline, pipeline | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.test_utils import test_level | |||
| class SentenceSimilarityTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| sentence1 = '今天气温比昨天高么?' | |||
| sentence2 = '今天湿度比昨天高么?' | |||
| def setUp(self) -> None: | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(cache_path) | |||
| model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer) | |||
| pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.sentence_similarity, model=model, preprocessor=tokenizer) | |||
| print('test1') | |||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') | |||
| print() | |||
| print( | |||
| f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||
| f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = SequenceClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, model=self.model_id) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.sentence_similarity) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||