Browse Source

Merge branch 'master' into nlp/space/gen

master
ly119399 3 years ago
parent
commit
ccbbd959a7
5 changed files with 46 additions and 22 deletions
  1. +2
    -2
      modelscope/pipelines/cv/image_matting_pipeline.py
  2. +17
    -12
      modelscope/pydatasets/py_dataset.py
  3. +17
    -1
      modelscope/utils/constant.py
  4. +3
    -2
      tests/pipelines/test_image_matting.py
  5. +7
    -5
      tests/pipelines/test_text_classification.py

+ 2
- 2
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -7,7 +7,7 @@ import PIL

from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import TF_GRAPH_FILE, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
@@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline):
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
model_path = osp.join(self.model, 'matting_person.pb')
model_path = osp.join(self.model, TF_GRAPH_FILE)

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True


+ 17
- 12
modelscope/pydatasets/py_dataset.py View File

@@ -1,9 +1,9 @@
import logging
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
Union)

from datasets import Dataset, load_dataset

from modelscope.utils.constant import Hubs
from modelscope.utils.logger import get_logger

logger = get_logger()
@@ -41,17 +41,17 @@ class PyDataset:
return dataset

@staticmethod
def load(
path: Union[str, list],
target: Optional[str] = None,
version: Optional[str] = None,
name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> 'PyDataset':
def load(path: Union[str, list],
target: Optional[str] = None,
version: Optional[str] = None,
name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str,
Union[str,
Sequence[str]]]]] = None,
hub: Optional[Hubs] = None) -> 'PyDataset':
"""Load a PyDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
Args:

@@ -62,10 +62,15 @@ class PyDataset:
data_dir (str, optional): Defining the data_dir of the dataset configuration. I
data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s).
split (str, optional): Which split of the data to load.
hub (Hubs, optional): When loading from a remote hub, where it is from

Returns:
PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset.
"""
if Hubs.modelscope == hub:
# TODO: parse data meta information from modelscope hub
# and possibly download data files to local (and update path)
print('getting data from modelscope hub')
if isinstance(path, str):
dataset = load_dataset(
path,


+ 17
- 1
modelscope/utils/constant.py View File

@@ -59,14 +59,30 @@ class Tasks(object):


class InputFields(object):
""" Names for input data fileds in the input data for pipelines
""" Names for input data fields in the input data for pipelines
"""
img = 'img'
text = 'text'
audio = 'audio'


class Hubs(object):
""" Source from which an entity (such as a Dataset or Model) is stored
"""
modelscope = 'modelscope'
huggingface = 'huggingface'


# configuration filename
# in order to avoid conflict with huggingface
# config file we use maas_config instead
CONFIGFILE = 'maas_config.json'

README_FILE = 'README.md'
TF_SAVED_MODEL_FILE = 'saved_model.pb'
TF_GRAPH_FILE = 'tf_graph.pb'
TF_CHECKPOINT_FOLDER = 'tf_ckpts'
TF_CHECKPOINT_FILE = 'checkpoint'
TORCH_MODEL_FILE = 'pytorch_model.bin'
TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch'

+ 3
- 2
tests/pipelines/test_image_matting.py View File

@@ -16,14 +16,15 @@ from modelscope.utils.hub import get_model_cache_dir
class ImageMattingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/image-matting-person'
self.model_id = 'damo/cv_unet_image-matting_damo'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)

def test_run(self):
@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb'
with tempfile.TemporaryDirectory() as tmp_dir:


+ 7
- 5
tests/pipelines/test_text_classification.py View File

@@ -10,7 +10,7 @@ from modelscope.models.nlp import BertForSequenceClassification
from modelscope.pipelines import SequenceClassificationPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import Hubs, Tasks
from modelscope.utils.hub import get_model_cache_dir


@@ -81,13 +81,15 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(
task=Tasks.text_classification, model=self.model_id)
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface))
self.printDataset(result)

def test_run_with_default_model(self):
text_classification = pipeline(task=Tasks.text_classification)
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface))
self.printDataset(result)

def test_run_with_dataset(self):
@@ -97,9 +99,9 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(
Tasks.text_classification, model=model, preprocessor=preprocessor)
# loaded from huggingface dataset
# TODO: add load_from parameter (an enum) LOAD_FROM.hugging_face
# TODO: rename parameter as dataset_name and subset_name
dataset = PyDataset.load('glue', name='sst2', target='sentence')
dataset = PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface)
result = text_classification(dataset)
self.printDataset(result)



Loading…
Cancel
Save