yingda.chen 3 years ago
parent
commit
0acbfe1663
6 changed files with 32 additions and 15 deletions
  1. +0
    -1
      modelscope/hub/api.py
  2. +13
    -0
      modelscope/hub/constants.py
  3. +9
    -4
      modelscope/models/base.py
  4. +3
    -2
      modelscope/utils/hub.py
  5. +3
    -5
      tests/hub/test_hub_examples.py
  6. +4
    -3
      tests/hub/test_hub_operation.py

+ 0
- 1
modelscope/hub/api.py View File

@@ -1,4 +1,3 @@
import imp
import os import os
import pickle import pickle
import subprocess import subprocess


+ 13
- 0
modelscope/hub/constants.py View File

@@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo'
MODEL_ID_SEPARATOR = '/' MODEL_ID_SEPARATOR = '/'


LOGGER_NAME = 'ModelScopeHub' LOGGER_NAME = 'ModelScopeHub'


class Licenses(object):
APACHE_V2 = 'Apache License 2.0'
GPL = 'GPL'
LGPL = 'LGPL'
MIT = 'MIT'


class ModelVisibility(object):
PRIVATE = 1
INTERNAL = 3
PUBLIC = 5

+ 9
- 4
modelscope/models/base.py View File

@@ -2,7 +2,7 @@


import os.path as osp import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Union
from typing import Dict, Optional, Union


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model from modelscope.models.builder import build_model
@@ -42,13 +42,18 @@ class Model(ABC):
return input return input


@classmethod @classmethod
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
""" Instantiate a model from local directory or remote model repo
def from_pretrained(cls,
model_name_or_path: str,
revision: Optional[str] = 'master',
*model_args,
**kwargs):
""" Instantiate a model from local directory or remote model repo. Note
that when loading from remote, the model revision can be specified.
""" """
if osp.exists(model_name_or_path): if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path local_model_dir = model_name_or_path
else: else:
local_model_dir = snapshot_download(model_name_or_path)
local_model_dir = snapshot_download(model_name_or_path, revision)
logger.info(f'initialize model from {local_model_dir}') logger.info(f'initialize model from {local_model_dir}')
cfg = Config.from_file( cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION)) osp.join(local_model_dir, ModelFile.CONFIGURATION))


+ 3
- 2
modelscope/utils/hub.py View File

@@ -6,6 +6,7 @@ from typing import List, Optional, Union


from requests import HTTPError from requests import HTTPError


from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.config import Config from modelscope.utils.config import Config
@@ -16,8 +17,8 @@ def create_model_if_not_exist(
api, api,
model_id: str, model_id: str,
chinese_name: str, chinese_name: str,
visibility: Optional[int] = 5, # 1-private, 5-public
license: Optional[str] = 'apache-2.0',
visibility: Optional[int] = ModelVisibility.PUBLIC,
license: Optional[str] = Licenses.APACHE_V2,
revision: Optional[str] = 'master'): revision: Optional[str] = 'master'):
exists = True exists = True
try: try:


+ 3
- 5
tests/hub/test_hub_examples.py View File

@@ -1,9 +1,9 @@
import unittest import unittest


from maas_hub.maas_api import MaasApi

from modelscope.hub.api import HubApi
from modelscope.utils.hub import create_model_if_not_exist from modelscope.utils.hub import create_model_if_not_exist


# note this is temporary before official account management is ready
USER_NAME = 'maasadmin' USER_NAME = 'maasadmin'
PASSWORD = '12345678' PASSWORD = '12345678'


@@ -11,8 +11,7 @@ PASSWORD = '12345678'
class HubExampleTest(unittest.TestCase): class HubExampleTest(unittest.TestCase):


def setUp(self): def setUp(self):
self.api = MaasApi()
# note this is temporary before official account management is ready
self.api = HubApi()
self.api.login(USER_NAME, PASSWORD) self.api.login(USER_NAME, PASSWORD)


@unittest.skip('to be used for local test only') @unittest.skip('to be used for local test only')
@@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase):
model_chinese_name = '达摩卡通化模型' model_chinese_name = '达摩卡通化模型'
model_org = 'damo' model_org = 'damo'
model_id = '%s/%s' % (model_org, model_name) model_id = '%s/%s' % (model_org, model_name)

created = create_model_if_not_exist(self.api, model_id, created = create_model_if_not_exist(self.api, model_id,
model_chinese_name) model_chinese_name)
if not created: if not created:


+ 4
- 3
tests/hub/test_hub_operation.py View File

@@ -4,7 +4,8 @@ import tempfile
import unittest import unittest
import uuid import uuid


from modelscope.hub.api import HubApi
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
@@ -31,8 +32,8 @@ class HubOperationTest(unittest.TestCase):
self.api.create_model( self.api.create_model(
model_id=self.model_id, model_id=self.model_id,
chinese_name=model_chinese_name, chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
temporary_dir = tempfile.mkdtemp() temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name) self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id) repo = Repository(self.model_dir, clone_from=self.model_id)


Loading…
Cancel
Save