|
|
|
@@ -1,6 +1,6 @@ |
|
|
|
import os |
|
|
|
import uuid |
|
|
|
from typing import Any, Dict |
|
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
|
@@ -8,6 +8,7 @@ import numpy as np |
|
|
|
from maas_lib.models.nlp import SequenceClassificationModel |
|
|
|
from maas_lib.preprocessors import SequenceClassificationPreprocessor |
|
|
|
from maas_lib.utils.constant import Tasks |
|
|
|
from ...models import Model |
|
|
|
from ..base import Input, Pipeline |
|
|
|
from ..builder import PIPELINES |
|
|
|
|
|
|
|
@@ -18,19 +19,29 @@ __all__ = ['SequenceClassificationPipeline'] |
|
|
|
Tasks.text_classification, module_name=r'bert-sentiment-analysis') |
|
|
|
class SequenceClassificationPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, model: SequenceClassificationModel, |
|
|
|
preprocessor: SequenceClassificationPreprocessor, **kwargs): |
|
|
|
def __init__(self, |
|
|
|
model: Union[SequenceClassificationModel, str], |
|
|
|
preprocessor: SequenceClassificationPreprocessor = None, |
|
|
|
**kwargs): |
|
|
|
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction |
|
|
|
|
|
|
|
Args: |
|
|
|
model (SequenceClassificationModel): a model instance |
|
|
|
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance |
|
|
|
""" |
|
|
|
|
|
|
|
super().__init__(model=model, preprocessor=preprocessor, **kwargs) |
|
|
|
sc_model = model if isinstance( |
|
|
|
model, |
|
|
|
SequenceClassificationModel) else Model.from_pretrained(model) |
|
|
|
if preprocessor is None: |
|
|
|
preprocessor = SequenceClassificationPreprocessor( |
|
|
|
sc_model.model_dir, |
|
|
|
first_sequence='sentence', |
|
|
|
second_sequence=None) |
|
|
|
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) |
|
|
|
|
|
|
|
from easynlp.utils import io |
|
|
|
self.label_path = os.path.join(model.model_dir, 'label_mapping.json') |
|
|
|
self.label_path = os.path.join(sc_model.model_dir, |
|
|
|
'label_mapping.json') |
|
|
|
with io.open(self.label_path) as f: |
|
|
|
self.label_mapping = json.load(f) |
|
|
|
self.label_id_to_name = { |
|
|
|
|