diff --git a/maas_lib/models/nlp/sequence_classification_model.py b/maas_lib/models/nlp/sequence_classification_model.py index d29587a0..f77b0fbc 100644 --- a/maas_lib/models/nlp/sequence_classification_model.py +++ b/maas_lib/models/nlp/sequence_classification_model.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict import numpy as np diff --git a/maas_lib/pipelines/nlp/sequence_classification_pipeline.py b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py index f3b20f95..9300035d 100644 --- a/maas_lib/pipelines/nlp/sequence_classification_pipeline.py +++ b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Any, Dict +from typing import Any, Dict, Union import json import numpy as np @@ -8,6 +8,7 @@ import numpy as np from maas_lib.models.nlp import SequenceClassificationModel from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.utils.constant import Tasks +from ...models import Model from ..base import Input, Pipeline from ..builder import PIPELINES @@ -18,19 +19,29 @@ __all__ = ['SequenceClassificationPipeline'] Tasks.text_classification, module_name=r'bert-sentiment-analysis') class SequenceClassificationPipeline(Pipeline): - def __init__(self, model: SequenceClassificationModel, - preprocessor: SequenceClassificationPreprocessor, **kwargs): + def __init__(self, + model: Union[SequenceClassificationModel, str], + preprocessor: SequenceClassificationPreprocessor = None, + **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction Args: model (SequenceClassificationModel): a model instance preprocessor (SequenceClassificationPreprocessor): a preprocessor instance """ - - super().__init__(model=model, preprocessor=preprocessor, **kwargs) + sc_model = model if isinstance( + model, + SequenceClassificationModel) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = SequenceClassificationPreprocessor( + sc_model.model_dir, + first_sequence='sentence', + second_sequence=None) + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) from easynlp.utils import io - self.label_path = os.path.join(model.model_dir, 'label_mapping.json') + self.label_path = os.path.join(sc_model.model_dir, + 'label_mapping.json') with io.open(self.label_path) as f: self.label_mapping = json.load(f) self.label_id_to_name = { diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 45b584af..080622d3 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -29,6 +29,12 @@ class SequenceClassificationTest(unittest.TestCase): print(data) + def printDataset(self, dataset: PyDataset): + for i, r in enumerate(dataset): + if i > 10: + break + print(r) + def test_run(self): model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' @@ -53,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase): Tasks.text_classification, model=model, preprocessor=preprocessor) print(pipeline2('Hello world!')) - def test_run_modelhub(self): + def test_run_with_model_from_modelhub(self): model = Model.from_pretrained('damo/bert-base-sst2') preprocessor = SequenceClassificationPreprocessor( model.model_dir, first_sequence='sentence', second_sequence=None) @@ -63,6 +69,13 @@ class SequenceClassificationTest(unittest.TestCase): preprocessor=preprocessor) self.predict(pipeline_ins) + def test_run_with_model_name(self): + text_classification = pipeline( + task=Tasks.text_classification, model='damo/bert-base-sst2') + result = text_classification( + PyDataset.load('glue', name='sst2', target='sentence')) + self.printDataset(result) + def test_run_with_dataset(self): model = Model.from_pretrained('damo/bert-base-sst2') preprocessor = SequenceClassificationPreprocessor( @@ -74,10 +87,7 @@ class SequenceClassificationTest(unittest.TestCase): # TODO: rename parameter as dataset_name and subset_name dataset = PyDataset.load('glue', name='sst2', target='sentence') result = text_classification(dataset) - for i, r in enumerate(result): - if i > 10: - break - print(r) + self.printDataset(result) if __name__ == '__main__':