Browse Source

[to #42322515]support plain pipeline for bert

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8945177

    * support plain pipeline for bert
master
yingda.chen 3 years ago
parent
commit
e075ad2245
3 changed files with 33 additions and 12 deletions
  1. +1
    -1
      maas_lib/models/nlp/sequence_classification_model.py
  2. +17
    -6
      maas_lib/pipelines/nlp/sequence_classification_pipeline.py
  3. +15
    -5
      tests/pipelines/test_text_classification.py

+ 1
- 1
maas_lib/models/nlp/sequence_classification_model.py View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any, Dict


import numpy as np import numpy as np




+ 17
- 6
maas_lib/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -1,6 +1,6 @@
import os import os
import uuid import uuid
from typing import Any, Dict
from typing import Any, Dict, Union


import json import json
import numpy as np import numpy as np
@@ -8,6 +8,7 @@ import numpy as np
from maas_lib.models.nlp import SequenceClassificationModel from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.utils.constant import Tasks from maas_lib.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline from ..base import Input, Pipeline
from ..builder import PIPELINES from ..builder import PIPELINES


@@ -18,19 +19,29 @@ __all__ = ['SequenceClassificationPipeline']
Tasks.text_classification, module_name=r'bert-sentiment-analysis') Tasks.text_classification, module_name=r'bert-sentiment-analysis')
class SequenceClassificationPipeline(Pipeline): class SequenceClassificationPipeline(Pipeline):


def __init__(self, model: SequenceClassificationModel,
preprocessor: SequenceClassificationPreprocessor, **kwargs):
def __init__(self,
model: Union[SequenceClassificationModel, str],
preprocessor: SequenceClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction


Args: Args:
model (SequenceClassificationModel): a model instance model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
""" """

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
sc_model = model if isinstance(
model,
SequenceClassificationModel) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SequenceClassificationPreprocessor(
sc_model.model_dir,
first_sequence='sentence',
second_sequence=None)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)


from easynlp.utils import io from easynlp.utils import io
self.label_path = os.path.join(model.model_dir, 'label_mapping.json')
self.label_path = os.path.join(sc_model.model_dir,
'label_mapping.json')
with io.open(self.label_path) as f: with io.open(self.label_path) as f:
self.label_mapping = json.load(f) self.label_mapping = json.load(f)
self.label_id_to_name = { self.label_id_to_name = {


+ 15
- 5
tests/pipelines/test_text_classification.py View File

@@ -29,6 +29,12 @@ class SequenceClassificationTest(unittest.TestCase):


print(data) print(data)


def printDataset(self, dataset: PyDataset):
for i, r in enumerate(dataset):
if i > 10:
break
print(r)

def test_run(self): def test_run(self):
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
@@ -53,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase):
Tasks.text_classification, model=model, preprocessor=preprocessor) Tasks.text_classification, model=model, preprocessor=preprocessor)
print(pipeline2('Hello world!')) print(pipeline2('Hello world!'))


def test_run_modelhub(self):
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained('damo/bert-base-sst2') model = Model.from_pretrained('damo/bert-base-sst2')
preprocessor = SequenceClassificationPreprocessor( preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None) model.model_dir, first_sequence='sentence', second_sequence=None)
@@ -63,6 +69,13 @@ class SequenceClassificationTest(unittest.TestCase):
preprocessor=preprocessor) preprocessor=preprocessor)
self.predict(pipeline_ins) self.predict(pipeline_ins)


def test_run_with_model_name(self):
text_classification = pipeline(
task=Tasks.text_classification, model='damo/bert-base-sst2')
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result)

def test_run_with_dataset(self): def test_run_with_dataset(self):
model = Model.from_pretrained('damo/bert-base-sst2') model = Model.from_pretrained('damo/bert-base-sst2')
preprocessor = SequenceClassificationPreprocessor( preprocessor = SequenceClassificationPreprocessor(
@@ -74,10 +87,7 @@ class SequenceClassificationTest(unittest.TestCase):
# TODO: rename parameter as dataset_name and subset_name # TODO: rename parameter as dataset_name and subset_name
dataset = PyDataset.load('glue', name='sst2', target='sentence') dataset = PyDataset.load('glue', name='sst2', target='sentence')
result = text_classification(dataset) result = text_classification(dataset)
for i, r in enumerate(result):
if i > 10:
break
print(r)
self.printDataset(result)




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save