yingda.chen 4 years ago
parent
commit
5786b9a0a1
8 changed files with 30 additions and 29 deletions
  1. +4
    -4
      modelscope/models/base.py
  2. +2
    -6
      modelscope/pipelines/builder.py
  3. +3
    -3
      modelscope/pipelines/cv/image_matting_pipeline.py
  4. +3
    -6
      modelscope/pipelines/util.py
  5. +13
    -2
      modelscope/utils/constant.py
  6. +0
    -1
      modelscope/utils/registry.py
  7. +5
    -4
      tests/pipelines/test_image_matting.py
  8. +0
    -3
      tests/utils/test_config.py

+ 4
- 4
modelscope/models/base.py View File

@@ -2,14 +2,13 @@


import os.path as osp import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
from typing import Dict, Union


from maas_hub.file_download import model_file_download
from maas_hub.snapshot_download import snapshot_download from maas_hub.snapshot_download import snapshot_download


from modelscope.models.builder import build_model from modelscope.models.builder import build_model
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import CONFIGFILE
from modelscope.utils.constant import ModelFile
from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.hub import get_model_cache_dir


Tensor = Union['torch.Tensor', 'tf.Tensor'] Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -47,7 +46,8 @@ class Model(ABC):
# raise ValueError( # raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists') # 'Remote model repo {model_name_or_path} does not exists')


cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE))
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))
task_name = cfg.task task_name = cfg.task
model_cfg = cfg.model model_cfg = cfg.model
# TODO @wenmeng.zwm may should manually initialize model after model building # TODO @wenmeng.zwm may should manually initialize model after model building


+ 2
- 6
modelscope/pipelines/builder.py View File

@@ -3,21 +3,17 @@
import os.path as osp import os.path as osp
from typing import List, Union from typing import List, Union


import json
from maas_hub.file_download import model_file_download

from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.utils.config import Config, ConfigDict from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import CONFIGFILE, Tasks
from modelscope.utils.constant import Tasks
from modelscope.utils.registry import Registry, build_from_cfg from modelscope.utils.registry import Registry, build_from_cfg
from .base import Pipeline from .base import Pipeline
from .util import is_model_name


PIPELINES = Registry('pipelines') PIPELINES = Registry('pipelines')


DEFAULT_MODEL_FOR_PIPELINE = { DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo) # TaskName: (pipeline_module_name, model_repo)
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting_damo'),
Tasks.text_classification: Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'), ('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),


+ 3
- 3
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -1,5 +1,5 @@
import os.path as osp import os.path as osp
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
@@ -7,7 +7,7 @@ import PIL


from modelscope.pipelines.base import Input from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image from modelscope.preprocessors import load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from ..base import Pipeline from ..base import Pipeline
from ..builder import PIPELINES from ..builder import PIPELINES
@@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline):
import tensorflow as tf import tensorflow as tf
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1
model_path = osp.join(self.model, 'matting_person.pb')
model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE)


config = tf.ConfigProto(allow_soft_placement=True) config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True


+ 3
- 6
modelscope/pipelines/util.py View File

@@ -1,14 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp import os.path as osp
from typing import List, Union from typing import List, Union


import json
from maas_hub.file_download import model_file_download from maas_hub.file_download import model_file_download
from matplotlib.pyplot import get


from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import CONFIGFILE
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


logger = get_logger() logger = get_logger()
@@ -29,14 +26,14 @@ def is_model_name(model: Union[str, List]):


def is_model_name_impl(model): def is_model_name_impl(model):
if osp.exists(model): if osp.exists(model):
cfg_file = osp.join(model, CONFIGFILE)
cfg_file = osp.join(model, ModelFile.CONFIGURATION)
if osp.exists(cfg_file): if osp.exists(cfg_file):
return is_config_has_model(cfg_file) return is_config_has_model(cfg_file)
else: else:
return False return False
else: else:
try: try:
cfg_file = model_file_download(model, CONFIGFILE)
cfg_file = model_file_download(model, ModelFile.CONFIGURATION)
return is_config_has_model(cfg_file) return is_config_has_model(cfg_file)
except Exception: except Exception:
return False return False


+ 13
- 2
modelscope/utils/constant.py View File

@@ -71,5 +71,16 @@ class Hubs(object):
huggingface = 'huggingface' huggingface = 'huggingface'




# configuration filename
CONFIGFILE = 'configuration.json'
class ModelFile(object):
CONFIGURATION = 'configuration.json'
README = 'README.md'
TF_SAVED_MODEL_FILE = 'saved_model.pb'
TF_GRAPH_FILE = 'tf_graph.pb'
TF_CHECKPOINT_FOLDER = 'tf_ckpts'
TF_CKPT_PREFIX = 'ckpt-'
TORCH_MODEL_FILE = 'pytorch_model.pt'
TORCH_MODEL_BIN_FILE = 'pytorch_model.bin'


TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch'

+ 0
- 1
modelscope/utils/registry.py View File

@@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import inspect import inspect
from email.policy import default


from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger




+ 5
- 4
tests/pipelines/test_image_matting.py View File

@@ -9,25 +9,26 @@ import cv2
from modelscope.fileio import File from modelscope.fileio import File
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.hub import get_model_cache_dir




class ImageMattingTest(unittest.TestCase): class ImageMattingTest(unittest.TestCase):


def setUp(self) -> None: def setUp(self) -> None:
self.model_id = 'damo/image-matting-person'
self.model_id = 'damo/cv_unet_image-matting_damo'
# switch to False if downloading everytime is not desired # switch to False if downloading everytime is not desired
purge_cache = True purge_cache = True
if purge_cache: if purge_cache:
shutil.rmtree( shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True) get_model_cache_dir(self.model_id), ignore_errors=True)


def test_run(self):
@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb' '.com/data/test/maas/image_matting/matting_person.pb'
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model_file = osp.join(tmp_dir, 'matting_person.pb')
model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE)
with open(model_file, 'wb') as ofile: with open(model_file, 'wb') as ofile:
ofile.write(File.read(model_path)) ofile.write(File.read(model_path))
img_matting = pipeline(Tasks.image_matting, model=tmp_dir) img_matting = pipeline(Tasks.image_matting, model=tmp_dir)


+ 0
- 3
tests/utils/test_config.py View File

@@ -1,11 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import argparse import argparse
import os.path as osp
import tempfile import tempfile
import unittest import unittest
from pathlib import Path


from modelscope.fileio import dump, load
from modelscope.utils.config import Config from modelscope.utils.config import Config


obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}


Loading…
Cancel
Save