Browse Source

merge with master

master
智丞 3 years ago
parent
commit
dadc07f492
41 changed files with 905 additions and 147 deletions
  1. +3
    -0
      data/test/images/image_mplug_vqa.jpg
  2. +61
    -23
      modelscope/hub/api.py
  3. +1
    -1
      modelscope/hub/constants.py
  4. +19
    -0
      modelscope/hub/errors.py
  5. +9
    -7
      modelscope/hub/file_download.py
  6. +8
    -0
      modelscope/hub/git.py
  7. +8
    -4
      modelscope/hub/repository.py
  8. +7
    -9
      modelscope/hub/snapshot_download.py
  9. +6
    -2
      modelscope/hub/utils/caching.py
  10. +7
    -1
      modelscope/metainfo.py
  11. +2
    -2
      modelscope/models/__init__.py
  12. +2
    -0
      modelscope/models/multi_modal/__init__.py
  13. +46
    -0
      modelscope/models/multi_modal/mplug_for_visual_question_answering.py
  14. +1
    -0
      modelscope/models/nlp/__init__.py
  15. +50
    -0
      modelscope/models/nlp/sbert_for_zero_shot_classification.py
  16. +1
    -1
      modelscope/msdatasets/config.py
  17. +41
    -15
      modelscope/msdatasets/ms_dataset.py
  18. +33
    -15
      modelscope/msdatasets/utils/ms_api.py
  19. +40
    -15
      modelscope/pipelines/base.py
  20. +7
    -1
      modelscope/pipelines/builder.py
  21. +1
    -0
      modelscope/pipelines/multi_modal/__init__.py
  22. +65
    -0
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py
  23. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  24. +97
    -0
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  25. +7
    -0
      modelscope/pipelines/outputs.py
  26. +1
    -1
      modelscope/preprocessors/__init__.py
  27. +45
    -0
      modelscope/preprocessors/multi_modal.py
  28. +45
    -1
      modelscope/preprocessors/nlp.py
  29. +11
    -1
      modelscope/utils/constant.py
  30. +3
    -2
      modelscope/utils/hub.py
  31. +1
    -1
      modelscope/version.py
  32. +1
    -1
      requirements/nlp.txt
  33. +35
    -7
      tests/hub/test_hub_operation.py
  34. +85
    -0
      tests/hub/test_hub_private_files.py
  35. +4
    -5
      tests/hub/test_hub_private_repository.py
  36. +5
    -19
      tests/hub/test_hub_repository.py
  37. +14
    -10
      tests/msdatasets/test_ms_dataset.py
  38. +2
    -1
      tests/pipelines/test_image_matting.py
  39. +6
    -2
      tests/pipelines/test_text_classification.py
  40. +60
    -0
      tests/pipelines/test_visual_question_answering.py
  41. +64
    -0
      tests/pipelines/test_zero_shot_classification.py

+ 3
- 0
data/test/images/image_mplug_vqa.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692
size 181725

+ 61
- 23
modelscope/hub/api.py View File

@@ -9,7 +9,7 @@ import requests

from modelscope.utils.logger import get_logger
from .constants import MODELSCOPE_URL_SCHEME
from .errors import NotExistError, is_ok, raise_on_error
from .errors import InvalidParameter, NotExistError, is_ok, raise_on_error
from .utils.utils import (get_endpoint, get_gitlab_domain,
model_id_to_group_owner_name)

@@ -61,17 +61,21 @@ class HubApi:

return d['Data']['AccessToken'], cookies

def create_model(self, model_id: str, chinese_name: str, visibility: int,
license: str) -> str:
def create_model(
self,
model_id: str,
visibility: str,
license: str,
chinese_name: Optional[str] = None,
) -> str:
"""
Create model repo at ModelScopeHub

Args:
model_id:(`str`): The model id
chinese_name(`str`): chinese name of the model
visibility(`int`): visibility of the model(1-private, 3-internal, 5-public)
license(`str`): license of the model, candidates can be found at: TBA

visibility(`int`): visibility of the model(1-private, 5-public), default public.
license(`str`): license of the model, default none.
chinese_name(`str`, *optional*): chinese name of the model
Returns:
name of the model created

@@ -79,6 +83,8 @@ class HubApi:
model_id = {owner}/{name}
</Tip>
"""
if model_id is None:
raise InvalidParameter('model_id is required!')
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
@@ -151,11 +157,33 @@ class HubApi:
else:
r.raise_for_status()

def _check_cookie(self,
use_cookies: Union[bool,
CookieJar] = False) -> CookieJar:
cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
return cookies

def get_model_branches_and_tags(
self,
model_id: str,
use_cookies: Union[bool, CookieJar] = False
) -> Tuple[List[str], List[str]]:
cookies = ModelScopeConfig.get_cookies()
"""Get model branch and tags.

Args:
model_id (str): The model id
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
Returns:
Tuple[List[str], List[str]]: _description_
"""
cookies = self._check_cookie(use_cookies)

path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = requests.get(path, cookies=cookies)
@@ -169,23 +197,33 @@ class HubApi:
] if info['RevisionMap']['Tags'] else []
return branches, tags

def get_model_files(
self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False) -> List[dict]:
def get_model_files(self,
model_id: str,
revision: Optional[str] = 'master',
root: Optional[str] = None,
recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False,
is_snapshot: Optional[bool] = True) -> List[dict]:
"""List the models files.

cookies = None
if isinstance(use_cookies, CookieJar):
cookies = use_cookies
elif use_cookies:
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
Args:
model_id (str): The model id
revision (Optional[str], optional): The branch or tag name. Defaults to 'master'.
root (Optional[str], optional): The root path. Defaults to None.
recursive (Optional[str], optional): Is recurive list files. Defaults to False.
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, will
will load cookie from local. Defaults to False.
is_snapshot(Optional[bool], optional): when snapshot_download set to True, otherwise False.

path = f'{self.endpoint}/api/v1/models/{model_id}/repo/files?Revision={revision}&Recursive={recursive}'
Raises:
ValueError: If user_cookies is True, but no local cookie.

Returns:
List[dict]: Model file list.
"""
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s&Snapshot=%s' % (
self.endpoint, model_id, revision, recursive, is_snapshot)
cookies = self._check_cookie(use_cookies)
if root is not None:
path = path + f'&Root={root}'



+ 1
- 1
modelscope/hub/constants.py View File

@@ -1,5 +1,5 @@
MODELSCOPE_URL_SCHEME = 'http://'
DEFAULT_MODELSCOPE_DOMAIN = '101.201.119.157:32330'
DEFAULT_MODELSCOPE_DOMAIN = '47.94.223.21:31090'
DEFAULT_MODELSCOPE_GITLAB_DOMAIN = '101.201.119.157:31102'

DEFAULT_MODELSCOPE_GROUP = 'damo'


+ 19
- 0
modelscope/hub/errors.py View File

@@ -10,6 +10,10 @@ class GitError(Exception):
pass


class InvalidParameter(Exception):
pass


def is_ok(rsp):
""" Check the request is ok

@@ -32,3 +36,18 @@ def raise_on_error(rsp):
return True
else:
raise RequestError(rsp['Message'])


# TODO use raise_on_error instead if modelhub and datahub response have uniform structures,
def datahub_raise_on_error(url, rsp):
"""If response error, raise exception

Args:
rsp (_type_): The server response
"""
if rsp.get('Code') == 200:
return True
else:
raise RequestError(
f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}"
)

+ 9
- 7
modelscope/hub/file_download.py View File

@@ -7,6 +7,7 @@ import tempfile
import time
from functools import partial
from hashlib import sha256
from http.cookiejar import CookieJar
from pathlib import Path
from typing import BinaryIO, Dict, Optional, Union
from uuid import uuid4
@@ -107,7 +108,9 @@ def model_file_download(

_api = HubApi()
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
branches, tags = _api.get_model_branches_and_tags(model_id)
cookies = ModelScopeConfig.get_cookies()
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
file_to_download_info = None
is_commit_id = False
if revision in branches or revision in tags: # The revision is version or tag,
@@ -117,18 +120,19 @@ def model_file_download(
model_id=model_id,
revision=revision,
recursive=True,
)
use_cookies=False if cookies is None else cookies,
is_snapshot=False)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue

if model_file['Path'] == file_path:
model_file['Branch'] = revision
if cache.exists(model_file):
return cache.get_file_by_info(model_file)
else:
file_to_download_info = model_file
break

if file_to_download_info is None:
raise NotExistError('The file path: %s not exist in: %s' %
@@ -141,8 +145,6 @@ def model_file_download(
return cached_file_path # the file is in cache.
is_commit_id = True
# we need to download again
# TODO: skip using JWT for authorization, use cookie instead
cookies = ModelScopeConfig.get_cookies()
url_to_download = get_file_download_url(model_id, file_path, revision)
file_to_download_info = {
'Path': file_path,
@@ -202,7 +204,7 @@ def http_get_file(
url: str,
local_dir: str,
file_name: str,
cookies: Dict[str, str],
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
"""
@@ -217,7 +219,7 @@ def http_get_file(
local directory where the downloaded file stores
file_name(`str`):
name of the file stored in `local_dir`
cookies(`Dict[str, str]`):
cookies(`CookieJar`):
cookies used to authentication the user, which is used for downloading private repos
headers(`Optional[Dict[str, str]] = None`):
http headers to carry necessary info when requesting the remote file


+ 8
- 0
modelscope/hub/git.py View File

@@ -70,6 +70,14 @@ class GitCommandWrapper(metaclass=Singleton):
except GitError:
return False

def git_lfs_install(self, repo_dir):
cmd = ['git', '-C', repo_dir, 'lfs', 'install']
try:
self._run_git_command(*cmd)
return True
except GitError:
return False

def clone(self,
repo_base_dir: str,
token: str,


+ 8
- 4
modelscope/hub/repository.py View File

@@ -1,7 +1,7 @@
import os
from typing import List, Optional

from modelscope.hub.errors import GitError
from modelscope.hub.errors import GitError, InvalidParameter
from modelscope.utils.logger import get_logger
from .api import ModelScopeConfig
from .constants import MODELSCOPE_URL_SCHEME
@@ -49,6 +49,8 @@ class Repository:
git_wrapper = GitCommandWrapper()
if not git_wrapper.is_lfs_installed():
logger.error('git lfs is not installed, please install.')
else:
git_wrapper.git_lfs_install(self.model_dir) # init repo lfs

self.git_wrapper = GitCommandWrapper(git_path)
os.makedirs(self.model_dir, exist_ok=True)
@@ -74,8 +76,6 @@ class Repository:

def push(self,
commit_message: str,
files: List[str] = list(),
all_files: bool = False,
branch: Optional[str] = 'master',
force: bool = False):
"""Push local to remote, this method will do.
@@ -86,8 +86,12 @@ class Repository:
commit_message (str): commit message
revision (Optional[str], optional): which branch to push. Defaults to 'master'.
"""
if commit_message is None:
msg = 'commit_message must be provided!'
raise InvalidParameter(msg)
url = self.git_wrapper.get_repo_remote_url(self.model_dir)
self.git_wrapper.add(self.model_dir, files, all_files)
self.git_wrapper.pull(self.model_dir)
self.git_wrapper.add(self.model_dir, all_files=True)
self.git_wrapper.commit(self.model_dir, commit_message)
self.git_wrapper.push(
repo_dir=self.model_dir,


+ 7
- 9
modelscope/hub/snapshot_download.py View File

@@ -20,8 +20,7 @@ def snapshot_download(model_id: str,
revision: Optional[str] = 'master',
cache_dir: Union[str, Path, None] = None,
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
private: Optional[bool] = False) -> str:
local_files_only: Optional[bool] = False) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which
@@ -79,8 +78,10 @@ def snapshot_download(model_id: str,
# make headers
headers = {'user-agent': http_user_agent(user_agent=user_agent, )}
_api = HubApi()
cookies = ModelScopeConfig.get_cookies()
# get file list from model repo
branches, tags = _api.get_model_branches_and_tags(model_id)
branches, tags = _api.get_model_branches_and_tags(
model_id, use_cookies=False if cookies is None else cookies)
if revision not in branches and revision not in tags:
raise NotExistError('The specified branch or tag : %s not exist!'
% revision)
@@ -89,11 +90,8 @@ def snapshot_download(model_id: str,
model_id=model_id,
revision=revision,
recursive=True,
use_cookies=private)

cookies = None
if private:
cookies = ModelScopeConfig.get_cookies()
use_cookies=False if cookies is None else cookies,
is_snapshot=True)

for model_file in model_files:
if model_file['Type'] == 'tree':
@@ -116,7 +114,7 @@ def snapshot_download(model_id: str,
local_dir=tempfile.gettempdir(),
file_name=model_file['Name'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
cookies=cookies)
# put file to cache
cache.put_file(
model_file,


+ 6
- 2
modelscope/hub/utils/caching.py View File

@@ -101,8 +101,9 @@ class FileSystemCache(object):
Args:
key (dict): The cache key.
"""
self.cached_files.remove(key)
self.save_cached_files()
if key in self.cached_files:
self.cached_files.remove(key)
self.save_cached_files()

def exists(self, key):
for cache_file in self.cached_files:
@@ -204,6 +205,7 @@ class ModelFileSystemCache(FileSystemCache):
return orig_path
else:
self.remove_key(cached_file)
break

return None

@@ -230,6 +232,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_key['Revision'].startswith(key['Revision'])
or key['Revision'].startswith(cached_key['Revision'])):
is_exists = True
break
file_path = os.path.join(self.cache_root_location,
model_file_info['Path'])
if is_exists:
@@ -253,6 +256,7 @@ class ModelFileSystemCache(FileSystemCache):
cached_file['Path'])
if os.path.exists(file_path):
os.remove(file_path)
break

def put_file(self, model_file_info, model_file_location):
"""Put model on model_file_location to cache, the model first download to /tmp, and move to cache.


+ 7
- 1
modelscope/metainfo.py View File

@@ -28,6 +28,7 @@ class Models(object):
# multi-modal models
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
mplug = 'mplug'


class Pipelines(object):
@@ -57,6 +58,7 @@ class Pipelines(object):
nli = 'nli'
dialog_intent_prediction = 'dialog-intent-prediction'
dialog_modeling = 'dialog-modeling'
zero_shot_classification = 'zero-shot-classification'

# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
@@ -65,8 +67,9 @@ class Pipelines(object):
kws_kwsbp = 'kws-kwsbp'

# multi-modal tasks
image_caption = 'image-caption'
image_caption = 'image-captioning'
multi_modal_embedding = 'multi-modal-embedding'
visual_question_answering = 'visual-question-answering'


class Trainers(object):
@@ -104,6 +107,8 @@ class Preprocessors(object):
sen_cls_tokenizer = 'sen-cls-tokenizer'
dialog_intent_preprocessor = 'dialog-intent-preprocessor'
dialog_modeling_preprocessor = 'dialog-modeling-preprocessor'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'
@@ -112,3 +117,4 @@ class Preprocessors(object):

# multi-modal
ofa_image_caption = 'ofa-image-caption'
mplug_visual_question_answering = 'mplug-visual-question-answering'

+ 2
- 2
modelscope/models/__init__.py View File

@@ -9,5 +9,5 @@ from .builder import MODELS, build_model
from .multi_modal import OfaForImageCaptioning
from .nlp import (BertForMaskedLM, BertForSequenceClassification, SbertForNLI,
SbertForSentenceSimilarity, SbertForSentimentClassification,
SbertForTokenClassification, StructBertForMaskedLM,
VecoForMaskedLM)
SbertForTokenClassification, SbertForZeroShotClassification,
StructBertForMaskedLM, VecoForMaskedLM)

+ 2
- 0
modelscope/models/multi_modal/__init__.py View File

@@ -1,2 +1,4 @@
from .clip.clip_model import CLIPForMultiModalEmbedding
from .image_captioning_model import OfaForImageCaptioning
from .mplug_for_visual_question_answering import \
MPlugForVisualQuestionAnswering

+ 46
- 0
modelscope/models/multi_modal/mplug_for_visual_question_answering.py View File

@@ -0,0 +1,46 @@
from typing import Dict

from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['MPlugForVisualQuestionAnswering']


@MODELS.register_module(
Tasks.visual_question_answering, module_name=Models.mplug)
class MPlugForVisualQuestionAnswering(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the mplug model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
from sofa.models.mplug import MPlugForVisualQuestionAnswering
self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Tensor]: results
Example:
{
'predictions': Tensor([[1377, 4959, 2785, 6392...])]),
}
"""

return self.model(**input)[0]

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -5,5 +5,6 @@ from .sbert_for_nli import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_sentiment_classification import * # noqa F403
from .sbert_for_token_classification import * # noqa F403
from .sbert_for_zero_shot_classification import * # noqa F403
from .space.dialog_intent_prediction_model import * # noqa F403
from .space.dialog_modeling_model import * # noqa F403

+ 50
- 0
modelscope/models/nlp/sbert_for_zero_shot_classification.py View File

@@ -0,0 +1,50 @@
from typing import Any, Dict

import numpy as np

from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ..base import Model
from ..builder import MODELS

__all__ = ['SbertForZeroShotClassification']


@MODELS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForZeroShotClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the zero shot classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
from sofa import SbertForSequenceClassification
self.model = SbertForSequenceClassification.from_pretrained(model_dir)

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
outputs = self.model(**input)
logits = outputs['logits'].numpy()
res = {'logits': logits}
return res

+ 1
- 1
modelscope/msdatasets/config.py View File

@@ -19,4 +19,4 @@ DOWNLOADED_DATASETS_PATH = Path(
os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH))

MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT',
'http://101.201.119.157:31752')
'http://47.94.223.21:31752')

+ 41
- 15
modelscope/msdatasets/ms_dataset.py View File

@@ -3,7 +3,7 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)

import numpy as np
from datasets import Dataset
from datasets import Dataset, DatasetDict
from datasets import load_dataset as hf_load_dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
@@ -12,7 +12,7 @@ from datasets.utils.file_utils import (is_relative_path,

from modelscope.msdatasets.config import MS_DATASETS_CACHE
from modelscope.msdatasets.utils.ms_api import MsApi
from modelscope.utils.constant import Hubs
from modelscope.utils.constant import DownloadMode, Hubs
from modelscope.utils.logger import get_logger

logger = get_logger()
@@ -34,6 +34,10 @@ class MsDataset:

def __init__(self, hf_ds: Dataset, target: Optional[str] = None):
self._hf_ds = hf_ds
if target is not None and target not in self._hf_ds.features:
raise TypeError(
f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}'
)
self.target = target

def __iter__(self):
@@ -48,17 +52,23 @@ class MsDataset:

@classmethod
def from_hf_dataset(cls,
hf_ds: Dataset,
hf_ds: Union[Dataset, DatasetDict],
target: str = None) -> Union[dict, 'MsDataset']:
if isinstance(hf_ds, Dataset):
return cls(hf_ds, target)
if len(hf_ds.keys()) == 1:
return cls(next(iter(hf_ds.values())), target)
return {k: cls(v, target) for k, v in hf_ds.items()}
elif isinstance(hf_ds, DatasetDict):
if len(hf_ds.keys()) == 1:
return cls(next(iter(hf_ds.values())), target)
return {k: cls(v, target) for k, v in hf_ds.items()}
else:
raise TypeError(
f'"hf_ds" must be a Dataset or DatasetDict, but got {type(hf_ds)}'
)

@staticmethod
def load(
dataset_name: Union[str, list],
namespace: Optional[str] = None,
target: Optional[str] = None,
version: Optional[str] = None,
hub: Optional[Hubs] = Hubs.modelscope,
@@ -67,23 +77,32 @@ class MsDataset:
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
Sequence[str]]]]] = None,
download_mode: Optional[DownloadMode] = DownloadMode.
REUSE_DATASET_IF_EXISTS
) -> Union[dict, 'MsDataset']:
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
Args:

dataset_name (str): Path or name of the dataset.
namespace(str, optional): Namespace of the dataset. It should not be None, if you load a remote dataset
from Hubs.modelscope,
target (str, optional): Name of the column to output.
version (str, optional): Version of the dataset script to load:
subset_name (str, optional): Defining the subset_name of the dataset.
data_dir (str, optional): Defining the data_dir of the dataset configuration. I
data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s).
split (str, optional): Which split of the data to load.
hub (Hubs, optional): When loading from a remote hub, where it is from
hub (Hubs or str, optional): When loading from a remote hub, where it is from. default Hubs.modelscope
download_mode (DownloadMode or str, optional): How to treat existing datasets. default
DownloadMode.REUSE_DATASET_IF_EXISTS

Returns:
MsDataset (obj:`MsDataset`): MsDataset object for a certain dataset.
"""
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
hub = Hubs(hub or Hubs.modelscope)
if hub == Hubs.huggingface:
dataset = hf_load_dataset(
dataset_name,
@@ -91,21 +110,25 @@ class MsDataset:
revision=version,
split=split,
data_dir=data_dir,
data_files=data_files)
data_files=data_files,
download_mode=download_mode.value)
return MsDataset.from_hf_dataset(dataset, target=target)
else:
elif hub == Hubs.modelscope:
return MsDataset._load_ms_dataset(
dataset_name,
namespace=namespace,
target=target,
subset_name=subset_name,
version=version,
split=split,
data_dir=data_dir,
data_files=data_files)
data_files=data_files,
download_mode=download_mode)

@staticmethod
def _load_ms_dataset(
dataset_name: Union[str, list],
namespace: Optional[str] = None,
target: Optional[str] = None,
version: Optional[str] = None,
subset_name: Optional[str] = None,
@@ -113,17 +136,19 @@ class MsDataset:
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
Sequence[str]]]]] = None,
download_mode: Optional[DownloadMode] = None
) -> Union[dict, 'MsDataset']:
if isinstance(dataset_name, str):
use_hf = False
if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \
(os.path.isfile(dataset_name) and dataset_name.endswith('.py')):
use_hf = True
elif is_relative_path(dataset_name):
elif is_relative_path(dataset_name) and dataset_name.count(
'/') == 0:
ms_api = MsApi()
dataset_scripts = ms_api.fetch_dataset_scripts(
dataset_name, version)
dataset_name, namespace, download_mode, version)
if 'py' in dataset_scripts: # dataset copied from hf datasets
dataset_name = dataset_scripts['py'][0]
use_hf = True
@@ -140,7 +165,8 @@ class MsDataset:
split=split,
data_dir=data_dir,
data_files=data_files,
cache_dir=MS_DATASETS_CACHE)
cache_dir=MS_DATASETS_CACHE,
download_mode=download_mode.value)
else:
# TODO load from ms datahub
raise NotImplementedError(


+ 33
- 15
modelscope/msdatasets/utils/ms_api.py View File

@@ -1,11 +1,14 @@
import os
import shutil
from collections import defaultdict
from typing import Optional

import requests

from modelscope.hub.errors import NotExistError, datahub_raise_on_error
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH,
MS_HUB_ENDPOINT)
from modelscope.utils.constant import DownloadMode
from modelscope.utils.logger import get_logger

logger = get_logger()
@@ -27,23 +30,38 @@ class MsApi:

def fetch_dataset_scripts(self,
dataset_name: str,
version: Optional[str] = 'master',
force_download=False):
datahub_url = f'{self.endpoint}/api/v1/datasets?Query={dataset_name}'
r = requests.get(datahub_url)
r.raise_for_status()
dataset_list = r.json()['Data']
if len(dataset_list) == 0:
return None
dataset_id = dataset_list[0]['Id']
namespace: str,
download_mode: Optional[DownloadMode],
version: Optional[str] = 'master'):
if namespace is None:
raise ValueError(
f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}'
)
version = version or 'master'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}'
r = requests.get(datahub_url)
r.raise_for_status()
file_list = r.json()['Data']['Files']
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name,
version)
namespace, version)
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(
cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
r = requests.get(datahub_url)
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
dataset_id = resp['Data']['Id']
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}'
r = requests.get(datahub_url)
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
file_list = resp['Data']
if file_list is None:
raise NotExistError(
f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, '
f'version = {version}] dose not exist')

file_list = file_list['Files']
local_paths = defaultdict(list)
for file_info in file_list:
file_path = file_info['Path']
@@ -54,7 +72,7 @@ class MsApi:
r.raise_for_status()
content = r.json()['Data']['Content']
local_path = os.path.join(cache_dir, file_path)
if os.path.exists(local_path) and not force_download:
if os.path.exists(local_path):
logger.warning(
f"Reusing dataset {dataset_name}'s python file ({local_path})"
)


+ 40
- 15
modelscope/pipelines/base.py View File

@@ -74,33 +74,57 @@ class Pipeline(ABC):
self.preprocessor = preprocessor

def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]:
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function

# simple showcase, need to support iterator type for both tensorflow and pytorch
# input_dict = self._handle_input(input)

# sanitize the parameters
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
kwargs['preprocess_params'] = preprocess_params
kwargs['forward_params'] = forward_params
kwargs['postprocess_params'] = postprocess_params

if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))
output.append(self._process_single(ele, *args, **kwargs))

elif isinstance(input, MsDataset):
return self._process_iterator(input, *args, **post_kwargs)
return self._process_iterator(input, *args, **kwargs)

else:
output = self._process_single(input, *args, **post_kwargs)
output = self._process_single(input, *args, **kwargs)
return output

def _process_iterator(self, input: Input, *args, **post_kwargs):
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considered to be a normal classmethod with default implementation / output

Default Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
return {}, {}, pipeline_parameters

def _process_iterator(self, input: Input, *args, **kwargs):
for ele in input:
yield self._process_single(ele, *args, **post_kwargs)
yield self._process_single(ele, *args, **kwargs)

def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
out = self.preprocess(input, **preprocess_params)
out = self.forward(out, **forward_params)
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
return out

@@ -120,20 +144,21 @@ class Pipeline(ABC):
raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing')

def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it
"""
assert self.preprocessor is not None, 'preprocess method should be implemented'
assert not isinstance(self.preprocessor, List),\
'default implementation does not support using multiple preprocessors.'
return self.preprocessor(inputs)
return self.preprocessor(inputs, **preprocess_params)

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
assert self.model is not None, 'forward method should be implemented'
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
return self.model(inputs)
return self.model(inputs, **forward_params)

@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 7
- 1
modelscope/pipelines/builder.py View File

@@ -33,6 +33,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/bert-base-sst2'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.zero_shot_classification:
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.image_captioning: (Pipelines.image_caption,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:
@@ -45,7 +48,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_TAdaConv_action-recognition'),
Tasks.multi_modal_embedding:
(Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding')
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'),
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),
}




+ 1
- 0
modelscope/pipelines/multi_modal/__init__.py View File

@@ -1,2 +1,3 @@
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline

+ 65
- 0
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -0,0 +1,65 @@
from typing import Any, Dict, Optional, Union

import torch

from ...metainfo import Pipelines
from ...models import Model
from ...models.multi_modal import MPlugForVisualQuestionAnswering
from ...preprocessors import MPlugVisualQuestionAnsweringPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

__all__ = ['VisualQuestionAnsweringPipeline']


@PIPELINES.register_module(
Tasks.visual_question_answering,
module_name=Pipelines.visual_question_answering)
class VisualQuestionAnsweringPipeline(Pipeline):

def __init__(self,
model: Union[MPlugForVisualQuestionAnswering, str],
preprocessor: Optional[
MPlugVisualQuestionAnsweringPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a visual question answering pipeline for prediction

Args:
model (MPlugForVisualQuestionAnswering): a model instance
preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
MPlugForVisualQuestionAnswering) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = model.tokenizer

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))

pred_string = self.tokenizer.decode(inputs[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
return {'answer': pred_string}

+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -7,3 +7,4 @@ from .sentiment_classification_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import * # noqa F403

+ 97
- 0
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -0,0 +1,97 @@
import os
import uuid
from typing import Any, Dict, Union

import json
import numpy as np
import torch
from scipy.special import softmax

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SbertForZeroShotClassification
from ...preprocessors import ZeroShotClassificationPreprocessor
from ...utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES

__all__ = ['ZeroShotClassificationPipeline']


@PIPELINES.register_module(
Tasks.zero_shot_classification,
module_name=Pipelines.zero_shot_classification)
class ZeroShotClassificationPipeline(Pipeline):

def __init__(self,
model: Union[SbertForZeroShotClassification, str],
preprocessor: ZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \
'model must be a single str or SbertForZeroShotClassification'
model = model if isinstance(
model,
SbertForZeroShotClassification) else Model.from_pretrained(model)

self.entailment_id = 0
self.contradiction_id = 2

if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}

if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')

postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self,
inputs: Dict[str, Any],
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the prediction results
"""

logits = inputs['logits']
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1]
else:
logits = logits[..., self.entailment_id]
scores = softmax(logits, axis=-1)

reversed_index = list(reversed(scores.argsort()))
result = {
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
}
return result

+ 7
- 0
modelscope/pipelines/outputs.py View File

@@ -108,6 +108,13 @@ TASK_OUTPUTS = {
# }
Tasks.sentiment_classification: ['scores', 'labels'],

# zero-shot classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.zero_shot_classification: ['scores', 'labels'],

# nli result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],


+ 1
- 1
modelscope/preprocessors/__init__.py View File

@@ -6,7 +6,7 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .kws import WavToLists
from .multi_modal import OfaImageCaptionPreprocessor
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403
from .space.dialog_intent_prediction_preprocessor import * # noqa F403
from .space.dialog_modeling_preprocessor import * # noqa F403


+ 45
- 0
modelscope/preprocessors/multi_modal.py View File

@@ -16,6 +16,7 @@ from .image import load_image

__all__ = [
'OfaImageCaptionPreprocessor',
'MPlugVisualQuestionAnsweringPreprocessor',
]


@@ -110,3 +111,47 @@ class OfaImageCaptionPreprocessor(Preprocessor):
}
}
return sample


@PREPROCESSORS.register_module(
Fields.multi_modal,
module_name=Preprocessors.mplug_visual_question_answering)
class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via 'bert-base-uncased' tokenizer and configuration

"""
super().__init__(*args, **kwargs)

# tokenizer
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# load configuration
from sofa.models.mplug import CONFIG_NAME, MPlugConfig
config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME))

# Initialize transform
from torchvision import transforms
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

self.patch_resize_transform = transforms.Compose([
transforms.Resize((config.image_res, config.image_res),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image, question = data['image'], data['question']
image = Image.open(image).convert('RGB') if isinstance(image,
str) else image
image = self.patch_resize_transform(image)
image = torch.stack([image], dim=0)
question = self.tokenizer([question.lower()],
padding='longest',
return_tensors='pt')

return {'image': image, 'question': question, 'train': False}

+ 45
- 1
modelscope/preprocessors/nlp.py View File

@@ -15,7 +15,7 @@ __all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
'NLIPreprocessor', 'SentimentClassificationPreprocessor',
'FillMaskPreprocessor'
'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor'
]


@@ -421,3 +421,47 @@ class TokenClassifcationPreprocessor(Preprocessor):
'attention_mask': attention_mask,
'token_type_ids': token_type_ids
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer)
class ZeroShotClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str, hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
pairs = [[data, hypothesis_template.format(label)]
for label in candidate_labels]

features = self.tokenizer(
pairs,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_tensors='pt',
truncation_strategy='only_first')
return features

+ 11
- 1
modelscope/utils/constant.py View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import enum


class Fields(object):
@@ -51,6 +52,7 @@ class Tasks(object):
fill_mask = 'fill-mask'
summarization = 'summarization'
question_answering = 'question-answering'
zero_shot_classification = 'zero-shot-classification'

# audio tasks
auto_speech_recognition = 'auto-speech-recognition'
@@ -63,6 +65,7 @@ class Tasks(object):
visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis'
multi_modal_embedding = 'multi-modal-embedding'
visual_question_answering = 'visual-question-answering'


class InputFields(object):
@@ -73,13 +76,20 @@ class InputFields(object):
audio = 'audio'


class Hubs(object):
class Hubs(enum.Enum):
""" Source from which an entity (such as a Dataset or Model) is stored
"""
modelscope = 'modelscope'
huggingface = 'huggingface'


class DownloadMode(enum.Enum):
""" How to treat existing datasets
"""
REUSE_DATASET_IF_EXISTS = 'reuse_dataset_if_exists'
FORCE_REDOWNLOAD = 'force_redownload'


class ModelFile(object):
CONFIGURATION = 'configuration.json'
README = 'README.md'


+ 3
- 2
modelscope/utils/hub.py View File

@@ -31,9 +31,10 @@ def create_model_if_not_exist(
else:
api.create_model(
model_id=model_id,
chinese_name=chinese_name,
visibility=visibility,
license=license)
license=license,
chinese_name=chinese_name,
)
print(f'model {model_id} successfully created.')
return True



+ 1
- 1
modelscope/version.py View File

@@ -1 +1 @@
__version__ = '0.1.1'
__version__ = '0.2.1'

+ 1
- 1
requirements/nlp.txt View File

@@ -1,3 +1,3 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.2-py3-none-any.whl
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
spacy>=2.3.5

+ 35
- 7
tests/hub/test_hub_operation.py View File

@@ -3,6 +3,7 @@ import os
import tempfile
import unittest
import uuid
from shutil import rmtree

from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
@@ -23,7 +24,6 @@ download_model_file_name = 'test.bin'
class HubOperationTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.api.login(USER_NAME, PASSWORD)
@@ -31,19 +31,18 @@ class HubOperationTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
chinese_name=model_chinese_name,
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)
os.chdir(self.model_dir)
os.system("echo 'testtest'>%s"
% os.path.join(self.model_dir, 'test.bin'))
repo.push('add model', all_files=True)
% os.path.join(self.model_dir, download_model_file_name))
repo.push('add model')

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_model_repo_creation(self):
@@ -79,6 +78,35 @@ class HubOperationTest(unittest.TestCase):
mdtime2 = os.path.getmtime(downloaded_file_path)
assert mdtime1 == mdtime2

def test_download_public_without_login(self):
rmtree(ModelScopeConfig.path_credential)
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
temporary_dir = tempfile.mkdtemp()
downloaded_file = model_file_download(
model_id=self.model_id,
file_path=download_model_file_name,
cache_dir=temporary_dir)
assert os.path.exists(downloaded_file)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_delete_download_cache_file(self):
snapshot_path = snapshot_download(model_id=self.model_id)
downloaded_file_path = os.path.join(snapshot_path,
download_model_file_name)
assert os.path.exists(downloaded_file_path)
os.remove(downloaded_file_path)
# download again in cache
file_download_path = model_file_download(
model_id=self.model_id, file_path='README.md')
assert os.path.exists(file_download_path)
# deleted file need download again
file_download_path = model_file_download(
model_id=self.model_id, file_path=download_model_file_name)
assert os.path.exists(file_download_path)


if __name__ == '__main__':
unittest.main()

+ 85
- 0
tests/hub/test_hub_private_files.py View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import unittest
import uuid

from requests.exceptions import HTTPError

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile

USER_NAME = 'maasadmin'
PASSWORD = '12345678'
USER_NAME2 = 'sdkdev'

model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'


class HubPrivateFileDownloadTest(unittest.TestCase):

def setUp(self):
self.old_cwd = os.getcwd()
self.api = HubApi()
# note this is temporary before official account management is ready
self.token, _ = self.api.login(USER_NAME, PASSWORD)
self.model_name = uuid.uuid4().hex
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
)

def tearDown(self):
os.chdir(self.old_cwd)
self.api.delete_model(model_id=self.model_id)

def test_snapshot_download_private_model(self):
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))

def test_snapshot_download_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
snapshot_download(self.model_id)
self.api.login(USER_NAME, PASSWORD)

def test_download_file_private_model(self):
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)

def test_download_file_private_model_no_permission(self):
self.token, _ = self.api.login(USER_NAME2, PASSWORD)
with self.assertRaises(HTTPError):
model_file_download(self.model_id, ModelFile.README)
self.api.login(USER_NAME, PASSWORD)

def test_snapshot_download_local_only(self):
with self.assertRaises(ValueError):
snapshot_download(self.model_id, local_files_only=True)
snapshot_path = snapshot_download(self.model_id)
assert os.path.exists(os.path.join(snapshot_path, ModelFile.README))
snapshot_path = snapshot_download(self.model_id, local_files_only=True)
assert os.path.exists(snapshot_path)

def test_file_download_local_only(self):
with self.assertRaises(ValueError):
model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
file_path = model_file_download(self.model_id, ModelFile.README)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id, ModelFile.README, local_files_only=True)
assert os.path.exists(file_path)


if __name__ == '__main__':
unittest.main()

+ 4
- 5
tests/hub/test_hub_private_repository.py View File

@@ -5,6 +5,7 @@ import unittest
import uuid

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import GitError
from modelscope.hub.repository import Repository

@@ -16,9 +17,6 @@ model_chinese_name = '达摩卡通化模型'
model_org = 'unittest'
DEFAULT_GIT_PATH = 'git'

sample_model_url = 'https://mindscope.oss-cn-hangzhou.aliyuncs.com/test_models/mnist-12.onnx'
download_model_file_name = 'mnist-12.onnx'


class HubPrivateRepositoryTest(unittest.TestCase):

@@ -31,9 +29,10 @@ class HubPrivateRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PRIVATE, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
visibility=1, # 1-private, 5-public
license='apache-2.0')
)

def tearDown(self):
self.api.login(USER_NAME, PASSWORD)


+ 5
- 19
tests/hub/test_hub_repository.py View File

@@ -2,7 +2,6 @@
import os
import shutil
import tempfile
import time
import unittest
import uuid
from os.path import expanduser
@@ -10,6 +9,7 @@ from os.path import expanduser
from requests import delete

from modelscope.hub.api import HubApi
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
@@ -55,9 +55,10 @@ class HubRepositoryTest(unittest.TestCase):
self.model_id = '%s/%s' % (model_org, self.model_name)
self.api.create_model(
model_id=self.model_id,
visibility=ModelVisibility.PUBLIC, # 1-private, 5-public
license=Licenses.APACHE_V2,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
)
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)

@@ -81,27 +82,12 @@ class HubRepositoryTest(unittest.TestCase):
os.chdir(self.model_dir)
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
repo.push('test', all_files=True)
repo.push('test')
add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2)

def test_push_files(self):
repo = Repository(self.model_dir, clone_from=self.model_id)
assert os.path.exists(os.path.join(self.model_dir, 'README.md'))
os.system("echo '111'>%s" % os.path.join(self.model_dir, 'add1.py'))
os.system("echo '222'>%s" % os.path.join(self.model_dir, 'add2.py'))
os.system("echo '333'>%s" % os.path.join(self.model_dir, 'add3.py'))
repo.push('test', files=['add1.py', 'add2.py'], all_files=False)
add1 = model_file_download(self.model_id, 'add1.py')
assert os.path.exists(add1)
add2 = model_file_download(self.model_id, 'add2.py')
assert os.path.exists(add2)
with self.assertRaises(NotExistError) as cm:
model_file_download(self.model_id, 'add3.py')
print(cm.exception)


if __name__ == '__main__':
unittest.main()

+ 14
- 10
tests/msdatasets/test_ms_dataset.py View File

@@ -32,11 +32,12 @@ class ImgPreprocessor(Preprocessor):

class MsDatasetTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_basic(self):
ms_ds_full = MsDataset.load('squad')
ms_ds_full = MsDataset.load('squad', namespace='damotest')
ms_ds_full_hf = hfdata.load_dataset('squad')
ms_ds_train = MsDataset.load('squad', split='train')
ms_ds_train = MsDataset.load(
'squad', namespace='damotest', split='train')
ms_ds_train_hf = hfdata.load_dataset('squad', split='train')
ms_image_train = MsDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
@@ -48,7 +49,7 @@ class MsDatasetTest(unittest.TestCase):
print(next(iter(ms_ds_train)))
print(next(iter(ms_image_train)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_text(self):
model_id = 'damo/bert-base-sst2'
@@ -57,13 +58,14 @@ class MsDatasetTest(unittest.TestCase):
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = MsDataset.load('squad', split='train')
ms_ds_train = MsDataset.load(
'squad', namespace='damotest', split='train')
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_text(self):
import tensorflow as tf
@@ -74,7 +76,8 @@ class MsDatasetTest(unittest.TestCase):
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = MsDataset.load('squad', split='train')
ms_ds_train = MsDataset.load(
'squad', namespace='damotest', split='train')
tf_dataset = ms_ds_train.to_tf_dataset(
batch_size=5,
shuffle=True,
@@ -85,8 +88,8 @@ class MsDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = MsDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
ms_image_train = MsDataset.load(
'beans', namespace='damotest', split='train')
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(
image_path='image_file_path', label='labels'))
@@ -99,7 +102,8 @@ class MsDatasetTest(unittest.TestCase):
def test_to_tf_dataset_img(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
ms_image_train = MsDataset.load('beans', split='train')
ms_image_train = MsDataset.load(
'beans', namespace='damotest', split='train')
tf_dataset = ms_image_train.to_tf_dataset(
batch_size=5,
shuffle=True,


+ 2
- 1
tests/pipelines/test_image_matting.py View File

@@ -62,7 +62,8 @@ class ImageMattingTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
dataset = MsDataset.load('beans', split='train', target='image')
dataset = MsDataset.load(
'beans', namespace='damotest', split='train', target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
result = img_matting(dataset)
for i in range(10):


+ 6
- 2
tests/pipelines/test_text_classification.py View File

@@ -87,12 +87,16 @@ class SequenceClassificationTest(unittest.TestCase):
result = text_classification(dataset)
self.printDataset(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
text_classification = pipeline(task=Tasks.text_classification)
# loaded from modelscope dataset
dataset = MsDataset.load(
'squad', split='train', target='context', hub=Hubs.modelscope)
'squad',
namespace='damotest',
split='train',
target='context',
hub=Hubs.modelscope)
result = text_classification(dataset)
self.printDataset(result)



+ 60
- 0
tests/pipelines/test_visual_question_answering.py View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class VisualQuestionAnsweringTest(unittest.TestCase):
model_id = 'damo/mplug_visual-question-answering_coco_large_en'
input_vqa = {
'image': 'data/test/images/image_mplug_vqa.jpg',
'question': 'What is the woman doing?',
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path)
model = MPlugForVisualQuestionAnswering(cache_path)
pipeline1 = VisualQuestionAnsweringPipeline(
model, preprocessor=preprocessor)
pipeline2 = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
print(f"question: {self.input_vqa['question']}")
print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}")
print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}")

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
pipeline_vqa = pipeline(
task=Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
print(pipeline_vqa(self.input_vqa))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_vqa = pipeline(
Tasks.visual_question_answering, model=self.model_id)
print(pipeline_vqa(self.input_vqa))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_vqa = pipeline(task=Tasks.visual_question_answering)
print(pipeline_vqa(self.input_vqa))


if __name__ == '__main__':
unittest.main()

+ 64
- 0
tests/pipelines/test_zero_shot_classification.py View File

@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForZeroShotClassification
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光'
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
template = '这篇文章的标题是{}'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
model = SbertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer)

print(
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
)
print()
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save