Browse Source

[to #42322515]support plain pipeline for bert

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8945177

    * support plain pipeline for bert
master
yingda.chen 3 years ago
parent
commit
e075ad2245
3 changed files with 33 additions and 12 deletions
  1. +1
    -1
      maas_lib/models/nlp/sequence_classification_model.py
  2. +17
    -6
      maas_lib/pipelines/nlp/sequence_classification_pipeline.py
  3. +15
    -5
      tests/pipelines/test_text_classification.py

+ 1
- 1
maas_lib/models/nlp/sequence_classification_model.py View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Union
from typing import Any, Dict

import numpy as np



+ 17
- 6
maas_lib/pipelines/nlp/sequence_classification_pipeline.py View File

@@ -1,6 +1,6 @@
import os
import uuid
from typing import Any, Dict
from typing import Any, Dict, Union

import json
import numpy as np
@@ -8,6 +8,7 @@ import numpy as np
from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES

@@ -18,19 +19,29 @@ __all__ = ['SequenceClassificationPipeline']
Tasks.text_classification, module_name=r'bert-sentiment-analysis')
class SequenceClassificationPipeline(Pipeline):

def __init__(self, model: SequenceClassificationModel,
preprocessor: SequenceClassificationPreprocessor, **kwargs):
def __init__(self,
model: Union[SequenceClassificationModel, str],
preprocessor: SequenceClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
sc_model = model if isinstance(
model,
SequenceClassificationModel) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SequenceClassificationPreprocessor(
sc_model.model_dir,
first_sequence='sentence',
second_sequence=None)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)

from easynlp.utils import io
self.label_path = os.path.join(model.model_dir, 'label_mapping.json')
self.label_path = os.path.join(sc_model.model_dir,
'label_mapping.json')
with io.open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {


+ 15
- 5
tests/pipelines/test_text_classification.py View File

@@ -29,6 +29,12 @@ class SequenceClassificationTest(unittest.TestCase):

print(data)

def printDataset(self, dataset: PyDataset):
for i, r in enumerate(dataset):
if i > 10:
break
print(r)

def test_run(self):
model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \
'/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip'
@@ -53,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase):
Tasks.text_classification, model=model, preprocessor=preprocessor)
print(pipeline2('Hello world!'))

def test_run_modelhub(self):
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained('damo/bert-base-sst2')
preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
@@ -63,6 +69,13 @@ class SequenceClassificationTest(unittest.TestCase):
preprocessor=preprocessor)
self.predict(pipeline_ins)

def test_run_with_model_name(self):
text_classification = pipeline(
task=Tasks.text_classification, model='damo/bert-base-sst2')
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result)

def test_run_with_dataset(self):
model = Model.from_pretrained('damo/bert-base-sst2')
preprocessor = SequenceClassificationPreprocessor(
@@ -74,10 +87,7 @@ class SequenceClassificationTest(unittest.TestCase):
# TODO: rename parameter as dataset_name and subset_name
dataset = PyDataset.load('glue', name='sst2', target='sentence')
result = text_classification(dataset)
for i, r in enumerate(result):
if i > 10:
break
print(r)
self.printDataset(result)


if __name__ == '__main__':


Loading…
Cancel
Save