diff --git a/modelscope/models/base.py b/modelscope/models/base.py index e641236d..3e361f91 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -2,14 +2,13 @@ import os.path as osp from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union +from typing import Dict, Union -from maas_hub.file_download import model_file_download from maas_hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model from modelscope.utils.config import Config -from modelscope.utils.constant import CONFIGFILE +from modelscope.utils.constant import ModelFile from modelscope.utils.hub import get_model_cache_dir Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -47,7 +46,8 @@ class Model(ABC): # raise ValueError( # 'Remote model repo {model_name_or_path} does not exists') - cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) + cfg = Config.from_file( + osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task model_cfg = cfg.model # TODO @wenmeng.zwm may should manually initialize model after model building diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6495a5db..ad3511cb 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -3,21 +3,17 @@ import os.path as osp from typing import List, Union -import json -from maas_hub.file_download import model_file_download - from modelscope.models.base import Model from modelscope.utils.config import Config, ConfigDict -from modelscope.utils.constant import CONFIGFILE, Tasks +from modelscope.utils.constant import Tasks from modelscope.utils.registry import Registry, build_from_cfg from .base import Pipeline -from .util import is_model_name PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) - Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), + Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 6f3ff5f5..0c60dfa7 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -1,5 +1,5 @@ import os.path as osp -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict import cv2 import numpy as np @@ -7,7 +7,7 @@ import PIL from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 - model_path = osp.join(self.model, 'matting_person.pb') + model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py index 43a7ac5a..37c9c929 100644 --- a/modelscope/pipelines/util.py +++ b/modelscope/pipelines/util.py @@ -1,14 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os import os.path as osp from typing import List, Union -import json from maas_hub.file_download import model_file_download -from matplotlib.pyplot import get from modelscope.utils.config import Config -from modelscope.utils.constant import CONFIGFILE +from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger logger = get_logger() @@ -29,14 +26,14 @@ def is_model_name(model: Union[str, List]): def is_model_name_impl(model): if osp.exists(model): - cfg_file = osp.join(model, CONFIGFILE) + cfg_file = osp.join(model, ModelFile.CONFIGURATION) if osp.exists(cfg_file): return is_config_has_model(cfg_file) else: return False else: try: - cfg_file = model_file_download(model, CONFIGFILE) + cfg_file = model_file_download(model, ModelFile.CONFIGURATION) return is_config_has_model(cfg_file) except Exception: return False diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index fa30dd2a..c6eb6385 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -71,5 +71,16 @@ class Hubs(object): huggingface = 'huggingface' -# configuration filename -CONFIGFILE = 'configuration.json' +class ModelFile(object): + CONFIGURATION = 'configuration.json' + README = 'README.md' + TF_SAVED_MODEL_FILE = 'saved_model.pb' + TF_GRAPH_FILE = 'tf_graph.pb' + TF_CHECKPOINT_FOLDER = 'tf_ckpts' + TF_CKPT_PREFIX = 'ckpt-' + TORCH_MODEL_FILE = 'pytorch_model.pt' + TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' + + +TENSORFLOW = 'tensorflow' +PYTORCH = 'pytorch' diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 73a938ea..888564c7 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect -from email.policy import default from modelscope.utils.logger import get_logger diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 53006317..f1a627a0 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -9,25 +9,26 @@ import cv2 from modelscope.fileio import File from modelscope.pipelines import pipeline from modelscope.pydatasets import PyDataset -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import get_model_cache_dir class ImageMattingTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/image-matting-person' + self.model_id = 'damo/cv_unet_image-matting_damo' # switch to False if downloading everytime is not desired purge_cache = True if purge_cache: shutil.rmtree( get_model_cache_dir(self.model_id), ignore_errors=True) - def test_run(self): + @unittest.skip('deprecated, download model from model hub instead') + def test_run_with_direct_file_download(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' with tempfile.TemporaryDirectory() as tmp_dir: - model_file = osp.join(tmp_dir, 'matting_person.pb') + model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE) with open(model_file, 'wb') as ofile: ofile.write(File.read(model_path)) img_matting = pipeline(Tasks.image_matting, model=tmp_dir) diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index fb7044e8..a3770f0d 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -1,11 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import argparse -import os.path as osp import tempfile import unittest -from pathlib import Path -from modelscope.fileio import dump, load from modelscope.utils.config import Config obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}