* add preprocessor module * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline * fix UT failed * add test for Compose Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8769235 * add preprocessor module * add test for Compose * fix citest error * fix abs class error * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * refine models and pipeline interface * add pipeline folder structure * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline 1.add preprossor model pipeline for nlp text classification 2. add corresponding test Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8757371 * new nlp pipeline * format pre-commit code * update easynlp pipeline * update model_name for easynlp pipeline; add test for maas_lib/utils/typeassert.py * update test_typeassert.py * refactor code 1. rename typeassert to type_assert 2. use lazy import to make easynlp dependency optional 3. refine image matting UT * fix linter test failed * update requirements.txt * fix UT failed * fix citest script to update requirementsmaster
| @@ -1,4 +1,4 @@ | |||
| pip install -r requirements/runtime.txt | |||
| pip install -r requirements.txt | |||
| pip install -r requirements/tests.txt | |||
| @@ -1 +1,2 @@ | |||
| from .file import File | |||
| from .io import dump, dumps, load | |||
| @@ -123,6 +123,7 @@ class HTTPStorage(Storage): | |||
| """HTTP and HTTPS storage.""" | |||
| def read(self, url): | |||
| # TODO @wenmeng.zwm add progress bar if file is too large | |||
| r = requests.get(url) | |||
| r.raise_for_status() | |||
| return r.content | |||
| @@ -0,0 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .base import Model | |||
| from .builder import MODELS | |||
| @@ -0,0 +1,29 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict, List, Tuple, Union | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| class Model(ABC): | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| return self.post_process(self.forward(input)) | |||
| @abstractmethod | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| pass | |||
| def post_process(self, input: Dict[str, Tensor], | |||
| **kwargs) -> Dict[str, Tensor]: | |||
| # model specific postprocess, implementation is optional | |||
| # will be called in Pipeline and evaluation loop(in the future) | |||
| return input | |||
| @classmethod | |||
| def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||
| raise NotImplementedError('from_preatrained has not been implemented') | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from maas_lib.utils.config import ConfigDict | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| MODELS = Registry('models') | |||
| def build_model(cfg: ConfigDict, | |||
| task_name: str = None, | |||
| default_args: dict = None): | |||
| """ build model given model config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||
| task_name (str, optional): task name, refer to | |||
| :obj:`Tasks` for more details | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg( | |||
| cfg, MODELS, group_key=task_name, default_args=default_args) | |||
| @@ -0,0 +1 @@ | |||
| from .sequence_classification_model import * # noqa F403 | |||
| @@ -0,0 +1,62 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import numpy as np | |||
| import torch | |||
| from maas_lib.utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| __all__ = ['SequenceClassificationModel'] | |||
| @MODELS.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationModel(Model): | |||
| def __init__(self, | |||
| model_dir: str, | |||
| model_cls: Optional[Any] = None, | |||
| *args, | |||
| **kwargs): | |||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
| # Predictor.__init__(self, *args, **kwargs) | |||
| """initilize the sequence classification model from the `model_dir` path | |||
| Args: | |||
| model_dir (str): the model path | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None | |||
| """ | |||
| super().__init__(model_dir, model_cls, *args, **kwargs) | |||
| from easynlp.appzoo import SequenceClassification | |||
| from easynlp.core.predictor import get_model_predictor | |||
| self.model_dir = model_dir | |||
| model_cls = SequenceClassification if not model_cls else model_cls | |||
| self.model = get_model_predictor( | |||
| model_dir=model_dir, | |||
| model_cls=model_cls, | |||
| input_keys=[('input_ids', torch.LongTensor), | |||
| ('attention_mask', torch.LongTensor), | |||
| ('token_type_ids', torch.LongTensor)], | |||
| output_keys=['predictions', 'probabilities', 'logits']) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| return self.model.predict(input) | |||
| ... | |||
| @@ -0,0 +1,6 @@ | |||
| from .audio import * # noqa F403 | |||
| from .base import Pipeline | |||
| from .builder import pipeline | |||
| from .cv import * # noqa F403 | |||
| from .multi_modal import * # noqa F403 | |||
| from .nlp import * # noqa F403 | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| from maas_lib.models import Model | |||
| from maas_lib.preprocessors import Preprocessor | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| output_keys = [ | |||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||
| class Pipeline(ABC): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model: Model = None, | |||
| preprocessor: Preprocessor = None, | |||
| **kwargs): | |||
| self.model = model | |||
| self.preprocessor = preprocessor | |||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||
| **post_kwargs) -> Dict[str, Any]: | |||
| # moodel provider should leave it as it is | |||
| # maas library developer will handle this function | |||
| # simple show case, need to support iterator type for both tensorflow and pytorch | |||
| # input_dict = self._handle_input(input) | |||
| if isinstance(input, list): | |||
| output = [] | |||
| for ele in input: | |||
| output.append(self._process_single(ele, *args, **post_kwargs)) | |||
| else: | |||
| output = self._process_single(input, *args, **post_kwargs) | |||
| return output | |||
| def _process_single(self, input: Input, *args, | |||
| **post_kwargs) -> Dict[str, Any]: | |||
| out = self.preprocess(input) | |||
| out = self.forward(out) | |||
| out = self.postprocess(out, **post_kwargs) | |||
| return out | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||
| """ | |||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | |||
| return self.preprocessor(inputs) | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """ Provide default implementation using self.model and user can reimplement it | |||
| """ | |||
| assert self.model is not None, 'forward method should be implemented' | |||
| return self.model(inputs) | |||
| @abstractmethod | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| raise NotImplementedError('postprocess') | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Union | |||
| from maas_lib.models.base import Model | |||
| from maas_lib.utils.config import ConfigDict | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| from .base import Pipeline | |||
| PIPELINES = Registry('pipelines') | |||
| def build_pipeline(cfg: ConfigDict, | |||
| task_name: str = None, | |||
| default_args: dict = None): | |||
| """ build pipeline given model config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||
| task_name (str, optional): task name, refer to | |||
| :obj:`Tasks` for more details | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg( | |||
| cfg, PIPELINES, group_key=task_name, default_args=default_args) | |||
| def pipeline(task: str = None, | |||
| model: Union[str, Model] = None, | |||
| config_file: str = None, | |||
| pipeline_name: str = None, | |||
| framework: str = None, | |||
| device: int = -1, | |||
| **kwargs) -> Pipeline: | |||
| """ Factory method to build a obj:`Pipeline`. | |||
| Args: | |||
| task (str): Task name defining which pipeline will be returned. | |||
| model (str or obj:`Model`): model name or model object. | |||
| config_file (str, optional): path to config file. | |||
| pipeline_name (str, optional): pipeline class name or alias name. | |||
| framework (str, optional): framework type. | |||
| device (int, optional): which device is used to do inference. | |||
| Return: | |||
| pipeline (obj:`Pipeline`): pipeline object for certain task. | |||
| Examples: | |||
| ```python | |||
| >>> p = pipeline('image-classification') | |||
| >>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||
| >>> # Using model object | |||
| >>> resnet = Model.from_pretrained('Resnet') | |||
| >>> p = pipeline('image-classification', model=resnet) | |||
| """ | |||
| if task is not None and model is None and pipeline_name is None: | |||
| # get default pipeline for this task | |||
| assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||
| pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||
| if pipeline_name is not None: | |||
| cfg = dict(type=pipeline_name, **kwargs) | |||
| return build_pipeline(cfg, task_name=task) | |||
| @@ -0,0 +1 @@ | |||
| from .image_matting import ImageMatting | |||
| @@ -0,0 +1,67 @@ | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import tensorflow as tf | |||
| from cv2 import COLOR_GRAY2RGB | |||
| from maas_lib.pipelines.base import Input | |||
| from maas_lib.preprocessors import load_image | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| if tf.__version__ >= '2.0': | |||
| tf = tf.compat.v1 | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||
| class ImageMatting(Pipeline): | |||
| def __init__(self, model_path: str): | |||
| super().__init__() | |||
| config = tf.ConfigProto(allow_soft_placement=True) | |||
| config.gpu_options.allow_growth = True | |||
| self._session = tf.Session(config=config) | |||
| with self._session.as_default(): | |||
| logger.info(f'loading model from {model_path}') | |||
| with tf.gfile.FastGFile(model_path, 'rb') as f: | |||
| graph_def = tf.GraphDef() | |||
| graph_def.ParseFromString(f.read()) | |||
| tf.import_graph_def(graph_def, name='') | |||
| self.output = self._session.graph.get_tensor_by_name( | |||
| 'output_png:0') | |||
| self.input_name = 'input_image:0' | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = img.astype(np.float) | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| with self._session.as_default(): | |||
| feed_dict = {self.input_name: input['img']} | |||
| output_png = self._session.run(self.output, feed_dict=feed_dict) | |||
| output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) | |||
| return {'output_png': output_png} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1 @@ | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,77 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| from maas_lib.utils.constant import Tasks | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| __all__ = ['SequenceClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationPipeline(Pipeline): | |||
| def __init__(self, model: SequenceClassificationModel, | |||
| preprocessor: SequenceClassificationPreprocessor, **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| from easynlp.utils import io | |||
| self.label_path = os.path.join(model.model_dir, 'label_mapping.json') | |||
| with io.open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.label_id_to_name = { | |||
| idx: name | |||
| for name, idx in self.label_mapping.items() | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the predict results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the predict results | |||
| """ | |||
| probs = inputs['probabilities'] | |||
| logits = inputs['logits'] | |||
| predictions = np.argsort(-probs, axis=-1) | |||
| preds = predictions[0] | |||
| b = 0 | |||
| new_result = list() | |||
| for pred in preds: | |||
| new_result.append({ | |||
| 'pred': self.label_id_to_name[pred], | |||
| 'prob': float(probs[b][pred]), | |||
| 'logit': float(logits[b][pred]) | |||
| }) | |||
| new_results = list() | |||
| new_results.append({ | |||
| 'id': | |||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||
| 'output': | |||
| new_result, | |||
| 'predictions': | |||
| new_result[0]['pred'], | |||
| 'probabilities': | |||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||
| 'logits': | |||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||
| }) | |||
| return new_results[0] | |||
| @@ -0,0 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Dict | |||
| class Preprocessor(ABC): | |||
| def __init__(self, *args, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| pass | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from maas_lib.utils.config import ConfigDict | |||
| from maas_lib.utils.constant import Fields | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| PREPROCESSORS = Registry('preprocessors') | |||
| def build_preprocessor(cfg: ConfigDict, | |||
| field_name: str = None, | |||
| default_args: dict = None): | |||
| """ build preprocesor given model config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||
| field_name (str, optional): application field name, refer to | |||
| :obj:`Fields` for more details | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg( | |||
| cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import time | |||
| from collections.abc import Sequence | |||
| from .builder import PREPROCESSORS, build_preprocessor | |||
| @PREPROCESSORS.register_module() | |||
| class Compose(object): | |||
| """Compose a data pipeline with a sequence of transforms. | |||
| Args: | |||
| transforms (list[dict | callable]): | |||
| Either config dicts of transforms or transform objects. | |||
| profiling (bool, optional): If set True, will profile and | |||
| print preprocess time for each step. | |||
| """ | |||
| def __init__(self, transforms, field_name=None, profiling=False): | |||
| assert isinstance(transforms, Sequence) | |||
| self.profiling = profiling | |||
| self.transforms = [] | |||
| self.field_name = field_name | |||
| for transform in transforms: | |||
| if isinstance(transform, dict): | |||
| if self.field_name is None: | |||
| transform = build_preprocessor(transform, field_name) | |||
| self.transforms.append(transform) | |||
| elif callable(transform): | |||
| self.transforms.append(transform) | |||
| else: | |||
| raise TypeError('transform must be callable or a dict, but got' | |||
| f' {type(transform)}') | |||
| def __call__(self, data): | |||
| for t in self.transforms: | |||
| if self.profiling: | |||
| start = time.time() | |||
| data = t(data) | |||
| if self.profiling: | |||
| print(f'{t} time {time.time()-start}') | |||
| if data is None: | |||
| return None | |||
| return data | |||
| def __repr__(self): | |||
| format_string = self.__class__.__name__ + '(' | |||
| for t in self.transforms: | |||
| format_string += f'\n {t}' | |||
| format_string += '\n)' | |||
| return format_string | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import io | |||
| from typing import Dict, Union | |||
| from PIL import Image, ImageOps | |||
| from maas_lib.fileio import File | |||
| from maas_lib.utils.constant import Fields | |||
| from .builder import PREPROCESSORS | |||
| @PREPROCESSORS.register_module(Fields.image) | |||
| class LoadImage: | |||
| """Load an image from file or url. | |||
| Added or updated keys are "filename", "img", "img_shape", | |||
| "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||
| "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||
| Args: | |||
| mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`. | |||
| to_float32 (bool): Whether to convert the loaded image to a float32 | |||
| numpy array. If set to False, the loaded image is an uint8 array. | |||
| Defaults to False. | |||
| """ | |||
| def __init__(self, mode='rgb'): | |||
| self.mode = mode.upper() | |||
| def __call__(self, input: Union[str, Dict[str, str]]): | |||
| """Call functions to load image and get image meta information. | |||
| Args: | |||
| input (str or dict): input image path or input dict with | |||
| a key `filename`. | |||
| Returns: | |||
| dict: The dict contains loaded image. | |||
| """ | |||
| if isinstance(input, dict): | |||
| image_path_or_url = input['filename'] | |||
| else: | |||
| image_path_or_url = input | |||
| bytes = File.read(image_path_or_url) | |||
| # TODO @wenmeng.zwm add opencv decode as optional | |||
| # we should also look at the input format which is the most commonly | |||
| # used in Mind' image related models | |||
| with io.BytesIO(bytes) as infile: | |||
| img = Image.open(infile) | |||
| img = ImageOps.exif_transpose(img) | |||
| img = img.convert(self.mode) | |||
| results = { | |||
| 'filename': image_path_or_url, | |||
| 'img': img, | |||
| 'img_shape': (img.size[1], img.size[0], 3), | |||
| 'img_field': 'img', | |||
| } | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})') | |||
| return repr_str | |||
| def load_image(image_path_or_url: str) -> Image: | |||
| """ simple interface to load an image from file or url | |||
| Args: | |||
| image_path_or_url (str): image file path or http url | |||
| """ | |||
| loader = LoadImage() | |||
| return loader(image_path_or_url)['img'] | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| from transformers import AutoTokenizer | |||
| from maas_lib.utils.constant import Fields, InputFields | |||
| from maas_lib.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||
| @PREPROCESSORS.register_module(Fields.nlp) | |||
| class Tokenize(Preprocessor): | |||
| def __init__(self, tokenizer_name) -> None: | |||
| self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||
| def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: | |||
| if isinstance(data, str): | |||
| data = {InputFields.text: data} | |||
| token_dict = self._tokenizer(data[InputFields.text]) | |||
| data.update(token_dict) | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from easynlp.modelzoo import AutoTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| new_data = {self.first_sequence: data} | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'id': [], | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [] | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data.get(self.second_sequence, None) | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| padding='max_length', | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| @@ -6,6 +6,7 @@ class Fields(object): | |||
| """ | |||
| image = 'image' | |||
| video = 'video' | |||
| cv = 'cv' | |||
| nlp = 'nlp' | |||
| audio = 'audio' | |||
| multi_modal = 'multi_modal' | |||
| @@ -18,12 +19,41 @@ class Tasks(object): | |||
| This should be used to register models, pipelines, trainers. | |||
| """ | |||
| # vision tasks | |||
| image_to_text = 'image-to-text' | |||
| pose_estimation = 'pose-estimation' | |||
| image_classfication = 'image-classification' | |||
| image_tagging = 'image-tagging' | |||
| object_detection = 'object-detection' | |||
| image_segmentation = 'image-segmentation' | |||
| image_editing = 'image-editing' | |||
| image_generation = 'image-generation' | |||
| image_matting = 'image-matting' | |||
| # nlp tasks | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| fill_mask = 'fill-mask' | |||
| text_classification = 'text-classification' | |||
| relation_extraction = 'relation-extraction' | |||
| zero_shot = 'zero-shot' | |||
| translation = 'translation' | |||
| token_classificatio = 'token-classification' | |||
| conversational = 'conversational' | |||
| text_generation = 'text-generation' | |||
| table_question_answ = 'table-question-answering' | |||
| feature_extraction = 'feature-extraction' | |||
| sentence_similarity = 'sentence-similarity' | |||
| fill_mask = 'fill-mask ' | |||
| summarization = 'summarization' | |||
| question_answering = 'question-answering' | |||
| # audio tasks | |||
| auto_speech_recognition = 'auto-speech-recognition' | |||
| text_to_speech = 'text-to-speech' | |||
| speech_signal_process = 'speech-signal-process' | |||
| # multi-media | |||
| image_captioning = 'image-captioning' | |||
| visual_grounding = 'visual-grounding' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| class InputFields(object): | |||
| @@ -1,5 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import inspect | |||
| from email.policy import default | |||
| from maas_lib.utils.logger import get_logger | |||
| @@ -15,10 +17,10 @@ class Registry(object): | |||
| def __init__(self, name: str): | |||
| self._name = name | |||
| self._modules = dict() | |||
| self._modules = {default_group: {}} | |||
| def __repr__(self): | |||
| format_str = self.__class__.__name__ + f'({self._name})\n' | |||
| format_str = self.__class__.__name__ + f' ({self._name})\n' | |||
| for group_name, group in self._modules.items(): | |||
| format_str += f'group_name={group_name}, '\ | |||
| f'modules={list(group.keys())}\n' | |||
| @@ -64,11 +66,24 @@ class Registry(object): | |||
| module_name = module_cls.__name__ | |||
| if module_name in self._modules[group_key]: | |||
| raise KeyError(f'{module_name} is already registered in' | |||
| raise KeyError(f'{module_name} is already registered in ' | |||
| f'{self._name}[{group_key}]') | |||
| self._modules[group_key][module_name] = module_cls | |||
| if module_name in self._modules[default_group]: | |||
| if id(self._modules[default_group][module_name]) == id(module_cls): | |||
| return | |||
| else: | |||
| logger.warning(f'{module_name} is already registered in ' | |||
| f'{self._name}[{default_group}] and will ' | |||
| 'be overwritten') | |||
| logger.warning(f'{self._modules[default_group][module_name]}' | |||
| 'to {module_cls}') | |||
| # also register module in the default group for faster access | |||
| # only by module name | |||
| self._modules[default_group][module_name] = module_cls | |||
| def register_module(self, | |||
| group_key: str = default_group, | |||
| module_name: str = None, | |||
| @@ -165,12 +180,15 @@ def build_from_cfg(cfg, | |||
| for name, value in default_args.items(): | |||
| args.setdefault(name, value) | |||
| if group_key is None: | |||
| group_key = default_group | |||
| obj_type = args.pop('type') | |||
| if isinstance(obj_type, str): | |||
| obj_cls = registry.get(obj_type, group_key=group_key) | |||
| if obj_cls is None: | |||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | |||
| f'registry group {group_key}') | |||
| f' registry group {group_key}') | |||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |||
| obj_cls = obj_type | |||
| else: | |||
| @@ -0,0 +1,50 @@ | |||
| from functools import wraps | |||
| from inspect import signature | |||
| def type_assert(*ty_args, **ty_kwargs): | |||
| """a decorator which is used to check the types of arguments in a function or class | |||
| Examples: | |||
| >>> @type_assert(str) | |||
| ... def main(a: str, b: list): | |||
| ... print(a, b) | |||
| >>> main(1) | |||
| Argument a must be a str | |||
| >>> @type_assert(str, (int, str)) | |||
| ... def main(a: str, b: int | str): | |||
| ... print(a, b) | |||
| >>> main('1', [1]) | |||
| Argument b must be (<class 'int'>, <class 'str'>) | |||
| >>> @type_assert(str, (int, str)) | |||
| ... class A: | |||
| ... def __init__(self, a: str, b: int | str) | |||
| ... print(a, b) | |||
| >>> a = A('1', [1]) | |||
| Argument b must be (<class 'int'>, <class 'str'>) | |||
| """ | |||
| def decorate(func): | |||
| # If in optimized mode, disable type checking | |||
| if not __debug__: | |||
| return func | |||
| # Map function argument names to supplied types | |||
| sig = signature(func) | |||
| bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| bound_values = sig.bind(*args, **kwargs) | |||
| # Enforce type assertions across supplied arguments | |||
| for name, value in bound_values.arguments.items(): | |||
| if name in bound_types: | |||
| if not isinstance(value, bound_types[name]): | |||
| raise TypeError('Argument {} must be {}'.format( | |||
| name, bound_types[name])) | |||
| return func(*args, **kwargs) | |||
| return wrapper | |||
| return decorate | |||
| @@ -1 +1,2 @@ | |||
| -r requirements/runtime.txt | |||
| -r requirements/pipeline.txt | |||
| @@ -0,0 +1,5 @@ | |||
| http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||
| tensorflow | |||
| torch==1.9.1 | |||
| torchaudio==0.9.1 | |||
| torchvision==0.10.1 | |||
| @@ -1,5 +1,8 @@ | |||
| addict | |||
| numpy | |||
| opencv-python-headless | |||
| Pillow | |||
| pyyaml | |||
| requests | |||
| transformers | |||
| yapf | |||
| @@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||
| [flake8] | |||
| select = B,C,E,F,P,T4,W,B9 | |||
| max-line-length = 120 | |||
| ignore = F401 | |||
| ignore = F401,F821 | |||
| exclude = docs/src,*.pyi,.git | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import numpy as np | |||
| import PIL | |||
| from maas_lib.pipelines import Pipeline, pipeline | |||
| from maas_lib.pipelines.builder import PIPELINES | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.logger import get_logger | |||
| from maas_lib.utils.registry import default_group | |||
| logger = get_logger() | |||
| Input = Union[str, 'PIL.Image', 'numpy.ndarray'] | |||
| class CustomPipelineTest(unittest.TestCase): | |||
| def test_abstract(self): | |||
| @PIPELINES.register_module() | |||
| class CustomPipeline1(Pipeline): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model=None, | |||
| preprocessor=None, | |||
| **kwargs): | |||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||
| with self.assertRaises(TypeError): | |||
| CustomPipeline1() | |||
| def test_custom(self): | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.image_tagging, module_name='custom-image') | |||
| class CustomImagePipeline(Pipeline): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model=None, | |||
| preprocessor=None, | |||
| **kwargs): | |||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||
| def preprocess(self, input: Union[str, | |||
| 'PIL.Image']) -> Dict[str, Any]: | |||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||
| """ | |||
| if not isinstance(input, PIL.Image.Image): | |||
| from maas_lib.preprocessors import load_image | |||
| data_dict = {'img': load_image(input), 'url': input} | |||
| else: | |||
| data_dict = {'img': input} | |||
| return data_dict | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """ Provide default implementation using self.model and user can reimplement it | |||
| """ | |||
| outputs = {} | |||
| if 'url' in inputs: | |||
| outputs['filename'] = inputs['url'] | |||
| img = inputs['img'] | |||
| new_image = img.resize((img.width // 2, img.height // 2)) | |||
| outputs['resize_image'] = np.array(new_image) | |||
| outputs['dummy_result'] = 'dummy_result' | |||
| return outputs | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||
| pipe = pipeline(pipeline_name='custom-image') | |||
| pipe2 = pipeline(Tasks.image_tagging) | |||
| self.assertTrue(type(pipe) is type(pipe2)) | |||
| img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||
| 'aliyuncs.com/data/test/images/image1.jpg' | |||
| output = pipe(img_url) | |||
| self.assertEqual(output['filename'], img_url) | |||
| self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||
| self.assertEqual(output['dummy_result'], 'dummy_result') | |||
| outputs = pipe([img_url for i in range(4)]) | |||
| self.assertEqual(len(outputs), 4) | |||
| for out in outputs: | |||
| self.assertEqual(out['filename'], img_url) | |||
| self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||
| self.assertEqual(out['dummy_result'], 'dummy_result') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import tempfile | |||
| import unittest | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from maas_lib.fileio import File | |||
| from maas_lib.pipelines import pipeline | |||
| from maas_lib.utils.constant import Tasks | |||
| class ImageMattingTest(unittest.TestCase): | |||
| def test_run(self): | |||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||
| with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||
| ofile.write(File.read(model_path)) | |||
| img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from maas_lib.fileio import File | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.pipelines import SequenceClassificationPipeline | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| class SequenceClassificationTest(unittest.TestCase): | |||
| def predict(self, pipeline: SequenceClassificationPipeline): | |||
| from easynlp.appzoo import load_dataset | |||
| set = load_dataset('glue', 'sst2') | |||
| data = set['test']['sentence'][:3] | |||
| results = pipeline(data[0]) | |||
| print(results) | |||
| results = pipeline(data[1]) | |||
| print(results) | |||
| print(data) | |||
| def test_run(self): | |||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||
| tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||
| with open(tmp_file, 'wb') as ofile: | |||
| ofile.write(File.read(model_url)) | |||
| with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||
| zipf.extractall(tmp_dir) | |||
| path = osp.join(tmp_dir, 'bert-base-sst2') | |||
| print(path) | |||
| model = SequenceClassificationModel(path) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| path, first_sequence='sentence', second_sequence=None) | |||
| pipeline = SequenceClassificationPipeline(model, preprocessor) | |||
| self.predict(pipeline) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor | |||
| class ComposeTest(unittest.TestCase): | |||
| def test_compose(self): | |||
| @PREPROCESSORS.register_module() | |||
| class Tmp1(Preprocessor): | |||
| def __call__(self, input): | |||
| input['tmp1'] = 'tmp1' | |||
| return input | |||
| @PREPROCESSORS.register_module() | |||
| class Tmp2(Preprocessor): | |||
| def __call__(self, input): | |||
| input['tmp2'] = 'tmp2' | |||
| return input | |||
| pipeline = [ | |||
| dict(type='Tmp1'), | |||
| dict(type='Tmp2'), | |||
| ] | |||
| trans = Compose(pipeline) | |||
| input = {} | |||
| output = trans(input) | |||
| self.assertEqual(output['tmp1'], 'tmp1') | |||
| self.assertEqual(output['tmp2'], 'tmp2') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from maas_lib.preprocessors import build_preprocessor | |||
| from maas_lib.utils.constant import Fields, InputFields | |||
| from maas_lib.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class NLPPreprocessorTest(unittest.TestCase): | |||
| def test_tokenize(self): | |||
| cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') | |||
| preprocessor = build_preprocessor(cfg, Fields.nlp) | |||
| input = { | |||
| InputFields.text: | |||
| 'Do not meddle in the affairs of wizards, ' | |||
| 'for they are subtle and quick to anger.' | |||
| } | |||
| output = preprocessor(input) | |||
| self.assertTrue(InputFields.text in output) | |||
| self.assertEqual(output['input_ids'], [ | |||
| 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, | |||
| 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 | |||
| ]) | |||
| self.assertEqual( | |||
| output['token_type_ids'], | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |||
| self.assertEqual( | |||
| output['attention_mask'], | |||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase): | |||
| def test_register_class_no_task(self): | |||
| MODELS = Registry('models') | |||
| self.assertTrue(MODELS.name == 'models') | |||
| self.assertTrue(MODELS.modules == {}) | |||
| self.assertEqual(len(MODELS.modules), 0) | |||
| self.assertTrue(default_group in MODELS.modules) | |||
| self.assertTrue(MODELS.modules[default_group] == {}) | |||
| self.assertEqual(len(MODELS.modules), 1) | |||
| @MODELS.register_module(module_name='cls-resnet') | |||
| class ResNetForCls(object): | |||
| @@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase): | |||
| self.assertTrue(Tasks.object_detection in MODELS.modules) | |||
| self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | |||
| self.assertEqual(len(MODELS.modules), 3) | |||
| self.assertEqual(len(MODELS.modules), 4) | |||
| def test_list(self): | |||
| MODELS = Registry('models') | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from typing import List, Union | |||
| from maas_lib.utils.type_assert import type_assert | |||
| class type_assertTest(unittest.TestCase): | |||
| @type_assert(object, list, (int, str)) | |||
| def a(self, a: List[int], b: Union[int, str]): | |||
| print(a, b) | |||
| def test_type_assert(self): | |||
| with self.assertRaises(TypeError): | |||
| self.a([1], 2) | |||
| self.a(1, [123]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||