Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9019685master
| @@ -2,14 +2,13 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, List, Tuple, Union | |||||
| from typing import Dict, Union | |||||
| from maas_hub.file_download import model_file_download | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import CONFIGFILE | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.hub import get_model_cache_dir | from modelscope.utils.hub import get_model_cache_dir | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -47,7 +46,8 @@ class Model(ABC): | |||||
| # raise ValueError( | # raise ValueError( | ||||
| # 'Remote model repo {model_name_or_path} does not exists') | # 'Remote model repo {model_name_or_path} does not exists') | ||||
| cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | |||||
| cfg = Config.from_file( | |||||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||||
| task_name = cfg.task | task_name = cfg.task | ||||
| model_cfg = cfg.model | model_cfg = cfg.model | ||||
| # TODO @wenmeng.zwm may should manually initialize model after model building | # TODO @wenmeng.zwm may should manually initialize model after model building | ||||
| @@ -3,21 +3,17 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | from typing import List, Union | ||||
| import json | |||||
| from maas_hub.file_download import model_file_download | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
| from modelscope.utils.constant import CONFIGFILE, Tasks | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.registry import Registry, build_from_cfg | from modelscope.utils.registry import Registry, build_from_cfg | ||||
| from .base import Pipeline | from .base import Pipeline | ||||
| from .util import is_model_name | |||||
| PIPELINES = Registry('pipelines') | PIPELINES = Registry('pipelines') | ||||
| DEFAULT_MODEL_FOR_PIPELINE = { | DEFAULT_MODEL_FOR_PIPELINE = { | ||||
| # TaskName: (pipeline_module_name, model_repo) | # TaskName: (pipeline_module_name, model_repo) | ||||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | |||||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'), | |||||
| Tasks.text_classification: | Tasks.text_classification: | ||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | ||||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | ||||
| @@ -1,5 +1,5 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -7,7 +7,7 @@ import PIL | |||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
| @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
| tf = tf.compat.v1 | tf = tf.compat.v1 | ||||
| model_path = osp.join(self.model, 'matting_person.pb') | |||||
| model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE) | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | config = tf.ConfigProto(allow_soft_placement=True) | ||||
| config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
| @@ -1,14 +1,11 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | from typing import List, Union | ||||
| import json | |||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| from matplotlib.pyplot import get | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import CONFIGFILE | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| logger = get_logger() | logger = get_logger() | ||||
| @@ -29,14 +26,14 @@ def is_model_name(model: Union[str, List]): | |||||
| def is_model_name_impl(model): | def is_model_name_impl(model): | ||||
| if osp.exists(model): | if osp.exists(model): | ||||
| cfg_file = osp.join(model, CONFIGFILE) | |||||
| cfg_file = osp.join(model, ModelFile.CONFIGURATION) | |||||
| if osp.exists(cfg_file): | if osp.exists(cfg_file): | ||||
| return is_config_has_model(cfg_file) | return is_config_has_model(cfg_file) | ||||
| else: | else: | ||||
| return False | return False | ||||
| else: | else: | ||||
| try: | try: | ||||
| cfg_file = model_file_download(model, CONFIGFILE) | |||||
| cfg_file = model_file_download(model, ModelFile.CONFIGURATION) | |||||
| return is_config_has_model(cfg_file) | return is_config_has_model(cfg_file) | ||||
| except Exception: | except Exception: | ||||
| return False | return False | ||||
| @@ -71,5 +71,16 @@ class Hubs(object): | |||||
| huggingface = 'huggingface' | huggingface = 'huggingface' | ||||
| # configuration filename | |||||
| CONFIGFILE = 'configuration.json' | |||||
| class ModelFile(object): | |||||
| CONFIGURATION = 'configuration.json' | |||||
| README = 'README.md' | |||||
| TF_SAVED_MODEL_FILE = 'saved_model.pb' | |||||
| TF_GRAPH_FILE = 'tf_graph.pb' | |||||
| TF_CHECKPOINT_FOLDER = 'tf_ckpts' | |||||
| TF_CKPT_PREFIX = 'ckpt-' | |||||
| TORCH_MODEL_FILE = 'pytorch_model.pt' | |||||
| TORCH_MODEL_BIN_FILE = 'pytorch_model.bin' | |||||
| TENSORFLOW = 'tensorflow' | |||||
| PYTORCH = 'pytorch' | |||||
| @@ -1,7 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import inspect | import inspect | ||||
| from email.policy import default | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -9,25 +9,26 @@ import cv2 | |||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.pydatasets import PyDataset | from modelscope.pydatasets import PyDataset | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.hub import get_model_cache_dir | from modelscope.utils.hub import get_model_cache_dir | ||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/image-matting-person' | |||||
| self.model_id = 'damo/cv_unet_image-matting_damo' | |||||
| # switch to False if downloading everytime is not desired | # switch to False if downloading everytime is not desired | ||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| shutil.rmtree( | shutil.rmtree( | ||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | get_model_cache_dir(self.model_id), ignore_errors=True) | ||||
| def test_run(self): | |||||
| @unittest.skip('deprecated, download model from model hub instead') | |||||
| def test_run_with_direct_file_download(self): | |||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
| '.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
| with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||
| model_file = osp.join(tmp_dir, 'matting_person.pb') | |||||
| model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE) | |||||
| with open(model_file, 'wb') as ofile: | with open(model_file, 'wb') as ofile: | ||||
| ofile.write(File.read(model_path)) | ofile.write(File.read(model_path)) | ||||
| img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | img_matting = pipeline(Tasks.image_matting, model=tmp_dir) | ||||
| @@ -1,11 +1,8 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import argparse | import argparse | ||||
| import os.path as osp | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from pathlib import Path | |||||
| from modelscope.fileio import dump, load | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | ||||