From 2be98edcd2dde73098bad0f8fe324ebe36b8406d Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 21 Nov 2023 15:29:12 +0800 Subject: [PATCH] [FIX] fix bugs for mac cannot build learnware root dir --- learnware/config.py | 25 +++++++++++++++++++++++-- learnware/utils/__init__.py | 2 +- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/learnware/config.py b/learnware/config.py index 9489e61..094dc2c 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -1,7 +1,7 @@ import os import copy import logging - +from enum import Enum class Config: def __init__(self, default_conf): @@ -48,7 +48,28 @@ class Config: self.__dict__["_config"].update(*args, **kwargs) -ROOT_DIRPATH = os.path.join(os.path.expanduser("~"), ".learnware") +class SystemType(Enum): + LINUX = 0 + MACOS = 1 + WINDOWS = 2 + +def get_platform(): + import platform + + sys_platform = platform.platform().lower() + if "windows" in sys_platform: + return SystemType.WINDOWS + elif "macos" in sys_platform: + return SystemType.MACOS + elif "linux" in sys_platform: + return SystemType.LINUX + raise SystemError("Learnware only support MACOS/Linux/Windows") + +if get_platform() == SystemType.MACOS: + ROOT_DIRPATH = os.path.join(os.path.expanduser("~"), "Library", "Learnware") +else: + ROOT_DIRPATH = os.path.join(os.path.expanduser("~"), ".learnware") + PACKAGE_DIRPATH = os.path.dirname(os.path.abspath(__file__)) DATABASE_PATH = os.path.join(ROOT_DIRPATH, "database") diff --git a/learnware/utils/__init__.py b/learnware/utils/__init__.py index d98e60b..5357aaf 100644 --- a/learnware/utils/__init__.py +++ b/learnware/utils/__init__.py @@ -5,7 +5,7 @@ from .import_utils import is_torch_available from .module import get_module_by_module_path from .file import read_yaml_to_dict, save_dict_to_yaml from .gpu import setup_seed, choose_device, allocate_cuda_idx - +from ..config import get_platform, SystemType def zip_learnware_folder(path: str, output_name: str): with zipfile.ZipFile(output_name, "w") as zip_ref: