Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9300163master
| @@ -141,7 +141,9 @@ def model_file_download( | |||||
| cached_file_path = cache.get_file_by_path_and_commit_id( | cached_file_path = cache.get_file_by_path_and_commit_id( | ||||
| file_path, revision) | file_path, revision) | ||||
| if cached_file_path is not None: | if cached_file_path is not None: | ||||
| logger.info('The specified file is in cache, skip downloading!') | |||||
| file_name = os.path.basename(cached_file_path) | |||||
| logger.info( | |||||
| f'File {file_name} already in cache, skip downloading!') | |||||
| return cached_file_path # the file is in cache. | return cached_file_path # the file is in cache. | ||||
| is_commit_id = True | is_commit_id = True | ||||
| # we need to download again | # we need to download again | ||||
| @@ -1,13 +1,11 @@ | |||||
| import os | import os | ||||
| import tempfile | import tempfile | ||||
| from glob import glob | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict, Optional, Union | from typing import Dict, Optional, Union | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .api import HubApi, ModelScopeConfig | from .api import HubApi, ModelScopeConfig | ||||
| from .constants import DEFAULT_MODELSCOPE_GROUP, MODEL_ID_SEPARATOR | |||||
| from .errors import NotExistError, RequestError, raise_on_error | |||||
| from .errors import NotExistError | |||||
| from .file_download import (get_file_download_url, http_get_file, | from .file_download import (get_file_download_url, http_get_file, | ||||
| http_user_agent) | http_user_agent) | ||||
| from .utils.caching import ModelFileSystemCache | from .utils.caching import ModelFileSystemCache | ||||
| @@ -98,8 +96,9 @@ def snapshot_download(model_id: str, | |||||
| continue | continue | ||||
| # check model_file is exist in cache, if exist, skip download, otherwise download | # check model_file is exist in cache, if exist, skip download, otherwise download | ||||
| if cache.exists(model_file): | if cache.exists(model_file): | ||||
| file_name = os.path.basename(model_file['Name']) | |||||
| logger.info( | logger.info( | ||||
| 'The specified file is in cache, skip downloading!') | |||||
| f'File {file_name} already in cache, skip downloading!') | |||||
| continue | continue | ||||
| # get download url | # get download url | ||||
| @@ -33,7 +33,7 @@ class Model(ABC): | |||||
| standard model outputs. | standard model outputs. | ||||
| Args: | Args: | ||||
| inputs: input data | |||||
| input: input data | |||||
| Return: | Return: | ||||
| dict of results: a dict containing outputs of model, each | dict of results: a dict containing outputs of model, each | ||||
| @@ -50,9 +50,17 @@ class Model(ABC): | |||||
| """ Instantiate a model from local directory or remote model repo. Note | """ Instantiate a model from local directory or remote model repo. Note | ||||
| that when loading from remote, the model revision can be specified. | that when loading from remote, the model revision can be specified. | ||||
| """ | """ | ||||
| prefetched = kwargs.get('model_prefetched') | |||||
| if prefetched is not None: | |||||
| kwargs.pop('model_prefetched') | |||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| if prefetched is True: | |||||
| raise RuntimeError( | |||||
| 'Expecting model is pre-fetched locally, but is not found.' | |||||
| ) | |||||
| local_model_dir = snapshot_download(model_name_or_path, revision) | local_model_dir = snapshot_download(model_name_or_path, revision) | ||||
| logger.info(f'initialize model from {local_model_dir}') | logger.info(f'initialize model from {local_model_dir}') | ||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| @@ -37,7 +37,8 @@ class ANSPipeline(Pipeline): | |||||
| SAMPLE_RATE = 16000 | SAMPLE_RATE = 16000 | ||||
| def __init__(self, model): | def __init__(self, model): | ||||
| r""" | |||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| @@ -1,7 +1,4 @@ | |||||
| import io | |||||
| import os | import os | ||||
| import shutil | |||||
| import stat | |||||
| import subprocess | import subprocess | ||||
| from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
| @@ -28,7 +25,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| model: Union[Model, str] = None, | model: Union[Model, str] = None, | ||||
| preprocessor: WavToLists = None, | preprocessor: WavToLists = None, | ||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | """ | ||||
| model = model if isinstance(model, | model = model if isinstance(model, | ||||
| @@ -39,6 +39,7 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||||
| model=model, | model=model, | ||||
| preprocessor=preprocessor, | preprocessor=preprocessor, | ||||
| **kwargs) | **kwargs) | ||||
| assert model is not None, 'kws model should be provided' | assert model is not None, 'kws model should be provided' | ||||
| self._preprocessor = preprocessor | self._preprocessor = preprocessor | ||||
| @@ -63,7 +63,8 @@ class LinearAECPipeline(Pipeline): | |||||
| """ | """ | ||||
| def __init__(self, model): | def __init__(self, model): | ||||
| r""" | |||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | Args: | ||||
| model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
| """ | """ | ||||
| @@ -1,17 +1,12 @@ | |||||
| import time | |||||
| from typing import Any, Dict, List | from typing import Any, Dict, List | ||||
| import numpy as np | import numpy as np | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models import Model | |||||
| from modelscope.models.audio.tts.am import SambertNetHifi16k | |||||
| from modelscope.models.audio.tts.vocoder import Hifigan16k | |||||
| from modelscope.pipelines.base import Pipeline | from modelscope.pipelines.base import Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import (Preprocessor, TextToTacotronSymbols, | |||||
| build_preprocessor) | |||||
| from modelscope.utils.constant import Fields, Tasks | |||||
| from modelscope.preprocessors import Preprocessor, TextToTacotronSymbols | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['TextToSpeechSambertHifigan16kPipeline'] | __all__ = ['TextToSpeechSambertHifigan16kPipeline'] | ||||
| @@ -24,6 +19,11 @@ class TextToSpeechSambertHifigan16kPipeline(Pipeline): | |||||
| model: List[str] = None, | model: List[str] = None, | ||||
| preprocessor: Preprocessor = None, | preprocessor: Preprocessor = None, | ||||
| **kwargs): | **kwargs): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| assert len(model) == 3, 'model number should be 3' | assert len(model) == 3, 'model number should be 3' | ||||
| if preprocessor is None: | if preprocessor is None: | ||||
| lang_type = 'pinyin' | lang_type = 'pinyin' | ||||
| @@ -17,9 +17,6 @@ Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||||
| Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray'] | Input = Union[str, tuple, MsDataset, 'PIL.Image.Image', 'numpy.ndarray'] | ||||
| InputModel = Union[str, Model] | InputModel = Union[str, Model] | ||||
| output_keys = [ | |||||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -28,9 +25,9 @@ class Pipeline(ABC): | |||||
| def initiate_single_model(self, model): | def initiate_single_model(self, model): | ||||
| logger.info(f'initiate model from {model}') | logger.info(f'initiate model from {model}') | ||||
| if isinstance(model, str) and is_official_hub_path(model): | if isinstance(model, str) and is_official_hub_path(model): | ||||
| model = snapshot_download( | |||||
| model) if not osp.exists(model) else model | |||||
| return Model.from_pretrained(model) if is_model(model) else model | |||||
| # expecting model has been prefetched to local cache beforehand | |||||
| return Model.from_pretrained( | |||||
| model, model_prefetched=True) if is_model(model) else model | |||||
| elif isinstance(model, Model): | elif isinstance(model, Model): | ||||
| return model | return model | ||||
| else: | else: | ||||
| @@ -1,7 +1,8 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import List, Union | |||||
| from typing import List, Optional, Union | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| @@ -67,6 +68,21 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| } | } | ||||
| def normalize_model_input(model, model_revision): | |||||
| """ normalize the input model, to ensure that a model str is a valid local path: in other words, | |||||
| for model represented by a model id, the model shall be downloaded locally | |||||
| """ | |||||
| if isinstance(model, str) and is_official_hub_path(model, model_revision): | |||||
| # note that if there is already a local copy, snapshot_download will check and skip downloading | |||||
| model = snapshot_download(model, revision=model_revision) | |||||
| elif isinstance(model, list) and isinstance(model[0], str): | |||||
| for idx in range(len(model)): | |||||
| if is_official_hub_path(model[idx], model_revision): | |||||
| model[idx] = snapshot_download( | |||||
| model[idx], revision=model_revision) | |||||
| return model | |||||
| def build_pipeline(cfg: ConfigDict, | def build_pipeline(cfg: ConfigDict, | ||||
| task_name: str = None, | task_name: str = None, | ||||
| default_args: dict = None): | default_args: dict = None): | ||||
| @@ -89,8 +105,9 @@ def pipeline(task: str = None, | |||||
| pipeline_name: str = None, | pipeline_name: str = None, | ||||
| framework: str = None, | framework: str = None, | ||||
| device: int = -1, | device: int = -1, | ||||
| model_revision: Optional[str] = 'master', | |||||
| **kwargs) -> Pipeline: | **kwargs) -> Pipeline: | ||||
| """ Factory method to build a obj:`Pipeline`. | |||||
| """ Factory method to build an obj:`Pipeline`. | |||||
| Args: | Args: | ||||
| @@ -100,6 +117,8 @@ def pipeline(task: str = None, | |||||
| config_file (str, optional): path to config file. | config_file (str, optional): path to config file. | ||||
| pipeline_name (str, optional): pipeline class name or alias name. | pipeline_name (str, optional): pipeline class name or alias name. | ||||
| framework (str, optional): framework type. | framework (str, optional): framework type. | ||||
| model_revision: revision of model(s) if getting from model hub, for multiple models, expecting | |||||
| all models to have the same revision | |||||
| device (int, optional): which device is used to do inference. | device (int, optional): which device is used to do inference. | ||||
| Return: | Return: | ||||
| @@ -123,14 +142,18 @@ def pipeline(task: str = None, | |||||
| assert isinstance(model, (type(None), str, Model, list)), \ | assert isinstance(model, (type(None), str, Model, list)), \ | ||||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | ||||
| model = normalize_model_input(model, model_revision) | |||||
| if pipeline_name is None: | if pipeline_name is None: | ||||
| # get default pipeline for this task | # get default pipeline for this task | ||||
| if isinstance(model, str) \ | if isinstance(model, str) \ | ||||
| or (isinstance(model, list) and isinstance(model[0], str)): | or (isinstance(model, list) and isinstance(model[0], str)): | ||||
| if is_official_hub_path(model): | |||||
| if is_official_hub_path(model, revision=model_revision): | |||||
| # read config file from hub and parse | # read config file from hub and parse | ||||
| cfg = read_config(model) if isinstance( | |||||
| model, str) else read_config(model[0]) | |||||
| cfg = read_config( | |||||
| model, revision=model_revision) if isinstance( | |||||
| model, str) else read_config( | |||||
| model[0], revision=model_revision) | |||||
| assert hasattr( | assert hasattr( | ||||
| cfg, | cfg, | ||||
| 'pipeline'), 'pipeline config is missing from config file.' | 'pipeline'), 'pipeline config is missing from config file.' | ||||
| @@ -152,7 +175,7 @@ def pipeline(task: str = None, | |||||
| pipeline_name = first_model.pipeline.type | pipeline_name = first_model.pipeline.type | ||||
| else: | else: | ||||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | pipeline_name, default_model_repo = get_default_pipeline_info(task) | ||||
| model = default_model_repo | |||||
| model = normalize_model_input(default_model_repo, model_revision) | |||||
| cfg = ConfigDict(type=pipeline_name, model=model) | cfg = ConfigDict(type=pipeline_name, model=model) | ||||
| @@ -23,6 +23,11 @@ logger = get_logger() | |||||
| class ActionRecognitionPipeline(Pipeline): | class ActionRecognitionPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | ||||
| logger.info(f'loading model from {model_path}') | logger.info(f'loading model from {model_path}') | ||||
| @@ -1,5 +1,4 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| import tempfile | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import cv2 | import cv2 | ||||
| @@ -8,13 +7,12 @@ import torch | |||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | from torchvision import transforms | ||||
| from modelscope.fileio import File | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.animal_recognition import resnet | from modelscope.models.cv.animal_recognition import resnet | ||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| @@ -28,6 +26,11 @@ logger = get_logger() | |||||
| class AnimalRecogPipeline(Pipeline): | class AnimalRecogPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| import torch | import torch | ||||
| @@ -1,11 +1,8 @@ | |||||
| import math | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import cv2 | |||||
| import decord | import decord | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | |||||
| import torch | import torch | ||||
| import torchvision.transforms.functional as TF | import torchvision.transforms.functional as TF | ||||
| from decord import VideoReader, cpu | from decord import VideoReader, cpu | ||||
| @@ -30,6 +27,11 @@ logger = get_logger() | |||||
| class CMDSSLVideoEmbeddingPipeline(Pipeline): | class CMDSSLVideoEmbeddingPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | ||||
| logger.info(f'loading model from {model_path}') | logger.info(f'loading model from {model_path}') | ||||
| @@ -31,6 +31,11 @@ logger = get_logger() | |||||
| class ImageCartoonPipeline(Pipeline): | class ImageCartoonPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| self.facer = FaceAna(self.model) | self.facer = FaceAna(self.model) | ||||
| self.sess_anime_head = self.load_sess( | self.sess_anime_head = self.load_sess( | ||||
| @@ -22,6 +22,11 @@ logger = get_logger() | |||||
| class ImageMattingPipeline(Pipeline): | class ImageMattingPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| import tensorflow as tf | import tensorflow as tf | ||||
| if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
| @@ -1,8 +1,5 @@ | |||||
| import math | |||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import sys | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -48,6 +45,11 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, | |||||
| class OCRDetectionPipeline(Pipeline): | class OCRDetectionPipeline(Pipeline): | ||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| tf.reset_default_graph() | tf.reset_default_graph() | ||||
| model_path = osp.join( | model_path = osp.join( | ||||
| @@ -18,7 +18,12 @@ class ImageCaptionPipeline(Pipeline): | |||||
| model: Union[Model, str], | model: Union[Model, str], | ||||
| preprocessor: [Preprocessor] = None, | preprocessor: [Preprocessor] = None, | ||||
| **kwargs): | **kwargs): | ||||
| super().__init__() | |||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| assert isinstance(model, str) or isinstance(model, Model), \ | assert isinstance(model, str) or isinstance(model, Model), \ | ||||
| 'model must be a single str or OfaForImageCaptioning' | 'model must be a single str or OfaForImageCaptioning' | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| @@ -1,4 +1,4 @@ | |||||
| from typing import Any, Dict, Union | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| @@ -15,6 +15,11 @@ logger = get_logger() | |||||
| class MultiModalEmbeddingPipeline(Pipeline): | class MultiModalEmbeddingPipeline(Pipeline): | ||||
| def __init__(self, model: str, device_id: int = -1): | def __init__(self, model: str, device_id: int = -1): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| pipe_model = Model.from_pretrained(model) | pipe_model = Model.from_pretrained(model) | ||||
| elif isinstance(model, Model): | elif isinstance(model, Model): | ||||
| @@ -19,6 +19,11 @@ logger = get_logger() | |||||
| class TextToImageSynthesisPipeline(Pipeline): | class TextToImageSynthesisPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| device_id = 0 if torch.cuda.is_available() else -1 | device_id = 0 if torch.cuda.is_available() else -1 | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| pipe_model = Model.from_pretrained(model, device_id=device_id) | pipe_model = Model.from_pretrained(model, device_id=device_id) | ||||
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from typing import Any, Dict, Union | |||||
| from typing import Dict, Union | |||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import Model | from ...models import Model | ||||
| @@ -22,7 +22,7 @@ class DialogModelingPipeline(Pipeline): | |||||
| model: Union[SpaceForDialogModeling, str], | model: Union[SpaceForDialogModeling, str], | ||||
| preprocessor: DialogModelingPreprocessor = None, | preprocessor: DialogModelingPreprocessor = None, | ||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a dialog modleing pipeline for dialog response generation | |||||
| """use `model` and `preprocessor` to create a dialog modeling pipeline for dialog response generation | |||||
| Args: | Args: | ||||
| model (SpaceForDialogModeling): a model instance | model (SpaceForDialogModeling): a model instance | ||||
| @@ -1,8 +1,5 @@ | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -11,7 +8,7 @@ from ...models import Model | |||||
| from ...models.nlp import SbertForSentimentClassification | from ...models.nlp import SbertForSentimentClassification | ||||
| from ...preprocessors import SentimentClassificationPreprocessor | from ...preprocessors import SentimentClassificationPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| from ..base import Input, Pipeline | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| from ..outputs import OutputKeys | from ..outputs import OutputKeys | ||||
| @@ -1,12 +1,11 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict, Optional, Union | |||||
| from typing import Any, Dict | |||||
| import numpy as np | import numpy as np | ||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from ...hub.snapshot_download import snapshot_download | from ...hub.snapshot_download import snapshot_download | ||||
| from ...metainfo import Pipelines | from ...metainfo import Pipelines | ||||
| from ...models import Model | |||||
| from ...models.nlp import CsanmtForTranslation | from ...models.nlp import CsanmtForTranslation | ||||
| from ...utils.constant import ModelFile, Tasks | from ...utils.constant import ModelFile, Tasks | ||||
| from ...utils.logger import get_logger | from ...utils.logger import get_logger | ||||
| @@ -1,9 +1,5 @@ | |||||
| import os | |||||
| import uuid | |||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import json | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from scipy.special import softmax | from scipy.special import softmax | ||||
| @@ -12,7 +8,7 @@ from ...models import Model | |||||
| from ...models.nlp import SbertForZeroShotClassification | from ...models.nlp import SbertForZeroShotClassification | ||||
| from ...preprocessors import ZeroShotClassificationPreprocessor | from ...preprocessors import ZeroShotClassificationPreprocessor | ||||
| from ...utils.constant import Tasks | from ...utils.constant import Tasks | ||||
| from ..base import Input, Pipeline | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| from ..outputs import OutputKeys | from ..outputs import OutputKeys | ||||
| @@ -30,7 +26,7 @@ class ZeroShotClassificationPipeline(Pipeline): | |||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
| Args: | Args: | ||||
| model (SbertForSentimentClassification): a model instance | |||||
| model (SbertForZeroShotClassification): a model instance | |||||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | ||||
| """ | """ | ||||
| assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ | assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ | ||||
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | |||||
| from typing import List, Optional, Union | |||||
| from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| @@ -20,8 +20,9 @@ def is_config_has_model(cfg_file): | |||||
| return False | return False | ||||
| def is_official_hub_path(path: Union[str, List]): | |||||
| """ Whether path is a official hub name or a valid local | |||||
| def is_official_hub_path(path: Union[str, List], | |||||
| revision: Optional[str] = 'master'): | |||||
| """ Whether path is an official hub name or a valid local | |||||
| path to official hub directory. | path to official hub directory. | ||||
| """ | """ | ||||
| @@ -31,7 +32,7 @@ def is_official_hub_path(path: Union[str, List]): | |||||
| return osp.exists(cfg_file) | return osp.exists(cfg_file) | ||||
| else: | else: | ||||
| try: | try: | ||||
| _ = HubApi().get_model(path) | |||||
| _ = HubApi().get_model(path, revision=revision) | |||||
| return True | return True | ||||
| except Exception: | except Exception: | ||||
| return False | return False | ||||
| @@ -42,7 +42,7 @@ def create_model_if_not_exist( | |||||
| return True | return True | ||||
| def read_config(model_id_or_path: str): | |||||
| def read_config(model_id_or_path: str, revision: Optional[str] = 'master'): | |||||
| """ Read config from hub or local path | """ Read config from hub or local path | ||||
| Args: | Args: | ||||
| @@ -52,8 +52,8 @@ def read_config(model_id_or_path: str): | |||||
| config (:obj:`Config`): config object | config (:obj:`Config`): config object | ||||
| """ | """ | ||||
| if not os.path.exists(model_id_or_path): | if not os.path.exists(model_id_or_path): | ||||
| local_path = model_file_download(model_id_or_path, | |||||
| ModelFile.CONFIGURATION) | |||||
| local_path = model_file_download( | |||||
| model_id_or_path, ModelFile.CONFIGURATION, revision=revision) | |||||
| else: | else: | ||||
| local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) | local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) | ||||
| @@ -1,6 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| import shutil | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| @@ -47,7 +46,8 @@ class ImageMattingTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||||
| img_matting = pipeline( | |||||
| Tasks.image_matting, model=self.model_id, model_revision='beta') | |||||
| result = img_matting('data/test/images/image_matting.png') | result = img_matting('data/test/images/image_matting.png') | ||||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | ||||