yingda.chen 3 years ago
parent
commit
0acbfe1663
6 changed files with 32 additions and 15 deletions
  1. +0
    -1
      modelscope/hub/api.py
  2. +13
    -0
      modelscope/hub/constants.py
  3. +9
    -4
      modelscope/models/base.py
  4. +3
    -2
      modelscope/utils/hub.py
  5. +3
    -5
      tests/hub/test_hub_examples.py
  6. +4
    -3
      tests/hub/test_hub_operation.py

+ 0
- 1
modelscope/hub/api.py View File

@@ -1,4 +1,3 @@
import imp
import os
import pickle
import subprocess


+ 13
- 0
modelscope/hub/constants.py View File

@@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo'
MODEL_ID_SEPARATOR = '/'

LOGGER_NAME = 'ModelScopeHub'


class Licenses(object):
APACHE_V2 = 'Apache License 2.0'
GPL = 'GPL'
LGPL = 'LGPL'
MIT = 'MIT'


class ModelVisibility(object):
PRIVATE = 1
INTERNAL = 3
PUBLIC = 5

+ 9
- 4
modelscope/models/base.py View File

@@ -2,7 +2,7 @@

import os.path as osp
from abc import ABC, abstractmethod
from typing import Dict, Union
from typing import Dict, Optional, Union

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
@@ -42,13 +42,18 @@ class Model(ABC):
return input

@classmethod
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
""" Instantiate a model from local directory or remote model repo
def from_pretrained(cls,
model_name_or_path: str,
revision: Optional[str] = 'master',
*model_args,
**kwargs):
""" Instantiate a model from local directory or remote model repo. Note
that when loading from remote, the model revision can be specified.
"""
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:
local_model_dir = snapshot_download(model_name_or_path)
local_model_dir = snapshot_download(model_name_or_path, revision)
logger.info(f'initialize model from {local_model_dir}')
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))


+ 3
- 2
modelscope/utils/hub.py View File

@@ -6,6 +6,7 @@ from typing import List, Optional, Union

from requests import HTTPError

from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.config import Config
@@ -16,8 +17,8 @@ def create_model_if_not_exist(
api,
model_id: str,
chinese_name: str,
visibility: Optional[int] = 5, # 1-private, 5-public
license: Optional[str] = 'apache-2.0',
visibility: Optional[int] = ModelVisibility.PUBLIC,
license: Optional[str] = Licenses.APACHE_V2,
revision: Optional[str] = 'master'):
exists = True
try:


+ 3
- 5
tests/hub/test_hub_examples.py View File

@@ -1,9 +1,9 @@
import unittest

from maas_hub.maas_api import MaasApi

from modelscope.hub.api import HubApi
from modelscope.utils.hub import create_model_if_not_exist

# note this is temporary before official account management is ready
USER_NAME = 'maasadmin'
PASSWORD = '12345678'

@@ -11,8 +11,7 @@ PASSWORD = '12345678'
class HubExampleTest(unittest.TestCase):

def setUp(self):
self.api = MaasApi()
# note this is temporary before official account management is ready
self.api = HubApi()
self.api.login(USER_NAME, PASSWORD)

@unittest.skip('to be used for local test only')
@@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase):
model_chinese_name = '达摩卡通化模型'
model_org = 'damo'
model_id = '%s/%s' % (model_org, model_name)

created = create_model_if_not_exist(self.api, model_id,
model_chinese_name)
if not created:


+ 4
- 3
tests/hub/test_hub_operation.py View File

@@ -4,7 +4,8 @@ import tempfile
import unittest
import uuid

from modelscope.hub.api import HubApi
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
@@ -31,8 +32,8 @@ class HubOperationTest(unittest.TestCase):
self.api.create_model(
model_id=self.model_id,
chinese_name=model_chinese_name,
visibility=5, # 1-private, 5-public
license='apache-2.0')
visibility=ModelVisibility.PUBLIC,
license=Licenses.APACHE_V2)
temporary_dir = tempfile.mkdtemp()
self.model_dir = os.path.join(temporary_dir, self.model_name)
repo = Repository(self.model_dir, clone_from=self.model_id)


Loading…
Cancel
Save