diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index f4f31280..d102219b 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -1,4 +1,3 @@ -import imp import os import pickle import subprocess diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index a38f9afb..08f7c31d 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo' MODEL_ID_SEPARATOR = '/' LOGGER_NAME = 'ModelScopeHub' + + +class Licenses(object): + APACHE_V2 = 'Apache License 2.0' + GPL = 'GPL' + LGPL = 'LGPL' + MIT = 'MIT' + + +class ModelVisibility(object): + PRIVATE = 1 + INTERNAL = 3 + PUBLIC = 5 diff --git a/modelscope/models/base.py b/modelscope/models/base.py index cb6d2b0e..40929a21 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -2,7 +2,7 @@ import os.path as osp from abc import ABC, abstractmethod -from typing import Dict, Union +from typing import Dict, Optional, Union from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model @@ -42,13 +42,18 @@ class Model(ABC): return input @classmethod - def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): - """ Instantiate a model from local directory or remote model repo + def from_pretrained(cls, + model_name_or_path: str, + revision: Optional[str] = 'master', + *model_args, + **kwargs): + """ Instantiate a model from local directory or remote model repo. Note + that when loading from remote, the model revision can be specified. """ if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: - local_model_dir = snapshot_download(model_name_or_path) + local_model_dir = snapshot_download(model_name_or_path, revision) logger.info(f'initialize model from {local_model_dir}') cfg = Config.from_file( osp.join(local_model_dir, ModelFile.CONFIGURATION)) diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 868e751b..c427b7a3 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -6,6 +6,7 @@ from typing import List, Optional, Union from requests import HTTPError +from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download from modelscope.utils.config import Config @@ -16,8 +17,8 @@ def create_model_if_not_exist( api, model_id: str, chinese_name: str, - visibility: Optional[int] = 5, # 1-private, 5-public - license: Optional[str] = 'apache-2.0', + visibility: Optional[int] = ModelVisibility.PUBLIC, + license: Optional[str] = Licenses.APACHE_V2, revision: Optional[str] = 'master'): exists = True try: diff --git a/tests/hub/test_hub_examples.py b/tests/hub/test_hub_examples.py index b63445af..b21cae51 100644 --- a/tests/hub/test_hub_examples.py +++ b/tests/hub/test_hub_examples.py @@ -1,9 +1,9 @@ import unittest -from maas_hub.maas_api import MaasApi - +from modelscope.hub.api import HubApi from modelscope.utils.hub import create_model_if_not_exist +# note this is temporary before official account management is ready USER_NAME = 'maasadmin' PASSWORD = '12345678' @@ -11,8 +11,7 @@ PASSWORD = '12345678' class HubExampleTest(unittest.TestCase): def setUp(self): - self.api = MaasApi() - # note this is temporary before official account management is ready + self.api = HubApi() self.api.login(USER_NAME, PASSWORD) @unittest.skip('to be used for local test only') @@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase): model_chinese_name = '达摩卡通化模型' model_org = 'damo' model_id = '%s/%s' % (model_org, model_name) - created = create_model_if_not_exist(self.api, model_id, model_chinese_name) if not created: diff --git a/tests/hub/test_hub_operation.py b/tests/hub/test_hub_operation.py index e0adc013..035b183e 100644 --- a/tests/hub/test_hub_operation.py +++ b/tests/hub/test_hub_operation.py @@ -4,7 +4,8 @@ import tempfile import unittest import uuid -from modelscope.hub.api import HubApi +from modelscope.hub.api import HubApi, ModelScopeConfig +from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.file_download import model_file_download from modelscope.hub.repository import Repository from modelscope.hub.snapshot_download import snapshot_download @@ -31,8 +32,8 @@ class HubOperationTest(unittest.TestCase): self.api.create_model( model_id=self.model_id, chinese_name=model_chinese_name, - visibility=5, # 1-private, 5-public - license='apache-2.0') + visibility=ModelVisibility.PUBLIC, + license=Licenses.APACHE_V2) temporary_dir = tempfile.mkdtemp() self.model_dir = os.path.join(temporary_dir, self.model_name) repo = Repository(self.model_dir, clone_from=self.model_id)