From de3ea0db5414872ef4262195e1f10c634b5a6226 Mon Sep 17 00:00:00 2001 From: Yingda Chen Date: Mon, 13 Jun 2022 15:19:30 +0800 Subject: [PATCH] [to #42322933]formalize image matting --- modelscope/pipelines/cv/image_matting_pipeline.py | 4 ++-- modelscope/utils/constant.py | 9 +++++++++ tests/pipelines/test_image_matting.py | 5 +++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 6f3ff5f5..3e962d85 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -7,7 +7,7 @@ import PIL from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import TF_GRAPH_FILE, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 - model_path = osp.join(self.model, 'matting_person.pb') + model_path = osp.join(self.model, TF_GRAPH_FILE) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 0d0f2492..c51e2445 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -75,3 +75,12 @@ class Hubs(object): # in order to avoid conflict with huggingface # config file we use maas_config instead CONFIGFILE = 'maas_config.json' + +README_FILE = 'README.md' +TF_SAVED_MODEL_FILE = 'saved_model.pb' +TF_GRAPH_FILE = 'tf_graph.pb' +TF_CHECKPOINT_FOLDER = 'tf_ckpts' +TF_CHECKPOINT_FILE = 'checkpoint' +TORCH_MODEL_FILE = 'pytorch_model.bin' +TENSORFLOW = 'tensorflow' +PYTORCH = 'pytorch' diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 53006317..69195bd1 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -16,14 +16,15 @@ from modelscope.utils.hub import get_model_cache_dir class ImageMattingTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/image-matting-person' + self.model_id = 'damo/cv_unet_image-matting_damo' # switch to False if downloading everytime is not desired purge_cache = True if purge_cache: shutil.rmtree( get_model_cache_dir(self.model_id), ignore_errors=True) - def test_run(self): + @unittest.skip('deprecated, download model from model hub instead') + def test_run_with_direct_file_download(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' with tempfile.TemporaryDirectory() as tmp_dir: