diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 9bef4af2..240dc140 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -3,8 +3,9 @@ import os.path as osp from abc import ABC, abstractmethod from multiprocessing.sharedctypes import Value -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, Generator, List, Tuple, Union +from ali_maas_datasets import PyDataset from maas_hub.snapshot_download import snapshot_download from maas_lib.models import Model @@ -14,7 +15,7 @@ from maas_lib.utils.constant import CONFIGFILE from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] -Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] +Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] output_keys = [ ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key @@ -59,8 +60,8 @@ class Pipeline(ABC): self.preprocessor = preprocessor def __call__(self, input: Union[Input, List[Input]], *args, - **post_kwargs) -> Dict[str, Any]: - # model provider should leave it as it is + **post_kwargs) -> Union[Dict[str, Any], Generator]: + # moodel provider should leave it as it is # maas library developer will handle this function # simple showcase, need to support iterator type for both tensorflow and pytorch @@ -69,10 +70,18 @@ class Pipeline(ABC): output = [] for ele in input: output.append(self._process_single(ele, *args, **post_kwargs)) + + elif isinstance(input, PyDataset): + return self._process_iterator(input, *args, **post_kwargs) + else: output = self._process_single(input, *args, **post_kwargs) return output + def _process_iterator(self, input: Input, *args, **post_kwargs): + for ele in input: + yield self._process_single(ele, *args, **post_kwargs) + def _process_single(self, input: Input, *args, **post_kwargs) -> Dict[str, Any]: out = self.preprocess(input) diff --git a/requirements/maas.txt b/requirements/maas.txt index 3b64c375..66b9aeca 100644 --- a/requirements/maas.txt +++ b/requirements/maas.txt @@ -1,2 +1,3 @@ http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/maas_lib-0.1.1-py3-none-any.whl https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl +https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 8f74e780..5d24e660 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,5 +1,6 @@ addict https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl +https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl numpy opencv-python-headless Pillow diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 88360994..8b8672ae 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Tuple, Union import cv2 import numpy as np import PIL +from ali_maas_datasets import PyDataset from maas_lib.fileio import File from maas_lib.pipelines import pipeline @@ -30,6 +31,25 @@ class ImageMattingTest(unittest.TestCase): ) cv2.imwrite('result.png', result['output_png']) + def test_dataset(self): + model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ + '.com/data/test/maas/image_matting/matting_person.pb' + with tempfile.TemporaryDirectory() as tmp_dir: + model_file = osp.join(tmp_dir, 'matting_person.pb') + with open(model_file, 'wb') as ofile: + ofile.write(File.read(model_path)) + img_matting = pipeline(Tasks.image_matting, model=tmp_dir) + # dataset = PyDataset.load('/dir/to/images', target='image') + # yapf: disable + dataset = PyDataset.load([ + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ], + target='image') + result = img_matting(dataset) + for i, r in enumerate(result): + cv2.imwrite(f'/path/to/result/{i}.png', r['output_png']) + print('end') + def test_run_modelhub(self): img_matting = pipeline( Tasks.image_matting, model='damo/image-matting-person') diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index e49c480d..39390a88 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -4,6 +4,8 @@ import unittest import zipfile from pathlib import Path +from ali_maas_datasets import PyDataset + from maas_lib.fileio import File from maas_lib.models import Model from maas_lib.models.nlp import SequenceClassificationModel @@ -58,6 +60,33 @@ class SequenceClassificationTest(unittest.TestCase): task='text-classification', model=model, preprocessor=preprocessor) self.predict(pipeline_ins) + def test_dataset(self): + model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ + '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' + cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' + cache_path = Path(cache_path_str) + + if not cache_path.exists(): + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.touch(exist_ok=True) + with cache_path.open('wb') as ofile: + ofile.write(File.read(model_url)) + + with zipfile.ZipFile(cache_path_str, 'r') as zipf: + zipf.extractall(cache_path.parent) + path = r'.cache/easynlp/bert-base-sst2' + model = SequenceClassificationModel(path) + preprocessor = SequenceClassificationPreprocessor( + path, first_sequence='sentence', second_sequence=None) + text_classification = pipeline( + 'text-classification', model=model, preprocessor=preprocessor) + dataset = PyDataset.load('glue', name='sst2', target='sentence') + result = text_classification(dataset) + for i, r in enumerate(result): + if i > 10: + break + print(r) + if __name__ == '__main__': unittest.main()