| @@ -1,5 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import jsonplus | |||||
| import numpy as np | import numpy as np | ||||
| from .base import FormatHandler | from .base import FormatHandler | ||||
| @@ -25,11 +24,13 @@ class JsonHandler(FormatHandler): | |||||
| """Use jsonplus, serialization of Python types to JSON that "just works".""" | """Use jsonplus, serialization of Python types to JSON that "just works".""" | ||||
| def load(self, file): | def load(self, file): | ||||
| import jsonplus | |||||
| return jsonplus.loads(file.read()) | return jsonplus.loads(file.read()) | ||||
| def dump(self, obj, file, **kwargs): | def dump(self, obj, file, **kwargs): | ||||
| file.write(self.dumps(obj, **kwargs)) | file.write(self.dumps(obj, **kwargs)) | ||||
| def dumps(self, obj, **kwargs): | def dumps(self, obj, **kwargs): | ||||
| import jsonplus | |||||
| kwargs.setdefault('default', set_default) | kwargs.setdefault('default', set_default) | ||||
| return jsonplus.dumps(obj, **kwargs) | return jsonplus.dumps(obj, **kwargs) | ||||
| @@ -1,5 +1,6 @@ | |||||
| import hashlib | import hashlib | ||||
| import os | import os | ||||
| from typing import Optional | |||||
| from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | ||||
| DEFAULT_MODELSCOPE_DOMAIN, | DEFAULT_MODELSCOPE_DOMAIN, | ||||
| @@ -23,14 +24,16 @@ def model_id_to_group_owner_name(model_id): | |||||
| return group_or_owner, name | return group_or_owner, name | ||||
| def get_cache_dir(): | |||||
| def get_cache_dir(model_id: Optional[str] = None): | |||||
| """ | """ | ||||
| cache dir precedence: | cache dir precedence: | ||||
| function parameter > enviroment > ~/.cache/modelscope/hub | function parameter > enviroment > ~/.cache/modelscope/hub | ||||
| """ | """ | ||||
| default_cache_dir = get_default_cache_dir() | default_cache_dir = get_default_cache_dir() | ||||
| return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir, | |||||
| 'hub')) | |||||
| base_path = os.getenv('MODELSCOPE_CACHE', | |||||
| os.path.join(default_cache_dir, 'hub')) | |||||
| return base_path if model_id is None else os.path.join( | |||||
| base_path, model_id + '/') | |||||
| def get_endpoint(): | def get_endpoint(): | ||||