diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 2ed6bb3d..9eada04b 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -141,7 +141,9 @@ def model_file_download( cached_file_path = cache.get_file_by_path_and_commit_id( file_path, revision) if cached_file_path is not None: - logger.info('The specified file is in cache, skip downloading!') + file_name = os.path.basename(cached_file_path) + logger.info( + f'File {file_name} already in cache, skip downloading!') return cached_file_path # the file is in cache. is_commit_id = True # we need to download again diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 38514197..b52cb42d 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -1,13 +1,11 @@ import os import tempfile -from glob import glob from pathlib import Path from typing import Dict, Optional, Union from modelscope.utils.logger import get_logger from .api import HubApi, ModelScopeConfig -from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR -from .errors import NotExistError, RequestError, raise_on_error +from .errors import NotExistError from .file_download import (get_file_download_url, http_get_file, http_user_agent) from .utils.caching import ModelFileSystemCache @@ -98,8 +96,9 @@ def snapshot_download(model_id: str, continue # check model_file is exist in cache, if exist, skip download, otherwise download if cache.exists(model_file): + file_name = os.path.basename(model_file['Name']) logger.info( - 'The specified file is in cache, skip downloading!') + f'File {file_name} already in cache, skip downloading!') continue # get download url diff --git a/modelscope/models/base.py b/modelscope/models/base.py index bb8cd1cd..03cc2d4d 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -33,7 +33,7 @@ class Model(ABC): standard model outputs. Args: - inputs: input data + input: input data Return: dict of results: a dict containing outputs of model, each @@ -50,9 +50,17 @@ class Model(ABC): """ Instantiate a model from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. """ + prefetched = kwargs.get('model_prefetched') + if prefetched is not None: + kwargs.pop('model_prefetched') + if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: + if prefetched is True: + raise RuntimeError( + 'Expecting model is pre-fetched locally, but is not found.' + ) local_model_dir = snapshot_download(model_name_or_path, revision) logger.info(f'initialize model from {local_model_dir}') cfg = Config.from_file( diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 536a536a..30d129ca 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -37,7 +37,8 @@ class ANSPipeline(Pipeline): SAMPLE_RATE = 16000 def __init__(self, model): - r""" + """ + use `model` and `preprocessor` to create a kws pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py index 27b31386..390df485 100644 --- a/modelscope/pipelines/audio/kws_kwsbp_pipeline.py +++ b/modelscope/pipelines/audio/kws_kwsbp_pipeline.py @@ -1,7 +1,4 @@ -import io import os -import shutil -import stat import subprocess from typing import Any, Dict, List, Union @@ -28,7 +25,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): model: Union[Model, str] = None, preprocessor: WavToLists = None, **kwargs): - """use `model` and `preprocessor` to create a kws pipeline for prediction + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. """ model = model if isinstance(model, @@ -39,6 +39,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): model=model, preprocessor=preprocessor, **kwargs) + assert model is not None, 'kws model should be provided' self._preprocessor = preprocessor diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py index 3539fe81..ea07565e 100644 --- a/modelscope/pipelines/audio/linear_aec_pipeline.py +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -63,7 +63,8 @@ class LinearAECPipeline(Pipeline): """ def __init__(self, model): - r""" + """ + use `model` and `preprocessor` to create a kws pipeline for prediction Args: model: model id on modelscope hub. """ diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py index 8ac92118..142d697d 100644 --- a/modelscope/pipelines/audio/text_to_speech_pipeline.py +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -1,17 +1,12 @@ -import time from typing import Any, Dict, List import numpy as np from modelscope.metainfo import Pipelines -from modelscope.models import Model -from modelscope.models.audio.tts.am import SambertNetHifi16k -from modelscope.models.audio.tts.vocoder import Hifigan16k from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import (Preprocessor, TextToTacotronSymbols, - build_preprocessor) -from modelscope.utils.constant import Fields, Tasks +from modelscope.preprocessors import Preprocessor, TextToTacotronSymbols +from modelscope.utils.constant import Tasks __all__ = ['TextToSpeechSambertHifigan16kPipeline'] @@ -24,6 +19,11 @@ class TextToSpeechSambertHifigan16kPipeline(Pipeline): model: List[str] = None, preprocessor: Preprocessor = None, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ assert len(model) == 3, 'model number should be 3' if preprocessor is None: lang_type = 'pinyin' diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 4052d35a..b2d17777 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -17,9 +17,6 @@ Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray'] InputModel = Union[str, Model] -output_keys = [ -] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key - logger = get_logger() @@ -28,9 +25,9 @@ class Pipeline(ABC): def initiate_single_model(self, model): logger.info(f'initiate model from {model}') if isinstance(model, str) and is_official_hub_path(model): - model = snapshot_download( - model) if not osp.exists(model) else model - return Model.from_pretrained(model) if is_model(model) else model + # expecting model has been prefetched to local cache beforehand + return Model.from_pretrained( + model, model_prefetched=True) if is_model(model) else model elif isinstance(model, Model): return model else: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index e2257ff4..ead7c521 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -1,7 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import List, Union +from typing import List, Optional, Union +from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Pipelines from modelscope.models.base import Model from modelscope.utils.config import Config, ConfigDict @@ -67,6 +68,21 @@ DEFAULT_MODEL_FOR_PIPELINE = { } +def normalize_model_input(model, model_revision): + """ normalize the input model, to ensure that a model str is a valid local path: in other words, + for model represented by a model id, the model shall be downloaded locally + """ + if isinstance(model, str) and is_official_hub_path(model, model_revision): + # note that if there is already a local copy, snapshot_download will check and skip downloading + model = snapshot_download(model, revision=model_revision) + elif isinstance(model, list) and isinstance(model[0], str): + for idx in range(len(model)): + if is_official_hub_path(model[idx], model_revision): + model[idx] = snapshot_download( + model[idx], revision=model_revision) + return model + + def build_pipeline(cfg: ConfigDict, task_name: str = None, default_args: dict = None): @@ -89,8 +105,9 @@ def pipeline(task: str = None, pipeline_name: str = None, framework: str = None, device: int = -1, + model_revision: Optional[str] = 'master', **kwargs) -> Pipeline: - """ Factory method to build a obj:`Pipeline`. + """ Factory method to build an obj:`Pipeline`. Args: @@ -100,6 +117,8 @@ def pipeline(task: str = None, config_file (str, optional): path to config file. pipeline_name (str, optional): pipeline class name or alias name. framework (str, optional): framework type. + model_revision: revision of model(s) if getting from model hub, for multiple models, expecting + all models to have the same revision device (int, optional): which device is used to do inference. Return: @@ -123,14 +142,18 @@ def pipeline(task: str = None, assert isinstance(model, (type(None), str, Model, list)), \ f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' + model = normalize_model_input(model, model_revision) + if pipeline_name is None: # get default pipeline for this task if isinstance(model, str) \ or (isinstance(model, list) and isinstance(model[0], str)): - if is_official_hub_path(model): + if is_official_hub_path(model, revision=model_revision): # read config file from hub and parse - cfg = read_config(model) if isinstance( - model, str) else read_config(model[0]) + cfg = read_config( + model, revision=model_revision) if isinstance( + model, str) else read_config( + model[0], revision=model_revision) assert hasattr( cfg, 'pipeline'), 'pipeline config is missing from config file.' @@ -152,7 +175,7 @@ def pipeline(task: str = None, pipeline_name = first_model.pipeline.type else: pipeline_name, default_model_repo = get_default_pipeline_info(task) - model = default_model_repo + model = normalize_model_input(default_model_repo, model_revision) cfg = ConfigDict(type=pipeline_name, model=model) diff --git a/modelscope/pipelines/cv/action_recognition_pipeline.py b/modelscope/pipelines/cv/action_recognition_pipeline.py index 40cd0b50..ad453fd8 100644 --- a/modelscope/pipelines/cv/action_recognition_pipeline.py +++ b/modelscope/pipelines/cv/action_recognition_pipeline.py @@ -23,6 +23,11 @@ logger = get_logger() class ActionRecognitionPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) logger.info(f'loading model from {model_path}') diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recog_pipeline.py index dd68dab6..2a5f3094 100644 --- a/modelscope/pipelines/cv/animal_recog_pipeline.py +++ b/modelscope/pipelines/cv/animal_recog_pipeline.py @@ -1,5 +1,4 @@ import os.path as osp -import tempfile from typing import Any, Dict import cv2 @@ -8,13 +7,12 @@ import torch from PIL import Image from torchvision import transforms -from modelscope.fileio import File from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Pipelines from modelscope.models.cv.animal_recognition import resnet from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image -from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES @@ -28,6 +26,11 @@ logger = get_logger() class AnimalRecogPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) import torch diff --git a/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py b/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py index c3a73bc6..850f2914 100644 --- a/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py +++ b/modelscope/pipelines/cv/cmdssl_video_embedding_pipleline.py @@ -1,11 +1,8 @@ -import math import os.path as osp from typing import Any, Dict -import cv2 import decord import numpy as np -import PIL import torch import torchvision.transforms.functional as TF from decord import VideoReader, cpu @@ -30,6 +27,11 @@ logger = get_logger() class CMDSSLVideoEmbeddingPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) logger.info(f'loading model from {model_path}') diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index f6fd3ee2..1b064e66 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -31,6 +31,11 @@ logger = get_logger() class ImageCartoonPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) self.facer = FaceAna(self.model) self.sess_anime_head = self.load_sess( diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 140d28d7..c645daa2 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -22,6 +22,11 @@ logger = get_logger() class ImageMattingPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) import tensorflow as tf if tf.__version__ >= '2.0': diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 6b259eaf..0333400d 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -1,8 +1,5 @@ -import math -import os import os.path as osp -import sys -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict import cv2 import numpy as np @@ -48,6 +45,11 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, class OCRDetectionPipeline(Pipeline): def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ super().__init__(model=model) tf.reset_default_graph() model_path = osp.join( diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 9f32caf4..039f61dd 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -18,7 +18,12 @@ class ImageCaptionPipeline(Pipeline): model: Union[Model, str], preprocessor: [Preprocessor] = None, **kwargs): - super().__init__() + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) assert isinstance(model, str) or isinstance(model, Model), \ 'model must be a single str or OfaForImageCaptioning' if isinstance(model, str): diff --git a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py index a21ecc79..ae00e275 100644 --- a/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py +++ b/modelscope/pipelines/multi_modal/multi_modal_embedding_pipeline.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input @@ -15,6 +15,11 @@ logger = get_logger() class MultiModalEmbeddingPipeline(Pipeline): def __init__(self, model: str, device_id: int = -1): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ if isinstance(model, str): pipe_model = Model.from_pretrained(model) elif isinstance(model, Model): diff --git a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py index 603a86fd..5b8d43d1 100644 --- a/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py +++ b/modelscope/pipelines/multi_modal/text_to_image_synthesis_pipeline.py @@ -19,6 +19,11 @@ logger = get_logger() class TextToImageSynthesisPipeline(Pipeline): def __init__(self, model: str, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ device_id = 0 if torch.cuda.is_available() else -1 if isinstance(model, str): pipe_model = Model.from_pretrained(model, device_id=device_id) diff --git a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py index 80a0f783..ed7a826b 100644 --- a/modelscope/pipelines/nlp/dialog_modeling_pipeline.py +++ b/modelscope/pipelines/nlp/dialog_modeling_pipeline.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, Union +from typing import Dict, Union from ...metainfo import Pipelines from ...models import Model @@ -22,7 +22,7 @@ class DialogModelingPipeline(Pipeline): model: Union[SpaceForDialogModeling, str], preprocessor: DialogModelingPreprocessor = None, **kwargs): - """use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation + """use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation Args: model (SpaceForDialogModeling): a model instance diff --git a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py index 2afe64d9..7bd75218 100644 --- a/modelscope/pipelines/nlp/sentiment_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sentiment_classification_pipeline.py @@ -1,8 +1,5 @@ -import os -import uuid from typing import Any, Dict, Union -import json import numpy as np import torch @@ -11,7 +8,7 @@ from ...models import Model from ...models.nlp import SbertForSentimentClassification from ...preprocessors import SentimentClassificationPreprocessor from ...utils.constant import Tasks -from ..base import Input, Pipeline +from ..base import Pipeline from ..builder import PIPELINES from ..outputs import OutputKeys diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index a0784afa..339428ff 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -1,12 +1,11 @@ import os.path as osp -from typing import Any, Dict, Optional, Union +from typing import Any, Dict import numpy as np import tensorflow as tf from ...hub.snapshot_download import snapshot_download from ...metainfo import Pipelines -from ...models import Model from ...models.nlp import CsanmtForTranslation from ...utils.constant import ModelFile, Tasks from ...utils.logger import get_logger diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index a7ea1e9a..818a968f 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -1,9 +1,5 @@ -import os -import uuid from typing import Any, Dict, Union -import json -import numpy as np import torch from scipy.special import softmax @@ -12,7 +8,7 @@ from ...models import Model from ...models.nlp import SbertForZeroShotClassification from ...preprocessors import ZeroShotClassificationPreprocessor from ...utils.constant import Tasks -from ..base import Input, Pipeline +from ..base import Pipeline from ..builder import PIPELINES from ..outputs import OutputKeys @@ -30,7 +26,7 @@ class ZeroShotClassificationPipeline(Pipeline): **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction Args: - model (SbertForSentimentClassification): a model instance + model (SbertForZeroShotClassification): a model instance preprocessor (SentimentClassificationPreprocessor): a preprocessor instance """ assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py index d034a7d4..ceee782f 100644 --- a/modelscope/pipelines/util.py +++ b/modelscope/pipelines/util.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import List, Union +from typing import List, Optional, Union from modelscope.hub.api import HubApi from modelscope.hub.file_download import model_file_download @@ -20,8 +20,9 @@ def is_config_has_model(cfg_file): return False -def is_official_hub_path(path: Union[str, List]): - """ Whether path is a official hub name or a valid local +def is_official_hub_path(path: Union[str, List], + revision: Optional[str] = 'master'): + """ Whether path is an official hub name or a valid local path to official hub directory. """ @@ -31,7 +32,7 @@ def is_official_hub_path(path: Union[str, List]): return osp.exists(cfg_file) else: try: - _ = HubApi().get_model(path) + _ = HubApi().get_model(path, revision=revision) return True except Exception: return False diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index f2a3c120..db224fb9 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -42,7 +42,7 @@ def create_model_if_not_exist( return True -def read_config(model_id_or_path: str): +def read_config(model_id_or_path: str, revision: Optional[str] = 'master'): """ Read config from hub or local path Args: @@ -52,8 +52,8 @@ def read_config(model_id_or_path: str): config (:obj:`Config`): config object """ if not os.path.exists(model_id_or_path): - local_path = model_file_download(model_id_or_path, - ModelFile.CONFIGURATION) + local_path = model_file_download( + model_id_or_path, ModelFile.CONFIGURATION, revision=revision) else: local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 22fb127b..584d6d91 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -import shutil import tempfile import unittest @@ -47,7 +46,8 @@ class ImageMattingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): - img_matting = pipeline(Tasks.image_matting, model=self.model_id) + img_matting = pipeline( + Tasks.image_matting, model=self.model_id, model_revision='beta') result = img_matting('data/test/images/image_matting.png') cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])