| @@ -7,7 +7,7 @@ import PIL | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import TF_GRAPH_FILE, Tasks | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
| tf = tf.compat.v1 | tf = tf.compat.v1 | ||||
| model_path = osp.join(self.model, TF_GRAPH_FILE) | |||||
| model_path = osp.join(self.model, 'matting_person.pb') | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | config = tf.ConfigProto(allow_soft_placement=True) | ||||
| config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
| @@ -75,12 +75,3 @@ class Hubs(object): | |||||
| # in order to avoid conflict with huggingface | # in order to avoid conflict with huggingface | ||||
| # config file we use maas_config instead | # config file we use maas_config instead | ||||
| CONFIGFILE = 'maas_config.json' | CONFIGFILE = 'maas_config.json' | ||||
| README_FILE = 'README.md' | |||||
| TF_SAVED_MODEL_FILE = 'saved_model.pb' | |||||
| TF_GRAPH_FILE = 'tf_graph.pb' | |||||
| TF_CHECKPOINT_FOLDER = 'tf_ckpts' | |||||
| TF_CHECKPOINT_FILE = 'checkpoint' | |||||
| TORCH_MODEL_FILE = 'pytorch_model.bin' | |||||
| TENSORFLOW = 'tensorflow' | |||||
| PYTORCH = 'pytorch' | |||||
| @@ -16,15 +16,14 @@ from modelscope.utils.hub import get_model_cache_dir | |||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/cv_unet_image-matting_damo' | |||||
| self.model_id = 'damo/image-matting-person' | |||||
| # switch to False if downloading everytime is not desired | # switch to False if downloading everytime is not desired | ||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| shutil.rmtree( | shutil.rmtree( | ||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | get_model_cache_dir(self.model_id), ignore_errors=True) | ||||
| @unittest.skip('deprecated, download model from model hub instead') | |||||
| def test_run_with_direct_file_download(self): | |||||
| def test_run(self): | |||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
| '.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||