Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8868644master
| @@ -3,8 +3,9 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from multiprocessing.sharedctypes import Value | from multiprocessing.sharedctypes import Value | ||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| from typing import Any, Dict, Generator, List, Tuple, Union | |||||
| from ali_maas_datasets import PyDataset | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| @@ -14,7 +15,7 @@ from maas_lib.utils.constant import CONFIGFILE | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| Input = Union[str, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||||
| output_keys = [ | output_keys = [ | ||||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ||||
| @@ -59,8 +60,8 @@ class Pipeline(ABC): | |||||
| self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
| def __call__(self, input: Union[Input, List[Input]], *args, | def __call__(self, input: Union[Input, List[Input]], *args, | ||||
| **post_kwargs) -> Dict[str, Any]: | |||||
| # model provider should leave it as it is | |||||
| **post_kwargs) -> Union[Dict[str, Any], Generator]: | |||||
| # moodel provider should leave it as it is | |||||
| # maas library developer will handle this function | # maas library developer will handle this function | ||||
| # simple showcase, need to support iterator type for both tensorflow and pytorch | # simple showcase, need to support iterator type for both tensorflow and pytorch | ||||
| @@ -69,10 +70,18 @@ class Pipeline(ABC): | |||||
| output = [] | output = [] | ||||
| for ele in input: | for ele in input: | ||||
| output.append(self._process_single(ele, *args, **post_kwargs)) | output.append(self._process_single(ele, *args, **post_kwargs)) | ||||
| elif isinstance(input, PyDataset): | |||||
| return self._process_iterator(input, *args, **post_kwargs) | |||||
| else: | else: | ||||
| output = self._process_single(input, *args, **post_kwargs) | output = self._process_single(input, *args, **post_kwargs) | ||||
| return output | return output | ||||
| def _process_iterator(self, input: Input, *args, **post_kwargs): | |||||
| for ele in input: | |||||
| yield self._process_single(ele, *args, **post_kwargs) | |||||
| def _process_single(self, input: Input, *args, | def _process_single(self, input: Input, *args, | ||||
| **post_kwargs) -> Dict[str, Any]: | **post_kwargs) -> Dict[str, Any]: | ||||
| out = self.preprocess(input) | out = self.preprocess(input) | ||||
| @@ -1,2 +1,3 @@ | |||||
| http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/maas_lib-0.1.1-py3-none-any.whl | http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/maas/maas_lib-0.1.1-py3-none-any.whl | ||||
| https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | ||||
| https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl | |||||
| @@ -1,5 +1,6 @@ | |||||
| addict | addict | ||||
| https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | https://maashub.oss-cn-hangzhou.aliyuncs.com/releases/maas_hub-0.1.0.dev0-py2.py3-none-any.whl | ||||
| https://mit-dataset.oss-cn-beijing.aliyuncs.com/release/ali_maas_datasets-0.0.1.dev0-py3-none-any.whl | |||||
| numpy | numpy | ||||
| opencv-python-headless | opencv-python-headless | ||||
| Pillow | Pillow | ||||
| @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Tuple, Union | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | import PIL | ||||
| from ali_maas_datasets import PyDataset | |||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.pipelines import pipeline | from maas_lib.pipelines import pipeline | ||||
| @@ -30,6 +31,25 @@ class ImageMattingTest(unittest.TestCase): | |||||
| ) | ) | ||||
| cv2.imwrite('result.png', result['output_png']) | cv2.imwrite('result.png', result['output_png']) | ||||
| def test_dataset(self): | |||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||||
| with tempfile.TemporaryDirectory() as tmp_dir: | |||||
| model_file = osp.join(tmp_dir, 'matting_person.pb') | |||||
| with open(model_file, 'wb') as ofile: | |||||
| ofile.write(File.read(model_path)) | |||||
| img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | |||||
| # dataset = PyDataset.load('/dir/to/images', target='image') | |||||
| # yapf: disable | |||||
| dataset = PyDataset.load([ | |||||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||||
| ], | |||||
| target='image') | |||||
| result = img_matting(dataset) | |||||
| for i, r in enumerate(result): | |||||
| cv2.imwrite(f'/path/to/result/{i}.png', r['output_png']) | |||||
| print('end') | |||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| img_matting = pipeline( | img_matting = pipeline( | ||||
| Tasks.image_matting, model='damo/image-matting-person') | Tasks.image_matting, model='damo/image-matting-person') | ||||
| @@ -4,6 +4,8 @@ import unittest | |||||
| import zipfile | import zipfile | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from ali_maas_datasets import PyDataset | |||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| from maas_lib.models.nlp import SequenceClassificationModel | from maas_lib.models.nlp import SequenceClassificationModel | ||||
| @@ -58,6 +60,33 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| task='text-classification', model=model, preprocessor=preprocessor) | task='text-classification', model=model, preprocessor=preprocessor) | ||||
| self.predict(pipeline_ins) | self.predict(pipeline_ins) | ||||
| def test_dataset(self): | |||||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||||
| cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||||
| cache_path = Path(cache_path_str) | |||||
| if not cache_path.exists(): | |||||
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |||||
| cache_path.touch(exist_ok=True) | |||||
| with cache_path.open('wb') as ofile: | |||||
| ofile.write(File.read(model_url)) | |||||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||||
| zipf.extractall(cache_path.parent) | |||||
| path = r'.cache/easynlp/bert-base-sst2' | |||||
| model = SequenceClassificationModel(path) | |||||
| preprocessor = SequenceClassificationPreprocessor( | |||||
| path, first_sequence='sentence', second_sequence=None) | |||||
| text_classification = pipeline( | |||||
| 'text-classification', model=model, preprocessor=preprocessor) | |||||
| dataset = PyDataset.load('glue', name='sst2', target='sentence') | |||||
| result = text_classification(dataset) | |||||
| for i, r in enumerate(result): | |||||
| if i > 10: | |||||
| break | |||||
| print(r) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||