Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8945177 * support plain pipeline for bertmaster
| @@ -1,4 +1,4 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from typing import Any, Dict | |||||
| import numpy as np | import numpy as np | ||||
| @@ -1,6 +1,6 @@ | |||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| from typing import Any, Dict | |||||
| from typing import Any, Dict, Union | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| @@ -8,6 +8,7 @@ import numpy as np | |||||
| from maas_lib.models.nlp import SequenceClassificationModel | from maas_lib.models.nlp import SequenceClassificationModel | ||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from ...models import Model | |||||
| from ..base import Input, Pipeline | from ..base import Input, Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| @@ -18,19 +19,29 @@ __all__ = ['SequenceClassificationPipeline'] | |||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | Tasks.text_classification, module_name=r'bert-sentiment-analysis') | ||||
| class SequenceClassificationPipeline(Pipeline): | class SequenceClassificationPipeline(Pipeline): | ||||
| def __init__(self, model: SequenceClassificationModel, | |||||
| preprocessor: SequenceClassificationPreprocessor, **kwargs): | |||||
| def __init__(self, | |||||
| model: Union[SequenceClassificationModel, str], | |||||
| preprocessor: SequenceClassificationPreprocessor = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
| Args: | Args: | ||||
| model (SequenceClassificationModel): a model instance | model (SequenceClassificationModel): a model instance | ||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | ||||
| """ | """ | ||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| sc_model = model if isinstance( | |||||
| model, | |||||
| SequenceClassificationModel) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = SequenceClassificationPreprocessor( | |||||
| sc_model.model_dir, | |||||
| first_sequence='sentence', | |||||
| second_sequence=None) | |||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||||
| from easynlp.utils import io | from easynlp.utils import io | ||||
| self.label_path = os.path.join(model.model_dir, 'label_mapping.json') | |||||
| self.label_path = os.path.join(sc_model.model_dir, | |||||
| 'label_mapping.json') | |||||
| with io.open(self.label_path) as f: | with io.open(self.label_path) as f: | ||||
| self.label_mapping = json.load(f) | self.label_mapping = json.load(f) | ||||
| self.label_id_to_name = { | self.label_id_to_name = { | ||||
| @@ -29,6 +29,12 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| print(data) | print(data) | ||||
| def printDataset(self, dataset: PyDataset): | |||||
| for i, r in enumerate(dataset): | |||||
| if i > 10: | |||||
| break | |||||
| print(r) | |||||
| def test_run(self): | def test_run(self): | ||||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | ||||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | ||||
| @@ -53,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| Tasks.text_classification, model=model, preprocessor=preprocessor) | Tasks.text_classification, model=model, preprocessor=preprocessor) | ||||
| print(pipeline2('Hello world!')) | print(pipeline2('Hello world!')) | ||||
| def test_run_modelhub(self): | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained('damo/bert-base-sst2') | model = Model.from_pretrained('damo/bert-base-sst2') | ||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| model.model_dir, first_sequence='sentence', second_sequence=None) | model.model_dir, first_sequence='sentence', second_sequence=None) | ||||
| @@ -63,6 +69,13 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| preprocessor=preprocessor) | preprocessor=preprocessor) | ||||
| self.predict(pipeline_ins) | self.predict(pipeline_ins) | ||||
| def test_run_with_model_name(self): | |||||
| text_classification = pipeline( | |||||
| task=Tasks.text_classification, model='damo/bert-base-sst2') | |||||
| result = text_classification( | |||||
| PyDataset.load('glue', name='sst2', target='sentence')) | |||||
| self.printDataset(result) | |||||
| def test_run_with_dataset(self): | def test_run_with_dataset(self): | ||||
| model = Model.from_pretrained('damo/bert-base-sst2') | model = Model.from_pretrained('damo/bert-base-sst2') | ||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| @@ -74,10 +87,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| # TODO: rename parameter as dataset_name and subset_name | # TODO: rename parameter as dataset_name and subset_name | ||||
| dataset = PyDataset.load('glue', name='sst2', target='sentence') | dataset = PyDataset.load('glue', name='sst2', target='sentence') | ||||
| result = text_classification(dataset) | result = text_classification(dataset) | ||||
| for i, r in enumerate(result): | |||||
| if i > 10: | |||||
| break | |||||
| print(r) | |||||
| self.printDataset(result) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||