| @@ -5,4 +5,5 @@ from .audio.tts.vocoder import Hifigan16k | |||||
| from .base import Model | from .base import Model | ||||
| from .builder import MODELS, build_model | from .builder import MODELS, build_model | ||||
| from .multi_model import OfaForImageCaptioning | from .multi_model import OfaForImageCaptioning | ||||
| from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity | |||||
| from .nlp import (BertForSequenceClassification, SbertForNLI, | |||||
| SbertForSentenceSimilarity) | |||||
| @@ -1,4 +1,5 @@ | |||||
| from .bert_for_sequence_classification import * # noqa F403 | from .bert_for_sequence_classification import * # noqa F403 | ||||
| from .nli_model import * # noqa F403 | |||||
| from .palm_for_text_generation import * # noqa F403 | from .palm_for_text_generation import * # noqa F403 | ||||
| from .sbert_for_sentence_similarity import * # noqa F403 | from .sbert_for_sentence_similarity import * # noqa F403 | ||||
| from .sbert_for_token_classification import * # noqa F403 | from .sbert_for_token_classification import * # noqa F403 | ||||
| @@ -0,0 +1,84 @@ | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from sofa import SbertConfig, SbertModel | |||||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||||
| from torch import nn | |||||
| from transformers.activations import ACT2FN, get_activation | |||||
| from transformers.models.bert.modeling_bert import SequenceClassifierOutput | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ..base import Model, Tensor | |||||
| from ..builder import MODELS | |||||
| __all__ = ['SbertForNLI'] | |||||
| class SbertTextClassifier(SbertPreTrainedModel): | |||||
| def __init__(self, config): | |||||
| super().__init__(config) | |||||
| self.num_labels = config.num_labels | |||||
| self.config = config | |||||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||||
| def forward(self, input_ids=None, token_type_ids=None): | |||||
| outputs = self.encoder( | |||||
| input_ids, | |||||
| token_type_ids=token_type_ids, | |||||
| return_dict=None, | |||||
| ) | |||||
| pooled_output = outputs[1] | |||||
| pooled_output = self.dropout(pooled_output) | |||||
| logits = self.classifier(pooled_output) | |||||
| return logits | |||||
| @MODELS.register_module( | |||||
| Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') | |||||
| class SbertForNLI(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text generation model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model_dir = model_dir | |||||
| self.model = SbertTextClassifier.from_pretrained( | |||||
| model_dir, num_labels=3) | |||||
| self.model.eval() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||||
| token_type_ids = torch.tensor( | |||||
| input['token_type_ids'], dtype=torch.long) | |||||
| with torch.no_grad(): | |||||
| logits = self.model(input_ids, token_type_ids) | |||||
| probs = logits.softmax(-1).numpy() | |||||
| pred = logits.argmax(-1).numpy() | |||||
| logits = logits.numpy() | |||||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||||
| return res | |||||
| @@ -20,6 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| ('sbert-base-chinese-sentence-similarity', | ('sbert-base-chinese-sentence-similarity', | ||||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | 'damo/nlp_structbert_sentence-similarity_chinese-base'), | ||||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | ||||
| Tasks.nli: ('nlp_structbert_nli_chinese-base', | |||||
| 'damo/nlp_structbert_nli_chinese-base'), | |||||
| Tasks.text_classification: | Tasks.text_classification: | ||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ||||
| Tasks.text_generation: ('palm2.0', | Tasks.text_generation: ('palm2.0', | ||||
| @@ -1,3 +1,4 @@ | |||||
| from .nli_pipeline import * # noqa F403 | |||||
| from .sentence_similarity_pipeline import * # noqa F403 | from .sentence_similarity_pipeline import * # noqa F403 | ||||
| from .sequence_classification_pipeline import * # noqa F403 | from .sequence_classification_pipeline import * # noqa F403 | ||||
| from .text_generation_pipeline import * # noqa F403 | from .text_generation_pipeline import * # noqa F403 | ||||
| @@ -0,0 +1,88 @@ | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| import json | |||||
| import numpy as np | |||||
| from modelscope.models.nlp import SbertForNLI | |||||
| from modelscope.preprocessors import NLIPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from ...models import Model | |||||
| from ..base import Input, Pipeline | |||||
| from ..builder import PIPELINES | |||||
| __all__ = ['NLIPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') | |||||
| class NLIPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[SbertForNLI, str], | |||||
| preprocessor: NLIPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SbertForNLI): a model instance | |||||
| preprocessor (NLIPreprocessor): a preprocessor instance | |||||
| """ | |||||
| assert isinstance(model, str) or isinstance(model, SbertForNLI), \ | |||||
| 'model must be a single str or SbertForNLI' | |||||
| sc_model = model if isinstance( | |||||
| model, SbertForNLI) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = NLIPreprocessor( | |||||
| sc_model.model_dir, | |||||
| first_sequence='first_sequence', | |||||
| second_sequence='second_sequence') | |||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||||
| self.label_path = os.path.join(sc_model.model_dir, | |||||
| 'label_mapping.json') | |||||
| with open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.label_id_to_name = { | |||||
| idx: name | |||||
| for name, idx in self.label_mapping.items() | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| probs = inputs['probabilities'] | |||||
| logits = inputs['logits'] | |||||
| predictions = np.argsort(-probs, axis=-1) | |||||
| preds = predictions[0] | |||||
| b = 0 | |||||
| new_result = list() | |||||
| for pred in preds: | |||||
| new_result.append({ | |||||
| 'pred': self.label_id_to_name[pred], | |||||
| 'prob': float(probs[b][pred]), | |||||
| 'logit': float(logits[b][pred]) | |||||
| }) | |||||
| new_results = list() | |||||
| new_results.append({ | |||||
| 'id': | |||||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||||
| 'output': | |||||
| new_result, | |||||
| 'predictions': | |||||
| new_result[0]['pred'], | |||||
| 'probabilities': | |||||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||||
| 'logits': | |||||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||||
| }) | |||||
| return new_results[0] | |||||
| @@ -7,4 +7,5 @@ from .common import Compose | |||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .multi_model import OfaImageCaptionPreprocessor | from .multi_model import OfaImageCaptionPreprocessor | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .nlp import NLIPreprocessor, TextGenerationPreprocessor | |||||
| from .text_to_speech import * # noqa F403 | from .text_to_speech import * # noqa F403 | ||||
| @@ -12,7 +12,8 @@ from .builder import PREPROCESSORS | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Tokenize', 'SequenceClassificationPreprocessor', | 'Tokenize', 'SequenceClassificationPreprocessor', | ||||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor' | |||||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | |||||
| 'NLIPreprocessor' | |||||
| ] | ] | ||||
| @@ -30,6 +31,77 @@ class Tokenize(Preprocessor): | |||||
| return data | return data | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base') | |||||
| class NLIPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| from sofa import SbertTokenizer | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||||
| @type_assert(object, tuple) | |||||
| def __call__(self, data: tuple) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (tuple): [sentence1, sentence2] | |||||
| sentence1 (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| sentence2 (str): a sentence | |||||
| Example: | |||||
| 'you are so beautiful.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| sentence1, sentence2 = data | |||||
| new_data = { | |||||
| self.first_sequence: sentence1, | |||||
| self.second_sequence: sentence2 | |||||
| } | |||||
| # preprocess the data for the model input | |||||
| rst = { | |||||
| 'id': [], | |||||
| 'input_ids': [], | |||||
| 'attention_mask': [], | |||||
| 'token_type_ids': [] | |||||
| } | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data[self.first_sequence] | |||||
| text_b = new_data[self.second_sequence] | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| text_b, | |||||
| padding=False, | |||||
| truncation=True, | |||||
| max_length=max_seq_length) | |||||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return rst | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.nlp, module_name=r'bert-sequence-classification') | Fields.nlp, module_name=r'bert-sequence-classification') | ||||
| class SequenceClassificationPreprocessor(Preprocessor): | class SequenceClassificationPreprocessor(Preprocessor): | ||||
| @@ -32,6 +32,7 @@ class Tasks(object): | |||||
| # nlp tasks | # nlp tasks | ||||
| word_segmentation = 'word-segmentation' | word_segmentation = 'word-segmentation' | ||||
| nli = 'nli' | |||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| text_classification = 'text-classification' | text_classification = 'text-classification' | ||||
| @@ -0,0 +1,15 @@ | |||||
| from modelscope.models import SbertForNLI | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.preprocessors import NLIPreprocessor | |||||
| model = SbertForNLI('../nlp_structbert_nli_chinese-base') | |||||
| print(model) | |||||
| tokenizer = NLIPreprocessor(model.model_dir) | |||||
| semantic_cls = pipeline('nli', model=model, preprocessor=tokenizer) | |||||
| print(type(semantic_cls)) | |||||
| print( | |||||
| semantic_cls( | |||||
| input=('我想还有一件事也伤害到了老师的招聘,那就是他们在课堂上失去了很多的权威', | |||||
| '教师在课堂上失去权威,导致想要进入这一职业的人减少了。'))) | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.nlp import SbertForNLI | |||||
| from modelscope.pipelines import NLIPipeline, pipeline | |||||
| from modelscope.preprocessors import NLIPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| class NLITest(unittest.TestCase): | |||||
| model_id = 'damo/nlp_structbert_nli_chinese-base' | |||||
| sentence1 = '四川商务职业学院和四川财经职业学院哪个好?' | |||||
| sentence2 = '四川商务职业学院商务管理在哪个校区?' | |||||
| @unittest.skip('skip temporarily to save test time') | |||||
| def test_run_from_local(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| tokenizer = NLIPreprocessor(cache_path) | |||||
| model = SbertForNLI(cache_path, tokenizer=tokenizer) | |||||
| pipeline1 = NLIPipeline(model, preprocessor=tokenizer) | |||||
| pipeline2 = pipeline(Tasks.nli, model=model, preprocessor=tokenizer) | |||||
| print(f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||||
| f'pipeline1:{pipeline1(input=(self.sentence1, self.sentence2))}') | |||||
| print() | |||||
| print( | |||||
| f'sentence1: {self.sentence1}\nsentence2: {self.sentence2}\n' | |||||
| f'pipeline1: {pipeline2(input=(self.sentence1, self.sentence2))}') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| tokenizer = NLIPreprocessor(model.model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.nli, model=model, preprocessor=tokenizer) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline(task=Tasks.nli, model=self.model_id) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.nli) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||