* add preprocessor module * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline * fix UT failed * add test for Compose Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8769235 * add preprocessor module * add test for Compose * fix citest error * fix abs class error * add model base and builder * update task constant * add load image preprocessor and its dependency * add pipeline interface and UT covered * support default pipeline for task * refine models and pipeline interface * add pipeline folder structure * add image matting pipeline * refine nlp tokenize interface * add nlp pipeline 1.add preprossor model pipeline for nlp text classification 2. add corresponding test Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8757371 * new nlp pipeline * format pre-commit code * update easynlp pipeline * update model_name for easynlp pipeline; add test for maas_lib/utils/typeassert.py * update test_typeassert.py * refactor code 1. rename typeassert to type_assert 2. use lazy import to make easynlp dependency optional 3. refine image matting UT * fix linter test failed * update requirements.txt * fix UT failed * fix citest script to update requirementsmaster
| @@ -1,4 +1,4 @@ | |||||
| pip install -r requirements/runtime.txt | |||||
| pip install -r requirements.txt | |||||
| pip install -r requirements/tests.txt | pip install -r requirements/tests.txt | ||||
| @@ -1 +1,2 @@ | |||||
| from .file import File | |||||
| from .io import dump, dumps, load | from .io import dump, dumps, load | ||||
| @@ -123,6 +123,7 @@ class HTTPStorage(Storage): | |||||
| """HTTP and HTTPS storage.""" | """HTTP and HTTPS storage.""" | ||||
| def read(self, url): | def read(self, url): | ||||
| # TODO @wenmeng.zwm add progress bar if file is too large | |||||
| r = requests.get(url) | r = requests.get(url) | ||||
| r.raise_for_status() | r.raise_for_status() | ||||
| return r.content | return r.content | ||||
| @@ -0,0 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .base import Model | |||||
| from .builder import MODELS | |||||
| @@ -0,0 +1,29 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Dict, List, Tuple, Union | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
| class Model(ABC): | |||||
| def __init__(self, *args, **kwargs): | |||||
| pass | |||||
| def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| return self.post_process(self.forward(input)) | |||||
| @abstractmethod | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| pass | |||||
| def post_process(self, input: Dict[str, Tensor], | |||||
| **kwargs) -> Dict[str, Tensor]: | |||||
| # model specific postprocess, implementation is optional | |||||
| # will be called in Pipeline and evaluation loop(in the future) | |||||
| return input | |||||
| @classmethod | |||||
| def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||||
| raise NotImplementedError('from_preatrained has not been implemented') | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from maas_lib.utils.config import ConfigDict | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||||
| MODELS = Registry('models') | |||||
| def build_model(cfg: ConfigDict, | |||||
| task_name: str = None, | |||||
| default_args: dict = None): | |||||
| """ build model given model config dict | |||||
| Args: | |||||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||||
| task_name (str, optional): task name, refer to | |||||
| :obj:`Tasks` for more details | |||||
| default_args (dict, optional): Default initialization arguments. | |||||
| """ | |||||
| return build_from_cfg( | |||||
| cfg, MODELS, group_key=task_name, default_args=default_args) | |||||
| @@ -0,0 +1 @@ | |||||
| from .sequence_classification_model import * # noqa F403 | |||||
| @@ -0,0 +1,62 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from ..base import Model | |||||
| from ..builder import MODELS | |||||
| __all__ = ['SequenceClassificationModel'] | |||||
| @MODELS.register_module( | |||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
| class SequenceClassificationModel(Model): | |||||
| def __init__(self, | |||||
| model_dir: str, | |||||
| model_cls: Optional[Any] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||||
| # Predictor.__init__(self, *args, **kwargs) | |||||
| """initilize the sequence classification model from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): the model path | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||||
| default loader to load model weights, by default None | |||||
| """ | |||||
| super().__init__(model_dir, model_cls, *args, **kwargs) | |||||
| from easynlp.appzoo import SequenceClassification | |||||
| from easynlp.core.predictor import get_model_predictor | |||||
| self.model_dir = model_dir | |||||
| model_cls = SequenceClassification if not model_cls else model_cls | |||||
| self.model = get_model_predictor( | |||||
| model_dir=model_dir, | |||||
| model_cls=model_cls, | |||||
| input_keys=[('input_ids', torch.LongTensor), | |||||
| ('attention_mask', torch.LongTensor), | |||||
| ('token_type_ids', torch.LongTensor)], | |||||
| output_keys=['predictions', 'probabilities', 'logits']) | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Any]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, np.ndarray]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||||
| } | |||||
| """ | |||||
| return self.model.predict(input) | |||||
| ... | |||||
| @@ -0,0 +1,6 @@ | |||||
| from .audio import * # noqa F403 | |||||
| from .base import Pipeline | |||||
| from .builder import pipeline | |||||
| from .cv import * # noqa F403 | |||||
| from .multi_modal import * # noqa F403 | |||||
| from .nlp import * # noqa F403 | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| from maas_lib.models import Model | |||||
| from maas_lib.preprocessors import Preprocessor | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
| Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| output_keys = [ | |||||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||||
| class Pipeline(ABC): | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model: Model = None, | |||||
| preprocessor: Preprocessor = None, | |||||
| **kwargs): | |||||
| self.model = model | |||||
| self.preprocessor = preprocessor | |||||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||||
| **post_kwargs) -> Dict[str, Any]: | |||||
| # moodel provider should leave it as it is | |||||
| # maas library developer will handle this function | |||||
| # simple show case, need to support iterator type for both tensorflow and pytorch | |||||
| # input_dict = self._handle_input(input) | |||||
| if isinstance(input, list): | |||||
| output = [] | |||||
| for ele in input: | |||||
| output.append(self._process_single(ele, *args, **post_kwargs)) | |||||
| else: | |||||
| output = self._process_single(input, *args, **post_kwargs) | |||||
| return output | |||||
| def _process_single(self, input: Input, *args, | |||||
| **post_kwargs) -> Dict[str, Any]: | |||||
| out = self.preprocess(input) | |||||
| out = self.forward(out) | |||||
| out = self.postprocess(out, **post_kwargs) | |||||
| return out | |||||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||||
| """ | |||||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | |||||
| return self.preprocessor(inputs) | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """ Provide default implementation using self.model and user can reimplement it | |||||
| """ | |||||
| assert self.model is not None, 'forward method should be implemented' | |||||
| return self.model(inputs) | |||||
| @abstractmethod | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| raise NotImplementedError('postprocess') | |||||
| @@ -0,0 +1,65 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Union | |||||
| from maas_lib.models.base import Model | |||||
| from maas_lib.utils.config import ConfigDict | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||||
| from .base import Pipeline | |||||
| PIPELINES = Registry('pipelines') | |||||
| def build_pipeline(cfg: ConfigDict, | |||||
| task_name: str = None, | |||||
| default_args: dict = None): | |||||
| """ build pipeline given model config dict | |||||
| Args: | |||||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||||
| task_name (str, optional): task name, refer to | |||||
| :obj:`Tasks` for more details | |||||
| default_args (dict, optional): Default initialization arguments. | |||||
| """ | |||||
| return build_from_cfg( | |||||
| cfg, PIPELINES, group_key=task_name, default_args=default_args) | |||||
| def pipeline(task: str = None, | |||||
| model: Union[str, Model] = None, | |||||
| config_file: str = None, | |||||
| pipeline_name: str = None, | |||||
| framework: str = None, | |||||
| device: int = -1, | |||||
| **kwargs) -> Pipeline: | |||||
| """ Factory method to build a obj:`Pipeline`. | |||||
| Args: | |||||
| task (str): Task name defining which pipeline will be returned. | |||||
| model (str or obj:`Model`): model name or model object. | |||||
| config_file (str, optional): path to config file. | |||||
| pipeline_name (str, optional): pipeline class name or alias name. | |||||
| framework (str, optional): framework type. | |||||
| device (int, optional): which device is used to do inference. | |||||
| Return: | |||||
| pipeline (obj:`Pipeline`): pipeline object for certain task. | |||||
| Examples: | |||||
| ```python | |||||
| >>> p = pipeline('image-classification') | |||||
| >>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||||
| >>> # Using model object | |||||
| >>> resnet = Model.from_pretrained('Resnet') | |||||
| >>> p = pipeline('image-classification', model=resnet) | |||||
| """ | |||||
| if task is not None and model is None and pipeline_name is None: | |||||
| # get default pipeline for this task | |||||
| assert task in PIPELINES.modules, f'No pipeline is registerd for Task {task}' | |||||
| pipeline_name = list(PIPELINES.modules[task].keys())[0] | |||||
| if pipeline_name is not None: | |||||
| cfg = dict(type=pipeline_name, **kwargs) | |||||
| return build_pipeline(cfg, task_name=task) | |||||
| @@ -0,0 +1 @@ | |||||
| from .image_matting import ImageMatting | |||||
| @@ -0,0 +1,67 @@ | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import tensorflow as tf | |||||
| from cv2 import COLOR_GRAY2RGB | |||||
| from maas_lib.pipelines.base import Input | |||||
| from maas_lib.preprocessors import load_image | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from maas_lib.utils.logger import get_logger | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| if tf.__version__ >= '2.0': | |||||
| tf = tf.compat.v1 | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||||
| class ImageMatting(Pipeline): | |||||
| def __init__(self, model_path: str): | |||||
| super().__init__() | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | |||||
| config.gpu_options.allow_growth = True | |||||
| self._session = tf.Session(config=config) | |||||
| with self._session.as_default(): | |||||
| logger.info(f'loading model from {model_path}') | |||||
| with tf.gfile.FastGFile(model_path, 'rb') as f: | |||||
| graph_def = tf.GraphDef() | |||||
| graph_def.ParseFromString(f.read()) | |||||
| tf.import_graph_def(graph_def, name='') | |||||
| self.output = self._session.graph.get_tensor_by_name( | |||||
| 'output_png:0') | |||||
| self.input_name = 'input_image:0' | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = img.astype(np.float) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| with self._session.as_default(): | |||||
| feed_dict = {self.input_name: input['img']} | |||||
| output_png = self._session.run(self.output, feed_dict=feed_dict) | |||||
| output_png = cv2.cvtColor(output_png, cv2.COLOR_RGBA2BGRA) | |||||
| return {'output_png': output_png} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -0,0 +1 @@ | |||||
| from .sequence_classification_pipeline import * # noqa F403 | |||||
| @@ -0,0 +1,77 @@ | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| from maas_lib.models.nlp import SequenceClassificationModel | |||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from ..base import Input, Pipeline | |||||
| from ..builder import PIPELINES | |||||
| __all__ = ['SequenceClassificationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||||
| class SequenceClassificationPipeline(Pipeline): | |||||
| def __init__(self, model: SequenceClassificationModel, | |||||
| preprocessor: SequenceClassificationPreprocessor, **kwargs): | |||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||||
| Args: | |||||
| model (SequenceClassificationModel): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||||
| """ | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| from easynlp.utils import io | |||||
| self.label_path = os.path.join(model.model_dir, 'label_mapping.json') | |||||
| with io.open(self.label_path) as f: | |||||
| self.label_mapping = json.load(f) | |||||
| self.label_id_to_name = { | |||||
| idx: name | |||||
| for name, idx in self.label_mapping.items() | |||||
| } | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||||
| """process the predict results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the predict results | |||||
| """ | |||||
| probs = inputs['probabilities'] | |||||
| logits = inputs['logits'] | |||||
| predictions = np.argsort(-probs, axis=-1) | |||||
| preds = predictions[0] | |||||
| b = 0 | |||||
| new_result = list() | |||||
| for pred in preds: | |||||
| new_result.append({ | |||||
| 'pred': self.label_id_to_name[pred], | |||||
| 'prob': float(probs[b][pred]), | |||||
| 'logit': float(logits[b][pred]) | |||||
| }) | |||||
| new_results = list() | |||||
| new_results.append({ | |||||
| 'id': | |||||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||||
| 'output': | |||||
| new_result, | |||||
| 'predictions': | |||||
| new_result[0]['pred'], | |||||
| 'probabilities': | |||||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||||
| 'logits': | |||||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||||
| }) | |||||
| return new_results[0] | |||||
| @@ -0,0 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS, build_preprocessor | |||||
| from .common import Compose | |||||
| from .image import LoadImage, load_image | |||||
| from .nlp import * # noqa F403 | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Any, Dict | |||||
| class Preprocessor(ABC): | |||||
| def __init__(self, *args, **kwargs): | |||||
| pass | |||||
| @abstractmethod | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| pass | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from maas_lib.utils.config import ConfigDict | |||||
| from maas_lib.utils.constant import Fields | |||||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||||
| PREPROCESSORS = Registry('preprocessors') | |||||
| def build_preprocessor(cfg: ConfigDict, | |||||
| field_name: str = None, | |||||
| default_args: dict = None): | |||||
| """ build preprocesor given model config dict | |||||
| Args: | |||||
| cfg (:obj:`ConfigDict`): config dict for model object. | |||||
| field_name (str, optional): application field name, refer to | |||||
| :obj:`Fields` for more details | |||||
| default_args (dict, optional): Default initialization arguments. | |||||
| """ | |||||
| return build_from_cfg( | |||||
| cfg, PREPROCESSORS, group_key=field_name, default_args=default_args) | |||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import time | |||||
| from collections.abc import Sequence | |||||
| from .builder import PREPROCESSORS, build_preprocessor | |||||
| @PREPROCESSORS.register_module() | |||||
| class Compose(object): | |||||
| """Compose a data pipeline with a sequence of transforms. | |||||
| Args: | |||||
| transforms (list[dict | callable]): | |||||
| Either config dicts of transforms or transform objects. | |||||
| profiling (bool, optional): If set True, will profile and | |||||
| print preprocess time for each step. | |||||
| """ | |||||
| def __init__(self, transforms, field_name=None, profiling=False): | |||||
| assert isinstance(transforms, Sequence) | |||||
| self.profiling = profiling | |||||
| self.transforms = [] | |||||
| self.field_name = field_name | |||||
| for transform in transforms: | |||||
| if isinstance(transform, dict): | |||||
| if self.field_name is None: | |||||
| transform = build_preprocessor(transform, field_name) | |||||
| self.transforms.append(transform) | |||||
| elif callable(transform): | |||||
| self.transforms.append(transform) | |||||
| else: | |||||
| raise TypeError('transform must be callable or a dict, but got' | |||||
| f' {type(transform)}') | |||||
| def __call__(self, data): | |||||
| for t in self.transforms: | |||||
| if self.profiling: | |||||
| start = time.time() | |||||
| data = t(data) | |||||
| if self.profiling: | |||||
| print(f'{t} time {time.time()-start}') | |||||
| if data is None: | |||||
| return None | |||||
| return data | |||||
| def __repr__(self): | |||||
| format_string = self.__class__.__name__ + '(' | |||||
| for t in self.transforms: | |||||
| format_string += f'\n {t}' | |||||
| format_string += '\n)' | |||||
| return format_string | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import io | |||||
| from typing import Dict, Union | |||||
| from PIL import Image, ImageOps | |||||
| from maas_lib.fileio import File | |||||
| from maas_lib.utils.constant import Fields | |||||
| from .builder import PREPROCESSORS | |||||
| @PREPROCESSORS.register_module(Fields.image) | |||||
| class LoadImage: | |||||
| """Load an image from file or url. | |||||
| Added or updated keys are "filename", "img", "img_shape", | |||||
| "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |||||
| "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |||||
| Args: | |||||
| mode (str): See :ref:`PIL.Mode<https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes>`. | |||||
| to_float32 (bool): Whether to convert the loaded image to a float32 | |||||
| numpy array. If set to False, the loaded image is an uint8 array. | |||||
| Defaults to False. | |||||
| """ | |||||
| def __init__(self, mode='rgb'): | |||||
| self.mode = mode.upper() | |||||
| def __call__(self, input: Union[str, Dict[str, str]]): | |||||
| """Call functions to load image and get image meta information. | |||||
| Args: | |||||
| input (str or dict): input image path or input dict with | |||||
| a key `filename`. | |||||
| Returns: | |||||
| dict: The dict contains loaded image. | |||||
| """ | |||||
| if isinstance(input, dict): | |||||
| image_path_or_url = input['filename'] | |||||
| else: | |||||
| image_path_or_url = input | |||||
| bytes = File.read(image_path_or_url) | |||||
| # TODO @wenmeng.zwm add opencv decode as optional | |||||
| # we should also look at the input format which is the most commonly | |||||
| # used in Mind' image related models | |||||
| with io.BytesIO(bytes) as infile: | |||||
| img = Image.open(infile) | |||||
| img = ImageOps.exif_transpose(img) | |||||
| img = img.convert(self.mode) | |||||
| results = { | |||||
| 'filename': image_path_or_url, | |||||
| 'img': img, | |||||
| 'img_shape': (img.size[1], img.size[0], 3), | |||||
| 'img_field': 'img', | |||||
| } | |||||
| return results | |||||
| def __repr__(self): | |||||
| repr_str = (f'{self.__class__.__name__}(' f'mode={self.mode})') | |||||
| return repr_str | |||||
| def load_image(image_path_or_url: str) -> Image: | |||||
| """ simple interface to load an image from file or url | |||||
| Args: | |||||
| image_path_or_url (str): image file path or http url | |||||
| """ | |||||
| loader = LoadImage() | |||||
| return loader(image_path_or_url)['img'] | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | |||||
| from transformers import AutoTokenizer | |||||
| from maas_lib.utils.constant import Fields, InputFields | |||||
| from maas_lib.utils.type_assert import type_assert | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | |||||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||||
| @PREPROCESSORS.register_module(Fields.nlp) | |||||
| class Tokenize(Preprocessor): | |||||
| def __init__(self, tokenizer_name) -> None: | |||||
| self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |||||
| def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]: | |||||
| if isinstance(data, str): | |||||
| data = {InputFields.text: data} | |||||
| token_dict = self._tokenizer(data[InputFields.text]) | |||||
| data.update(token_dict) | |||||
| return data | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=r'bert-sentiment-analysis') | |||||
| class SequenceClassificationPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| from easynlp.modelzoo import AutoTokenizer | |||||
| self.model_dir: str = model_dir | |||||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||||
| 'first_sequence') | |||||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||||
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) | |||||
| @type_assert(object, str) | |||||
| def __call__(self, data: str) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| new_data = {self.first_sequence: data} | |||||
| # preprocess the data for the model input | |||||
| rst = { | |||||
| 'id': [], | |||||
| 'input_ids': [], | |||||
| 'attention_mask': [], | |||||
| 'token_type_ids': [] | |||||
| } | |||||
| max_seq_length = self.sequence_length | |||||
| text_a = new_data[self.first_sequence] | |||||
| text_b = new_data.get(self.second_sequence, None) | |||||
| feature = self.tokenizer( | |||||
| text_a, | |||||
| text_b, | |||||
| padding='max_length', | |||||
| truncation=True, | |||||
| max_length=max_seq_length) | |||||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||||
| rst['input_ids'].append(feature['input_ids']) | |||||
| rst['attention_mask'].append(feature['attention_mask']) | |||||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||||
| return rst | |||||
| @@ -6,6 +6,7 @@ class Fields(object): | |||||
| """ | """ | ||||
| image = 'image' | image = 'image' | ||||
| video = 'video' | video = 'video' | ||||
| cv = 'cv' | |||||
| nlp = 'nlp' | nlp = 'nlp' | ||||
| audio = 'audio' | audio = 'audio' | ||||
| multi_modal = 'multi_modal' | multi_modal = 'multi_modal' | ||||
| @@ -18,12 +19,41 @@ class Tasks(object): | |||||
| This should be used to register models, pipelines, trainers. | This should be used to register models, pipelines, trainers. | ||||
| """ | """ | ||||
| # vision tasks | # vision tasks | ||||
| image_to_text = 'image-to-text' | |||||
| pose_estimation = 'pose-estimation' | |||||
| image_classfication = 'image-classification' | image_classfication = 'image-classification' | ||||
| image_tagging = 'image-tagging' | |||||
| object_detection = 'object-detection' | object_detection = 'object-detection' | ||||
| image_segmentation = 'image-segmentation' | |||||
| image_editing = 'image-editing' | |||||
| image_generation = 'image-generation' | |||||
| image_matting = 'image-matting' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentiment_analysis = 'sentiment-analysis' | sentiment_analysis = 'sentiment-analysis' | ||||
| fill_mask = 'fill-mask' | |||||
| text_classification = 'text-classification' | |||||
| relation_extraction = 'relation-extraction' | |||||
| zero_shot = 'zero-shot' | |||||
| translation = 'translation' | |||||
| token_classificatio = 'token-classification' | |||||
| conversational = 'conversational' | |||||
| text_generation = 'text-generation' | |||||
| table_question_answ = 'table-question-answering' | |||||
| feature_extraction = 'feature-extraction' | |||||
| sentence_similarity = 'sentence-similarity' | |||||
| fill_mask = 'fill-mask ' | |||||
| summarization = 'summarization' | |||||
| question_answering = 'question-answering' | |||||
| # audio tasks | |||||
| auto_speech_recognition = 'auto-speech-recognition' | |||||
| text_to_speech = 'text-to-speech' | |||||
| speech_signal_process = 'speech-signal-process' | |||||
| # multi-media | |||||
| image_captioning = 'image-captioning' | |||||
| visual_grounding = 'visual-grounding' | |||||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||||
| class InputFields(object): | class InputFields(object): | ||||
| @@ -1,5 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import inspect | import inspect | ||||
| from email.policy import default | |||||
| from maas_lib.utils.logger import get_logger | from maas_lib.utils.logger import get_logger | ||||
| @@ -15,10 +17,10 @@ class Registry(object): | |||||
| def __init__(self, name: str): | def __init__(self, name: str): | ||||
| self._name = name | self._name = name | ||||
| self._modules = dict() | |||||
| self._modules = {default_group: {}} | |||||
| def __repr__(self): | def __repr__(self): | ||||
| format_str = self.__class__.__name__ + f'({self._name})\n' | |||||
| format_str = self.__class__.__name__ + f' ({self._name})\n' | |||||
| for group_name, group in self._modules.items(): | for group_name, group in self._modules.items(): | ||||
| format_str += f'group_name={group_name}, '\ | format_str += f'group_name={group_name}, '\ | ||||
| f'modules={list(group.keys())}\n' | f'modules={list(group.keys())}\n' | ||||
| @@ -64,11 +66,24 @@ class Registry(object): | |||||
| module_name = module_cls.__name__ | module_name = module_cls.__name__ | ||||
| if module_name in self._modules[group_key]: | if module_name in self._modules[group_key]: | ||||
| raise KeyError(f'{module_name} is already registered in' | |||||
| raise KeyError(f'{module_name} is already registered in ' | |||||
| f'{self._name}[{group_key}]') | f'{self._name}[{group_key}]') | ||||
| self._modules[group_key][module_name] = module_cls | self._modules[group_key][module_name] = module_cls | ||||
| if module_name in self._modules[default_group]: | |||||
| if id(self._modules[default_group][module_name]) == id(module_cls): | |||||
| return | |||||
| else: | |||||
| logger.warning(f'{module_name} is already registered in ' | |||||
| f'{self._name}[{default_group}] and will ' | |||||
| 'be overwritten') | |||||
| logger.warning(f'{self._modules[default_group][module_name]}' | |||||
| 'to {module_cls}') | |||||
| # also register module in the default group for faster access | |||||
| # only by module name | |||||
| self._modules[default_group][module_name] = module_cls | |||||
| def register_module(self, | def register_module(self, | ||||
| group_key: str = default_group, | group_key: str = default_group, | ||||
| module_name: str = None, | module_name: str = None, | ||||
| @@ -165,12 +180,15 @@ def build_from_cfg(cfg, | |||||
| for name, value in default_args.items(): | for name, value in default_args.items(): | ||||
| args.setdefault(name, value) | args.setdefault(name, value) | ||||
| if group_key is None: | |||||
| group_key = default_group | |||||
| obj_type = args.pop('type') | obj_type = args.pop('type') | ||||
| if isinstance(obj_type, str): | if isinstance(obj_type, str): | ||||
| obj_cls = registry.get(obj_type, group_key=group_key) | obj_cls = registry.get(obj_type, group_key=group_key) | ||||
| if obj_cls is None: | if obj_cls is None: | ||||
| raise KeyError(f'{obj_type} is not in the {registry.name}' | raise KeyError(f'{obj_type} is not in the {registry.name}' | ||||
| f'registry group {group_key}') | |||||
| f' registry group {group_key}') | |||||
| elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | ||||
| obj_cls = obj_type | obj_cls = obj_type | ||||
| else: | else: | ||||
| @@ -0,0 +1,50 @@ | |||||
| from functools import wraps | |||||
| from inspect import signature | |||||
| def type_assert(*ty_args, **ty_kwargs): | |||||
| """a decorator which is used to check the types of arguments in a function or class | |||||
| Examples: | |||||
| >>> @type_assert(str) | |||||
| ... def main(a: str, b: list): | |||||
| ... print(a, b) | |||||
| >>> main(1) | |||||
| Argument a must be a str | |||||
| >>> @type_assert(str, (int, str)) | |||||
| ... def main(a: str, b: int | str): | |||||
| ... print(a, b) | |||||
| >>> main('1', [1]) | |||||
| Argument b must be (<class 'int'>, <class 'str'>) | |||||
| >>> @type_assert(str, (int, str)) | |||||
| ... class A: | |||||
| ... def __init__(self, a: str, b: int | str) | |||||
| ... print(a, b) | |||||
| >>> a = A('1', [1]) | |||||
| Argument b must be (<class 'int'>, <class 'str'>) | |||||
| """ | |||||
| def decorate(func): | |||||
| # If in optimized mode, disable type checking | |||||
| if not __debug__: | |||||
| return func | |||||
| # Map function argument names to supplied types | |||||
| sig = signature(func) | |||||
| bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| bound_values = sig.bind(*args, **kwargs) | |||||
| # Enforce type assertions across supplied arguments | |||||
| for name, value in bound_values.arguments.items(): | |||||
| if name in bound_types: | |||||
| if not isinstance(value, bound_types[name]): | |||||
| raise TypeError('Argument {} must be {}'.format( | |||||
| name, bound_types[name])) | |||||
| return func(*args, **kwargs) | |||||
| return wrapper | |||||
| return decorate | |||||
| @@ -1 +1,2 @@ | |||||
| -r requirements/runtime.txt | -r requirements/runtime.txt | ||||
| -r requirements/pipeline.txt | |||||
| @@ -0,0 +1,5 @@ | |||||
| http://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/package/whl/easynlp-0.0.3-py2.py3-none-any.whl | |||||
| tensorflow | |||||
| torch==1.9.1 | |||||
| torchaudio==0.9.1 | |||||
| torchvision==0.10.1 | |||||
| @@ -1,5 +1,8 @@ | |||||
| addict | addict | ||||
| numpy | numpy | ||||
| opencv-python-headless | |||||
| Pillow | |||||
| pyyaml | pyyaml | ||||
| requests | requests | ||||
| transformers | |||||
| yapf | yapf | ||||
| @@ -20,5 +20,5 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids | |||||
| [flake8] | [flake8] | ||||
| select = B,C,E,F,P,T4,W,B9 | select = B,C,E,F,P,T4,W,B9 | ||||
| max-line-length = 120 | max-line-length = 120 | ||||
| ignore = F401 | |||||
| ignore = F401,F821 | |||||
| exclude = docs/src,*.pyi,.git | exclude = docs/src,*.pyi,.git | ||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import numpy as np | |||||
| import PIL | |||||
| from maas_lib.pipelines import Pipeline, pipeline | |||||
| from maas_lib.pipelines.builder import PIPELINES | |||||
| from maas_lib.utils.constant import Tasks | |||||
| from maas_lib.utils.logger import get_logger | |||||
| from maas_lib.utils.registry import default_group | |||||
| logger = get_logger() | |||||
| Input = Union[str, 'PIL.Image', 'numpy.ndarray'] | |||||
| class CustomPipelineTest(unittest.TestCase): | |||||
| def test_abstract(self): | |||||
| @PIPELINES.register_module() | |||||
| class CustomPipeline1(Pipeline): | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model=None, | |||||
| preprocessor=None, | |||||
| **kwargs): | |||||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||||
| with self.assertRaises(TypeError): | |||||
| CustomPipeline1() | |||||
| def test_custom(self): | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.image_tagging, module_name='custom-image') | |||||
| class CustomImagePipeline(Pipeline): | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model=None, | |||||
| preprocessor=None, | |||||
| **kwargs): | |||||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||||
| def preprocess(self, input: Union[str, | |||||
| 'PIL.Image']) -> Dict[str, Any]: | |||||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||||
| """ | |||||
| if not isinstance(input, PIL.Image.Image): | |||||
| from maas_lib.preprocessors import load_image | |||||
| data_dict = {'img': load_image(input), 'url': input} | |||||
| else: | |||||
| data_dict = {'img': input} | |||||
| return data_dict | |||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """ Provide default implementation using self.model and user can reimplement it | |||||
| """ | |||||
| outputs = {} | |||||
| if 'url' in inputs: | |||||
| outputs['filename'] = inputs['url'] | |||||
| img = inputs['img'] | |||||
| new_image = img.resize((img.width // 2, img.height // 2)) | |||||
| outputs['resize_image'] = np.array(new_image) | |||||
| outputs['dummy_result'] = 'dummy_result' | |||||
| return outputs | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||||
| pipe = pipeline(pipeline_name='custom-image') | |||||
| pipe2 = pipeline(Tasks.image_tagging) | |||||
| self.assertTrue(type(pipe) is type(pipe2)) | |||||
| img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \ | |||||
| 'aliyuncs.com/data/test/images/image1.jpg' | |||||
| output = pipe(img_url) | |||||
| self.assertEqual(output['filename'], img_url) | |||||
| self.assertEqual(output['resize_image'].shape, (318, 512, 3)) | |||||
| self.assertEqual(output['dummy_result'], 'dummy_result') | |||||
| outputs = pipe([img_url for i in range(4)]) | |||||
| self.assertEqual(len(outputs), 4) | |||||
| for out in outputs: | |||||
| self.assertEqual(out['filename'], img_url) | |||||
| self.assertEqual(out['resize_image'].shape, (318, 512, 3)) | |||||
| self.assertEqual(out['dummy_result'], 'dummy_result') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import tempfile | |||||
| import unittest | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| from maas_lib.fileio import File | |||||
| from maas_lib.pipelines import pipeline | |||||
| from maas_lib.utils.constant import Tasks | |||||
| class ImageMattingTest(unittest.TestCase): | |||||
| def test_run(self): | |||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||||
| with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: | |||||
| ofile.write(File.read(model_path)) | |||||
| img_matting = pipeline(Tasks.image_matting, model_path=ofile.name) | |||||
| result = img_matting( | |||||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||||
| ) | |||||
| cv2.imwrite('result.png', result['output_png']) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import tempfile | |||||
| import unittest | |||||
| import zipfile | |||||
| from maas_lib.fileio import File | |||||
| from maas_lib.models.nlp import SequenceClassificationModel | |||||
| from maas_lib.pipelines import SequenceClassificationPipeline | |||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||||
| class SequenceClassificationTest(unittest.TestCase): | |||||
| def predict(self, pipeline: SequenceClassificationPipeline): | |||||
| from easynlp.appzoo import load_dataset | |||||
| set = load_dataset('glue', 'sst2') | |||||
| data = set['test']['sentence'][:3] | |||||
| results = pipeline(data[0]) | |||||
| print(results) | |||||
| results = pipeline(data[1]) | |||||
| print(results) | |||||
| print(data) | |||||
| def test_run(self): | |||||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||||
| tmp_file = osp.join(tmp_dir, 'bert-base-sst2.zip') | |||||
| with open(tmp_file, 'wb') as ofile: | |||||
| ofile.write(File.read(model_url)) | |||||
| with zipfile.ZipFile(tmp_file, 'r') as zipf: | |||||
| zipf.extractall(tmp_dir) | |||||
| path = osp.join(tmp_dir, 'bert-base-sst2') | |||||
| print(path) | |||||
| model = SequenceClassificationModel(path) | |||||
| preprocessor = SequenceClassificationPreprocessor( | |||||
| path, first_sequence='sentence', second_sequence=None) | |||||
| pipeline = SequenceClassificationPipeline(model, preprocessor) | |||||
| self.predict(pipeline) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_lib.preprocessors import PREPROCESSORS, Compose, Preprocessor | |||||
| class ComposeTest(unittest.TestCase): | |||||
| def test_compose(self): | |||||
| @PREPROCESSORS.register_module() | |||||
| class Tmp1(Preprocessor): | |||||
| def __call__(self, input): | |||||
| input['tmp1'] = 'tmp1' | |||||
| return input | |||||
| @PREPROCESSORS.register_module() | |||||
| class Tmp2(Preprocessor): | |||||
| def __call__(self, input): | |||||
| input['tmp2'] = 'tmp2' | |||||
| return input | |||||
| pipeline = [ | |||||
| dict(type='Tmp1'), | |||||
| dict(type='Tmp2'), | |||||
| ] | |||||
| trans = Compose(pipeline) | |||||
| input = {} | |||||
| output = trans(input) | |||||
| self.assertEqual(output['tmp1'], 'tmp1') | |||||
| self.assertEqual(output['tmp2'], 'tmp2') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from maas_lib.preprocessors import build_preprocessor | |||||
| from maas_lib.utils.constant import Fields, InputFields | |||||
| from maas_lib.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class NLPPreprocessorTest(unittest.TestCase): | |||||
| def test_tokenize(self): | |||||
| cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased') | |||||
| preprocessor = build_preprocessor(cfg, Fields.nlp) | |||||
| input = { | |||||
| InputFields.text: | |||||
| 'Do not meddle in the affairs of wizards, ' | |||||
| 'for they are subtle and quick to anger.' | |||||
| } | |||||
| output = preprocessor(input) | |||||
| self.assertTrue(InputFields.text in output) | |||||
| self.assertEqual(output['input_ids'], [ | |||||
| 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116, | |||||
| 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102 | |||||
| ]) | |||||
| self.assertEqual( | |||||
| output['token_type_ids'], | |||||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |||||
| self.assertEqual( | |||||
| output['attention_mask'], | |||||
| [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -10,8 +10,10 @@ class RegistryTest(unittest.TestCase): | |||||
| def test_register_class_no_task(self): | def test_register_class_no_task(self): | ||||
| MODELS = Registry('models') | MODELS = Registry('models') | ||||
| self.assertTrue(MODELS.name == 'models') | self.assertTrue(MODELS.name == 'models') | ||||
| self.assertTrue(MODELS.modules == {}) | |||||
| self.assertEqual(len(MODELS.modules), 0) | |||||
| self.assertTrue(default_group in MODELS.modules) | |||||
| self.assertTrue(MODELS.modules[default_group] == {}) | |||||
| self.assertEqual(len(MODELS.modules), 1) | |||||
| @MODELS.register_module(module_name='cls-resnet') | @MODELS.register_module(module_name='cls-resnet') | ||||
| class ResNetForCls(object): | class ResNetForCls(object): | ||||
| @@ -47,7 +49,7 @@ class RegistryTest(unittest.TestCase): | |||||
| self.assertTrue(Tasks.object_detection in MODELS.modules) | self.assertTrue(Tasks.object_detection in MODELS.modules) | ||||
| self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR) | ||||
| self.assertEqual(len(MODELS.modules), 3) | |||||
| self.assertEqual(len(MODELS.modules), 4) | |||||
| def test_list(self): | def test_list(self): | ||||
| MODELS = Registry('models') | MODELS = Registry('models') | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from typing import List, Union | |||||
| from maas_lib.utils.type_assert import type_assert | |||||
| class type_assertTest(unittest.TestCase): | |||||
| @type_assert(object, list, (int, str)) | |||||
| def a(self, a: List[int], b: Union[int, str]): | |||||
| print(a, b) | |||||
| def test_type_assert(self): | |||||
| with self.assertRaises(TypeError): | |||||
| self.a([1], 2) | |||||
| self.a(1, [123]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||