Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687master
| @@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from maas_lib.models.builder import build_model | from maas_lib.models.builder import build_model | ||||
| from maas_lib.pipelines import util | |||||
| from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
| from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
| @@ -39,8 +40,9 @@ class Model(ABC): | |||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| local_model_dir = snapshot_download(model_name_or_path) | |||||
| cache_path = util.get_model_cache_dir(model_name_or_path) | |||||
| local_model_dir = cache_path if osp.exists( | |||||
| cache_path) else snapshot_download(model_name_or_path) | |||||
| # else: | # else: | ||||
| # raise ValueError( | # raise ValueError( | ||||
| # 'Remote model repo {model_name_or_path} does not exists') | # 'Remote model repo {model_name_or_path} does not exists') | ||||
| @@ -2,16 +2,15 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from multiprocessing.sharedctypes import Value | |||||
| from typing import Any, Dict, Generator, List, Tuple, Union | from typing import Any, Dict, Generator, List, Tuple, Union | ||||
| from ali_maas_datasets import PyDataset | from ali_maas_datasets import PyDataset | ||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| from maas_lib.pipelines import util | |||||
| from maas_lib.preprocessors import Preprocessor | from maas_lib.preprocessors import Preprocessor | ||||
| from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
| from maas_lib.utils.constant import CONFIGFILE | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -31,7 +30,7 @@ class Pipeline(ABC): | |||||
| """ Base class for pipeline. | """ Base class for pipeline. | ||||
| If config_file is provided, model and preprocessor will be | If config_file is provided, model and preprocessor will be | ||||
| instantiated from corresponding config. Otherwise model | |||||
| instantiated from corresponding config. Otherwise, model | |||||
| and preprocessor will be constructed separately. | and preprocessor will be constructed separately. | ||||
| Args: | Args: | ||||
| @@ -44,7 +43,11 @@ class Pipeline(ABC): | |||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| if not osp.exists(model): | if not osp.exists(model): | ||||
| model = snapshot_download(model) | |||||
| cache_path = util.get_model_cache_dir(model) | |||||
| if osp.exists(cache_path): | |||||
| model = cache_path | |||||
| else: | |||||
| model = snapshot_download(model) | |||||
| if is_model_name(model): | if is_model_name(model): | ||||
| self.model = Model.from_pretrained(model) | self.model = Model.from_pretrained(model) | ||||
| @@ -61,7 +64,7 @@ class Pipeline(ABC): | |||||
| def __call__(self, input: Union[Input, List[Input]], *args, | def __call__(self, input: Union[Input, List[Input]], *args, | ||||
| **post_kwargs) -> Union[Dict[str, Any], Generator]: | **post_kwargs) -> Union[Dict[str, Any], Generator]: | ||||
| # moodel provider should leave it as it is | |||||
| # model provider should leave it as it is | |||||
| # maas library developer will handle this function | # maas library developer will handle this function | ||||
| # simple showcase, need to support iterator type for both tensorflow and pytorch | # simple showcase, need to support iterator type for both tensorflow and pytorch | ||||
| @@ -1,12 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import json | import json | ||||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
| # temp solution before the hub-cache is in place | |||||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||||
| model_id_expanded = model_id.replace('/', | |||||
| MODEL_ID_SEPARATOR) + '.' + branch | |||||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||||
| return os.getenv('MAAS_CACHE', | |||||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||||
| def is_model_name(model): | def is_model_name(model): | ||||
| if osp.exists(model): | if osp.exists(model): | ||||
| if osp.exists(osp.join(model, CONFIGFILE)): | if osp.exists(osp.join(model, CONFIGFILE)): | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import shutil | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| @@ -8,12 +9,20 @@ import cv2 | |||||
| from ali_maas_datasets import PyDataset | from ali_maas_datasets import PyDataset | ||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.pipelines import pipeline | |||||
| from maas_lib.pipelines import pipeline, util | |||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/image-matting-person' | |||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| def test_run(self): | def test_run(self): | ||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
| '.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
| @@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase): | |||||
| # input_location = '/dir/to/images' | # input_location = '/dir/to/images' | ||||
| dataset = PyDataset.load(input_location, target='image') | dataset = PyDataset.load(input_location, target='image') | ||||
| img_matting = pipeline( | |||||
| Tasks.image_matting, model='damo/image-matting-person') | |||||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | # note that for dataset output, the inference-output is a Generator that can be iterated. | ||||
| result = img_matting(dataset) | result = img_matting(dataset) | ||||
| cv2.imwrite('result.png', next(result)['output_png']) | cv2.imwrite('result.png', next(result)['output_png']) | ||||
| print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| img_matting = pipeline( | |||||
| Tasks.image_matting, model='damo/image-matting-person') | |||||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||||
| result = img_matting( | result = img_matting( | ||||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import tempfile | |||||
| import os | |||||
| import shutil | |||||
| import unittest | import unittest | ||||
| import zipfile | import zipfile | ||||
| from pathlib import Path | from pathlib import Path | ||||
| @@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset | |||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| from maas_lib.models.nlp import SequenceClassificationModel | from maas_lib.models.nlp import SequenceClassificationModel | ||||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| class SequenceClassificationTest(unittest.TestCase): | class SequenceClassificationTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/bert-base-sst2' | |||||
| # switch to False if downloading everytime is not desired | |||||
| purge_cache = True | |||||
| if purge_cache: | |||||
| shutil.rmtree( | |||||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | def predict(self, pipeline_ins: SequenceClassificationPipeline): | ||||
| from easynlp.appzoo import load_dataset | from easynlp.appzoo import load_dataset | ||||
| @@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| print(pipeline2('Hello world!')) | print(pipeline2('Hello world!')) | ||||
| def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
| model = Model.from_pretrained('damo/bert-base-sst2') | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| model.model_dir, first_sequence='sentence', second_sequence=None) | model.model_dir, first_sequence='sentence', second_sequence=None) | ||||
| pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
| @@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
| text_classification = pipeline( | text_classification = pipeline( | ||||
| task=Tasks.text_classification, model='damo/bert-base-sst2') | |||||
| task=Tasks.text_classification, model=self.model_id) | |||||
| result = text_classification( | result = text_classification( | ||||
| PyDataset.load('glue', name='sst2', target='sentence')) | PyDataset.load('glue', name='sst2', target='sentence')) | ||||
| self.printDataset(result) | self.printDataset(result) | ||||
| def test_run_with_dataset(self): | def test_run_with_dataset(self): | ||||
| model = Model.from_pretrained('damo/bert-base-sst2') | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| model.model_dir, first_sequence='sentence', second_sequence=None) | model.model_dir, first_sequence='sentence', second_sequence=None) | ||||
| text_classification = pipeline( | text_classification = pipeline( | ||||