diff --git a/maas_lib/models/base.py b/maas_lib/models/base.py index cc6c4ec8..677a136a 100644 --- a/maas_lib/models/base.py +++ b/maas_lib/models/base.py @@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download from maas_hub.snapshot_download import snapshot_download from maas_lib.models.builder import build_model +from maas_lib.pipelines import util from maas_lib.utils.config import Config from maas_lib.utils.constant import CONFIGFILE @@ -39,8 +40,9 @@ class Model(ABC): if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: - - local_model_dir = snapshot_download(model_name_or_path) + cache_path = util.get_model_cache_dir(model_name_or_path) + local_model_dir = cache_path if osp.exists( + cache_path) else snapshot_download(model_name_or_path) # else: # raise ValueError( # 'Remote model repo {model_name_or_path} does not exists') diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 240dc140..3b1103f6 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -2,16 +2,15 @@ import os.path as osp from abc import ABC, abstractmethod -from multiprocessing.sharedctypes import Value from typing import Any, Dict, Generator, List, Tuple, Union from ali_maas_datasets import PyDataset from maas_hub.snapshot_download import snapshot_download from maas_lib.models import Model +from maas_lib.pipelines import util from maas_lib.preprocessors import Preprocessor from maas_lib.utils.config import Config -from maas_lib.utils.constant import CONFIGFILE from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -31,7 +30,7 @@ class Pipeline(ABC): """ Base class for pipeline. If config_file is provided, model and preprocessor will be - instantiated from corresponding config. Otherwise model + instantiated from corresponding config. Otherwise, model and preprocessor will be constructed separately. Args: @@ -44,7 +43,11 @@ class Pipeline(ABC): if isinstance(model, str): if not osp.exists(model): - model = snapshot_download(model) + cache_path = util.get_model_cache_dir(model) + if osp.exists(cache_path): + model = cache_path + else: + model = snapshot_download(model) if is_model_name(model): self.model = Model.from_pretrained(model) @@ -61,7 +64,7 @@ class Pipeline(ABC): def __call__(self, input: Union[Input, List[Input]], *args, **post_kwargs) -> Union[Dict[str, Any], Generator]: - # moodel provider should leave it as it is + # model provider should leave it as it is # maas library developer will handle this function # simple showcase, need to support iterator type for both tensorflow and pytorch diff --git a/maas_lib/pipelines/util.py b/maas_lib/pipelines/util.py index 3e907359..4a0a28ec 100644 --- a/maas_lib/pipelines/util.py +++ b/maas_lib/pipelines/util.py @@ -1,12 +1,23 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import os.path as osp import json +from maas_hub.constants import MODEL_ID_SEPARATOR from maas_hub.file_download import model_file_download from maas_lib.utils.constant import CONFIGFILE +# temp solution before the hub-cache is in place +def get_model_cache_dir(model_id: str, branch: str = 'master'): + model_id_expanded = model_id.replace('/', + MODEL_ID_SEPARATOR) + '.' + branch + default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) + return os.getenv('MAAS_CACHE', + os.path.join(default_cache_dir, 'hub', model_id_expanded)) + + def is_model_name(model): if osp.exists(model): if osp.exists(osp.join(model, CONFIGFILE)): diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 26847389..1713b34e 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import os import os.path as osp +import shutil import tempfile import unittest @@ -8,12 +9,20 @@ import cv2 from ali_maas_datasets import PyDataset from maas_lib.fileio import File -from maas_lib.pipelines import pipeline +from maas_lib.pipelines import pipeline, util from maas_lib.utils.constant import Tasks class ImageMattingTest(unittest.TestCase): + def setUp(self) -> None: + self.model_id = 'damo/image-matting-person' + # switch to False if downloading everytime is not desired + purge_cache = True + if purge_cache: + shutil.rmtree( + util.get_model_cache_dir(self.model_id), ignore_errors=True) + def test_run(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' @@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase): # input_location = '/dir/to/images' dataset = PyDataset.load(input_location, target='image') - img_matting = pipeline( - Tasks.image_matting, model='damo/image-matting-person') + img_matting = pipeline(Tasks.image_matting, model=self.model_id) # note that for dataset output, the inference-output is a Generator that can be iterated. result = img_matting(dataset) cv2.imwrite('result.png', next(result)['output_png']) print(f'Output written to {osp.abspath("result.png")}') def test_run_modelhub(self): - img_matting = pipeline( - Tasks.image_matting, model='damo/image-matting-person') + img_matting = pipeline(Tasks.image_matting, model=self.model_id) result = img_matting( 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 080622d3..cbdd8964 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import tempfile +import os +import shutil import unittest import zipfile from pathlib import Path @@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset from maas_lib.fileio import File from maas_lib.models import Model from maas_lib.models.nlp import SequenceClassificationModel -from maas_lib.pipelines import SequenceClassificationPipeline, pipeline +from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.utils.constant import Tasks class SequenceClassificationTest(unittest.TestCase): + def setUp(self) -> None: + self.model_id = 'damo/bert-base-sst2' + # switch to False if downloading everytime is not desired + purge_cache = True + if purge_cache: + shutil.rmtree( + util.get_model_cache_dir(self.model_id), ignore_errors=True) + def predict(self, pipeline_ins: SequenceClassificationPipeline): from easynlp.appzoo import load_dataset @@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): print(pipeline2('Hello world!')) def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained('damo/bert-base-sst2') + model = Model.from_pretrained(self.model_id) preprocessor = SequenceClassificationPreprocessor( model.model_dir, first_sequence='sentence', second_sequence=None) pipeline_ins = pipeline( @@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase): def test_run_with_model_name(self): text_classification = pipeline( - task=Tasks.text_classification, model='damo/bert-base-sst2') + task=Tasks.text_classification, model=self.model_id) result = text_classification( PyDataset.load('glue', name='sst2', target='sentence')) self.printDataset(result) def test_run_with_dataset(self): - model = Model.from_pretrained('damo/bert-base-sst2') + model = Model.from_pretrained(self.model_id) preprocessor = SequenceClassificationPreprocessor( model.model_dir, first_sequence='sentence', second_sequence=None) text_classification = pipeline(