Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9159678master
| @@ -1,4 +1,3 @@ | |||||
| import imp | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import subprocess | import subprocess | ||||
| @@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo' | |||||
| MODEL_ID_SEPARATOR = '/' | MODEL_ID_SEPARATOR = '/' | ||||
| LOGGER_NAME = 'ModelScopeHub' | LOGGER_NAME = 'ModelScopeHub' | ||||
| class Licenses(object): | |||||
| APACHE_V2 = 'Apache License 2.0' | |||||
| GPL = 'GPL' | |||||
| LGPL = 'LGPL' | |||||
| MIT = 'MIT' | |||||
| class ModelVisibility(object): | |||||
| PRIVATE = 1 | |||||
| INTERNAL = 3 | |||||
| PUBLIC = 5 | |||||
| @@ -2,7 +2,7 @@ | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, Union | |||||
| from typing import Dict, Optional, Union | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| @@ -42,13 +42,18 @@ class Model(ABC): | |||||
| return input | return input | ||||
| @classmethod | @classmethod | ||||
| def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs): | |||||
| """ Instantiate a model from local directory or remote model repo | |||||
| def from_pretrained(cls, | |||||
| model_name_or_path: str, | |||||
| revision: Optional[str] = 'master', | |||||
| *model_args, | |||||
| **kwargs): | |||||
| """ Instantiate a model from local directory or remote model repo. Note | |||||
| that when loading from remote, the model revision can be specified. | |||||
| """ | """ | ||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| local_model_dir = snapshot_download(model_name_or_path) | |||||
| local_model_dir = snapshot_download(model_name_or_path, revision) | |||||
| logger.info(f'initialize model from {local_model_dir}') | logger.info(f'initialize model from {local_model_dir}') | ||||
| cfg = Config.from_file( | cfg = Config.from_file( | ||||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | osp.join(local_model_dir, ModelFile.CONFIGURATION)) | ||||
| @@ -6,6 +6,7 @@ from typing import List, Optional, Union | |||||
| from requests import HTTPError | from requests import HTTPError | ||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| @@ -16,8 +17,8 @@ def create_model_if_not_exist( | |||||
| api, | api, | ||||
| model_id: str, | model_id: str, | ||||
| chinese_name: str, | chinese_name: str, | ||||
| visibility: Optional[int] = 5, # 1-private, 5-public | |||||
| license: Optional[str] = 'apache-2.0', | |||||
| visibility: Optional[int] = ModelVisibility.PUBLIC, | |||||
| license: Optional[str] = Licenses.APACHE_V2, | |||||
| revision: Optional[str] = 'master'): | revision: Optional[str] = 'master'): | ||||
| exists = True | exists = True | ||||
| try: | try: | ||||
| @@ -1,9 +1,9 @@ | |||||
| import unittest | import unittest | ||||
| from maas_hub.maas_api import MaasApi | |||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.utils.hub import create_model_if_not_exist | from modelscope.utils.hub import create_model_if_not_exist | ||||
| # note this is temporary before official account management is ready | |||||
| USER_NAME = 'maasadmin' | USER_NAME = 'maasadmin' | ||||
| PASSWORD = '12345678' | PASSWORD = '12345678' | ||||
| @@ -11,8 +11,7 @@ PASSWORD = '12345678' | |||||
| class HubExampleTest(unittest.TestCase): | class HubExampleTest(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| self.api = MaasApi() | |||||
| # note this is temporary before official account management is ready | |||||
| self.api = HubApi() | |||||
| self.api.login(USER_NAME, PASSWORD) | self.api.login(USER_NAME, PASSWORD) | ||||
| @unittest.skip('to be used for local test only') | @unittest.skip('to be used for local test only') | ||||
| @@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase): | |||||
| model_chinese_name = '达摩卡通化模型' | model_chinese_name = '达摩卡通化模型' | ||||
| model_org = 'damo' | model_org = 'damo' | ||||
| model_id = '%s/%s' % (model_org, model_name) | model_id = '%s/%s' % (model_org, model_name) | ||||
| created = create_model_if_not_exist(self.api, model_id, | created = create_model_if_not_exist(self.api, model_id, | ||||
| model_chinese_name) | model_chinese_name) | ||||
| if not created: | if not created: | ||||
| @@ -4,7 +4,8 @@ import tempfile | |||||
| import unittest | import unittest | ||||
| import uuid | import uuid | ||||
| from modelscope.hub.api import HubApi | |||||
| from modelscope.hub.api import HubApi, ModelScopeConfig | |||||
| from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.repository import Repository | from modelscope.hub.repository import Repository | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| @@ -31,8 +32,8 @@ class HubOperationTest(unittest.TestCase): | |||||
| self.api.create_model( | self.api.create_model( | ||||
| model_id=self.model_id, | model_id=self.model_id, | ||||
| chinese_name=model_chinese_name, | chinese_name=model_chinese_name, | ||||
| visibility=5, # 1-private, 5-public | |||||
| license='apache-2.0') | |||||
| visibility=ModelVisibility.PUBLIC, | |||||
| license=Licenses.APACHE_V2) | |||||
| temporary_dir = tempfile.mkdtemp() | temporary_dir = tempfile.mkdtemp() | ||||
| self.model_dir = os.path.join(temporary_dir, self.model_name) | self.model_dir = os.path.join(temporary_dir, self.model_name) | ||||
| repo = Repository(self.model_dir, clone_from=self.model_id) | repo = Repository(self.model_dir, clone_from=self.model_id) | ||||