| @@ -7,7 +7,7 @@ import PIL | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import TF_GRAPH_FILE, Tasks | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
| tf = tf.compat.v1 | tf = tf.compat.v1 | ||||
| model_path = osp.join(self.model, TF_GRAPH_FILE) | |||||
| model_path = osp.join(self.model, 'matting_person.pb') | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | config = tf.ConfigProto(allow_soft_placement=True) | ||||
| config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
| @@ -5,8 +5,22 @@ from typing import List, Union | |||||
| import json | import json | ||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| from matplotlib.pyplot import get | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import CONFIGFILE | from modelscope.utils.constant import CONFIGFILE | ||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| def is_config_has_model(cfg_file): | |||||
| try: | |||||
| cfg = Config.from_file(cfg_file) | |||||
| return hasattr(cfg, 'model') | |||||
| except Exception as e: | |||||
| logger.error(f'parse config file {cfg_file} failed: {e}') | |||||
| return False | |||||
| def is_model_name(model: Union[str, List]): | def is_model_name(model: Union[str, List]): | ||||
| @@ -15,24 +29,17 @@ def is_model_name(model: Union[str, List]): | |||||
| def is_model_name_impl(model): | def is_model_name_impl(model): | ||||
| if osp.exists(model): | if osp.exists(model): | ||||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||||
| return True | |||||
| cfg_file = osp.join(model, CONFIGFILE) | |||||
| if osp.exists(cfg_file): | |||||
| return is_config_has_model(cfg_file) | |||||
| else: | else: | ||||
| return False | return False | ||||
| else: | else: | ||||
| # try: | |||||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||||
| # except Exception: | |||||
| # cfg_file = None | |||||
| # TODO @wenmeng.zwm use exception instead of | |||||
| # following tricky logic | |||||
| cfg_file = model_file_download(model, CONFIGFILE) | |||||
| with open(cfg_file, 'r') as infile: | |||||
| cfg = json.load(infile) | |||||
| if 'Code' in cfg: | |||||
| try: | |||||
| cfg_file = model_file_download(model, CONFIGFILE) | |||||
| return is_config_has_model(cfg_file) | |||||
| except Exception: | |||||
| return False | return False | ||||
| else: | |||||
| return True | |||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| return is_model_name_impl(model) | return is_model_name_impl(model) | ||||
| @@ -9,7 +9,7 @@ from modelscope.utils.constant import Fields | |||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| @PREPROCESSORS.register_module(Fields.image) | |||||
| @PREPROCESSORS.register_module(Fields.cv) | |||||
| class LoadImage: | class LoadImage: | ||||
| """Load an image from file or url. | """Load an image from file or url. | ||||
| Added or updated keys are "filename", "img", "img_shape", | Added or updated keys are "filename", "img", "img_shape", | ||||
| @@ -74,17 +74,17 @@ class Config: | |||||
| {'c': [1, 2, 3], 'd': 'dd'} | {'c': [1, 2, 3], 'd': 'dd'} | ||||
| >>> cfg.b.d | >>> cfg.b.d | ||||
| 'dd' | 'dd' | ||||
| >>> cfg = Config.from_file('configs/examples/config.json') | |||||
| >>> cfg = Config.from_file('configs/examples/configuration.json') | |||||
| >>> cfg.filename | >>> cfg.filename | ||||
| 'configs/examples/config.json' | |||||
| 'configs/examples/configuration.json' | |||||
| >>> cfg.b | >>> cfg.b | ||||
| {'c': [1, 2, 3], 'd': 'dd'} | {'c': [1, 2, 3], 'd': 'dd'} | ||||
| >>> cfg = Config.from_file('configs/examples/config.py') | |||||
| >>> cfg = Config.from_file('configs/examples/configuration.py') | |||||
| >>> cfg.filename | >>> cfg.filename | ||||
| "configs/examples/config.py" | |||||
| >>> cfg = Config.from_file('configs/examples/config.yaml') | |||||
| "configs/examples/configuration.py" | |||||
| >>> cfg = Config.from_file('configs/examples/configuration.yaml') | |||||
| >>> cfg.filename | >>> cfg.filename | ||||
| "configs/examples/config.yaml" | |||||
| "configs/examples/configuration.yaml" | |||||
| """ | """ | ||||
| @staticmethod | @staticmethod | ||||
| @@ -4,8 +4,8 @@ | |||||
| class Fields(object): | class Fields(object): | ||||
| """ Names for different application fields | """ Names for different application fields | ||||
| """ | """ | ||||
| image = 'image' | |||||
| video = 'video' | |||||
| # image = 'image' | |||||
| # video = 'video' | |||||
| cv = 'cv' | cv = 'cv' | ||||
| nlp = 'nlp' | nlp = 'nlp' | ||||
| audio = 'audio' | audio = 'audio' | ||||
| @@ -73,15 +73,4 @@ class Hubs(object): | |||||
| # configuration filename | # configuration filename | ||||
| # in order to avoid conflict with huggingface | |||||
| # config file we use maas_config instead | |||||
| CONFIGFILE = 'maas_config.json' | |||||
| README_FILE = 'README.md' | |||||
| TF_SAVED_MODEL_FILE = 'saved_model.pb' | |||||
| TF_GRAPH_FILE = 'tf_graph.pb' | |||||
| TF_CHECKPOINT_FOLDER = 'tf_ckpts' | |||||
| TF_CHECKPOINT_FILE = 'checkpoint' | |||||
| TORCH_MODEL_FILE = 'pytorch_model.bin' | |||||
| TENSORFLOW = 'tensorflow' | |||||
| PYTORCH = 'pytorch' | |||||
| CONFIGFILE = 'configuration.json' | |||||
| @@ -16,15 +16,14 @@ from modelscope.utils.hub import get_model_cache_dir | |||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/cv_unet_image-matting_damo' | |||||
| self.model_id = 'damo/image-matting-person' | |||||
| # switch to False if downloading everytime is not desired | # switch to False if downloading everytime is not desired | ||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| shutil.rmtree( | shutil.rmtree( | ||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | get_model_cache_dir(self.model_id), ignore_errors=True) | ||||
| @unittest.skip('deprecated, download model from model hub instead') | |||||
| def test_run_with_direct_file_download(self): | |||||
| def test_run(self): | |||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
| '.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| @@ -14,25 +14,25 @@ obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | |||||
| class ConfigTest(unittest.TestCase): | class ConfigTest(unittest.TestCase): | ||||
| def test_json(self): | def test_json(self): | ||||
| config_file = 'configs/examples/config.json' | |||||
| config_file = 'configs/examples/configuration.json' | |||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| self.assertEqual(cfg.a, 1) | self.assertEqual(cfg.a, 1) | ||||
| self.assertEqual(cfg.b, obj['b']) | self.assertEqual(cfg.b, obj['b']) | ||||
| def test_yaml(self): | def test_yaml(self): | ||||
| config_file = 'configs/examples/config.yaml' | |||||
| config_file = 'configs/examples/configuration.yaml' | |||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| self.assertEqual(cfg.a, 1) | self.assertEqual(cfg.a, 1) | ||||
| self.assertEqual(cfg.b, obj['b']) | self.assertEqual(cfg.b, obj['b']) | ||||
| def test_py(self): | def test_py(self): | ||||
| config_file = 'configs/examples/config.py' | |||||
| config_file = 'configs/examples/configuration.py' | |||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| self.assertEqual(cfg.a, 1) | self.assertEqual(cfg.a, 1) | ||||
| self.assertEqual(cfg.b, obj['b']) | self.assertEqual(cfg.b, obj['b']) | ||||
| def test_dump(self): | def test_dump(self): | ||||
| config_file = 'configs/examples/config.py' | |||||
| config_file = 'configs/examples/configuration.py' | |||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| self.assertEqual(cfg.a, 1) | self.assertEqual(cfg.a, 1) | ||||
| self.assertEqual(cfg.b, obj['b']) | self.assertEqual(cfg.b, obj['b']) | ||||
| @@ -53,7 +53,7 @@ class ConfigTest(unittest.TestCase): | |||||
| self.assertEqual(yaml_str, infile.read()) | self.assertEqual(yaml_str, infile.read()) | ||||
| def test_to_dict(self): | def test_to_dict(self): | ||||
| config_file = 'configs/examples/config.json' | |||||
| config_file = 'configs/examples/configuration.json' | |||||
| cfg = Config.from_file(config_file) | cfg = Config.from_file(config_file) | ||||
| d = cfg.to_dict() | d = cfg.to_dict() | ||||
| print(d) | print(d) | ||||