Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687master
| @@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models.builder import build_model | |||
| from maas_lib.pipelines import util | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| @@ -39,8 +40,9 @@ class Model(ABC): | |||
| if osp.exists(model_name_or_path): | |||
| local_model_dir = model_name_or_path | |||
| else: | |||
| local_model_dir = snapshot_download(model_name_or_path) | |||
| cache_path = util.get_model_cache_dir(model_name_or_path) | |||
| local_model_dir = cache_path if osp.exists( | |||
| cache_path) else snapshot_download(model_name_or_path) | |||
| # else: | |||
| # raise ValueError( | |||
| # 'Remote model repo {model_name_or_path} does not exists') | |||
| @@ -2,16 +2,15 @@ | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from multiprocessing.sharedctypes import Value | |||
| from typing import Any, Dict, Generator, List, Tuple, Union | |||
| from ali_maas_datasets import PyDataset | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models import Model | |||
| from maas_lib.pipelines import util | |||
| from maas_lib.preprocessors import Preprocessor | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -31,7 +30,7 @@ class Pipeline(ABC): | |||
| """ Base class for pipeline. | |||
| If config_file is provided, model and preprocessor will be | |||
| instantiated from corresponding config. Otherwise model | |||
| instantiated from corresponding config. Otherwise, model | |||
| and preprocessor will be constructed separately. | |||
| Args: | |||
| @@ -44,7 +43,11 @@ class Pipeline(ABC): | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| model = snapshot_download(model) | |||
| cache_path = util.get_model_cache_dir(model) | |||
| if osp.exists(cache_path): | |||
| model = cache_path | |||
| else: | |||
| model = snapshot_download(model) | |||
| if is_model_name(model): | |||
| self.model = Model.from_pretrained(model) | |||
| @@ -61,7 +64,7 @@ class Pipeline(ABC): | |||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||
| **post_kwargs) -> Union[Dict[str, Any], Generator]: | |||
| # moodel provider should leave it as it is | |||
| # model provider should leave it as it is | |||
| # maas library developer will handle this function | |||
| # simple showcase, need to support iterator type for both tensorflow and pytorch | |||
| @@ -1,12 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import json | |||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| # temp solution before the hub-cache is in place | |||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||
| model_id_expanded = model_id.replace('/', | |||
| MODEL_ID_SEPARATOR) + '.' + branch | |||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||
| return os.getenv('MAAS_CACHE', | |||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||
| def is_model_name(model): | |||
| if osp.exists(model): | |||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| @@ -8,12 +9,20 @@ import cv2 | |||
| from ali_maas_datasets import PyDataset | |||
| from maas_lib.fileio import File | |||
| from maas_lib.pipelines import pipeline | |||
| from maas_lib.pipelines import pipeline, util | |||
| from maas_lib.utils.constant import Tasks | |||
| class ImageMattingTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/image-matting-person' | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| def test_run(self): | |||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
| '.com/data/test/maas/image_matting/matting_person.pb' | |||
| @@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase): | |||
| # input_location = '/dir/to/images' | |||
| dataset = PyDataset.load(input_location, target='image') | |||
| img_matting = pipeline( | |||
| Tasks.image_matting, model='damo/image-matting-person') | |||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
| # note that for dataset output, the inference-output is a Generator that can be iterated. | |||
| result = img_matting(dataset) | |||
| cv2.imwrite('result.png', next(result)['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| def test_run_modelhub(self): | |||
| img_matting = pipeline( | |||
| Tasks.image_matting, model='damo/image-matting-person') | |||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import tempfile | |||
| import os | |||
| import shutil | |||
| import unittest | |||
| import zipfile | |||
| from pathlib import Path | |||
| @@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset | |||
| from maas_lib.fileio import File | |||
| from maas_lib.models import Model | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| from maas_lib.utils.constant import Tasks | |||
| class SequenceClassificationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/bert-base-sst2' | |||
| # switch to False if downloading everytime is not desired | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
| from easynlp.appzoo import load_dataset | |||
| @@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| print(pipeline2('Hello world!')) | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained('damo/bert-base-sst2') | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| pipeline_ins = pipeline( | |||
| @@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| def test_run_with_model_name(self): | |||
| text_classification = pipeline( | |||
| task=Tasks.text_classification, model='damo/bert-base-sst2') | |||
| task=Tasks.text_classification, model=self.model_id) | |||
| result = text_classification( | |||
| PyDataset.load('glue', name='sst2', target='sentence')) | |||
| self.printDataset(result) | |||
| def test_run_with_dataset(self): | |||
| model = Model.from_pretrained('damo/bert-base-sst2') | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| model.model_dir, first_sequence='sentence', second_sequence=None) | |||
| text_classification = pipeline( | |||