Browse Source

init

master
思宏 3 years ago
parent
commit
414c0c1b3c
9 changed files with 260 additions and 3 deletions
  1. +1
    -1
      modelscope/models/__init__.py
  2. +1
    -0
      modelscope/models/nlp/__init__.py
  3. +83
    -0
      modelscope/models/nlp/nli_model.py
  4. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  5. +88
    -0
      modelscope/pipelines/nlp/nli_pipeline.py
  6. +1
    -1
      modelscope/preprocessors/__init__.py
  7. +72
    -1
      modelscope/preprocessors/nlp.py
  8. +1
    -0
      modelscope/utils/constant.py
  9. +12
    -0
      test.py

+ 1
- 1
modelscope/models/__init__.py View File

@@ -2,4 +2,4 @@

from .base import Model
from .builder import MODELS, build_model
from .nlp import BertForSequenceClassification
from .nlp import BertForSequenceClassification, SbertForNLI

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -1,2 +1,3 @@
from .nli_model import * # noqa F403
from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403

+ 83
- 0
modelscope/models/nlp/nli_model.py View File

@@ -0,0 +1,83 @@
import os
from typing import Any, Dict

import numpy as np
import torch
from sofa import SbertConfig, SbertModel
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel
from torch import nn
from transformers.activations import ACT2FN, get_activation
from transformers.models.bert.modeling_bert import SequenceClassifierOutput

from modelscope.utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['SbertForNLI']


class TextClassifier(SbertPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.encoder = SbertModel(config, add_pooling_layer=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

def forward(self, input_ids=None, token_type_ids=None):
outputs = self.encoder(
input_ids,
token_type_ids=token_type_ids,
return_dict=None,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits


@MODELS.register_module(
Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base')
class SbertForNLI(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir

self.model = TextClassifier.from_pretrained(model_dir, num_labels=3)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
input_ids = torch.tensor(input['input_ids'], dtype=torch.long)
token_type_ids = torch.tensor(
input['token_type_ids'], dtype=torch.long)
with torch.no_grad():
logits = self.model(input_ids, token_type_ids)
probs = logits.softmax(-1).numpy()
pred = logits.argmax(-1).numpy()
logits = logits.numpy()
res = {'predictions': pred, 'probabilities': probs, 'logits': logits}
return res

+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,2 +1,3 @@
from .nli_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403

+ 88
- 0
modelscope/pipelines/nlp/nli_pipeline.py View File

@@ -0,0 +1,88 @@
import os
import uuid
from typing import Any, Dict, Union

import json
import numpy as np

from modelscope.models.nlp import SbertForNLI
from modelscope.preprocessors import NLIPreprocessor
from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES

__all__ = ['NLIPipeline']


@PIPELINES.register_module(
Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base')
class NLIPipeline(Pipeline):

def __init__(self,
model: Union[SbertForNLI, str],
preprocessor: NLIPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SbertForNLI): a model instance
preprocessor (NLIPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForNLI), \
'model must be a single str or SbertForNLI'
sc_model = model if isinstance(model,
SbertForNLI) else SbertForNLI(model)
if preprocessor is None:
preprocessor = NLIPreprocessor(
sc_model.model_dir,
first_sequence='first_sequence',
second_sequence='second_sequence')
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)

self.label_path = os.path.join(sc_model.model_dir,
'label_mapping.json')
with open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {
idx: name
for name, idx in self.label_mapping.items()
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

probs = inputs['probabilities']
logits = inputs['logits']
predictions = np.argsort(-probs, axis=-1)
preds = predictions[0]
b = 0
new_result = list()
for pred in preds:
new_result.append({
'pred': self.label_id_to_name[pred],
'prob': float(probs[b][pred]),
'logit': float(logits[b][pred])
})
new_results = list()
new_results.append({
'id':
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
'output':
new_result,
'predictions':
new_result[0]['pred'],
'probabilities':
','.join([str(t) for t in inputs['probabilities'][b]]),
'logits':
','.join([str(t) for t in inputs['logits'][b]])
})

return new_results[0]

+ 1
- 1
modelscope/preprocessors/__init__.py View File

@@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .nlp import * # noqa F403
from .nlp import TextGenerationPreprocessor
from .nlp import NLIPreprocessor, TextGenerationPreprocessor

+ 72
- 1
modelscope/preprocessors/nlp.py View File

@@ -10,7 +10,7 @@ from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS

__all__ = ['Tokenize', 'SequenceClassificationPreprocessor']
__all__ = ['Tokenize', 'SequenceClassificationPreprocessor', 'NLIPreprocessor']


@PREPROCESSORS.register_module(Fields.nlp)
@@ -27,6 +27,77 @@ class Tokenize(Preprocessor):
return data


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base')
class NLIPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, tuple)
def __call__(self, data: tuple) -> Dict[str, Any]:
"""process the raw input data

Args:
data (tuple): [sentence1, sentence2]
sentence1 (str): a sentence
Example:
'you are so handsome.'
sentence2 (str): a sentence
Example:
'you are so beautiful.'

Returns:
Dict[str, Any]: the preprocessed data
"""
sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
# preprocess the data for the model input

rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': []
}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
text_b = new_data[self.second_sequence]
feature = self.tokenizer(
text_a,
text_b,
padding=False,
truncation=True,
max_length=max_seq_length)

rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])

return rst


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-sentiment-analysis')
class SequenceClassificationPreprocessor(Preprocessor):


+ 1
- 0
modelscope/utils/constant.py View File

@@ -30,6 +30,7 @@ class Tasks(object):
image_matting = 'image-matting'

# nlp tasks
nli = 'nli'
sentiment_analysis = 'sentiment-analysis'
text_classification = 'text-classification'
relation_extraction = 'relation-extraction'


+ 12
- 0
test.py View File

@@ -0,0 +1,12 @@
from modelscope.models import SbertForNLI
from modelscope.pipelines import pipeline
from modelscope.preprocessors import NLIPreprocessor

model = SbertForNLI('../nlp_structbert_nli_chinese-base')
print(model)
tokenizer = NLIPreprocessor(model.model_dir)

semantic_cls = pipeline('nli', model=model, preprocessor=tokenizer)
print(type(semantic_cls))

print(semantic_cls(input=('相反,这表明克林顿的敌人是疯子。', '四川商务职业学院商务管理在哪个校区?')))

Loading…
Cancel
Save