diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f8443a1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + args: ["-l 120"] + +- repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics"] \ No newline at end of file diff --git a/examples/dataset_text_workflow/get_data.py b/examples/dataset_text_workflow/get_data.py index 0d1412d..770fd80 100644 --- a/examples/dataset_text_workflow/get_data.py +++ b/examples/dataset_text_workflow/get_data.py @@ -1,4 +1,3 @@ -import torch from torchtext.datasets import SST2 diff --git a/examples/dataset_text_workflow/requirements.txt b/examples/dataset_text_workflow/requirements.txt new file mode 100644 index 0000000..66500d8 --- /dev/null +++ b/examples/dataset_text_workflow/requirements.txt @@ -0,0 +1 @@ +torchtext>=0.14.1 diff --git a/learnware/__init__.py b/learnware/__init__.py index a57329c..b140a86 100644 --- a/learnware/__init__.py +++ b/learnware/__init__.py @@ -1,7 +1,10 @@ -__version__ = "0.1.1.99" +__version__ = "0.1.2.99" import os from .logger import get_module_logger +from .utils import is_torch_avaliable + +logger = get_module_logger("Initialization") def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): @@ -10,9 +13,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): C.reset() C.update(**kwargs) - logger = get_module_logger("Initialization") logger.info(f"init learnware market with {kwargs}") - ## make dirs if make_dir: os.makedirs(C.root_path, exist_ok=True) @@ -25,3 +26,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): ## ignore tensorflow warning # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel # logger.info(f"The tensorflow log level is setted to {tf_loglevel}") + + +if not is_torch_avaliable(verbose=False): + logger.warning("The functionality of learnware is limited due to 'torch' is not installed!") diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 7bc23bd..55816f4 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -18,7 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker from ..logger import get_module_logger from ..specification import Specification from ..learnware import get_learnware_from_dirpath -from ..test import get_semantic_specification +from ..tests import get_semantic_specification CHUNK_SIZE = 1024 * 1024 logger = get_module_logger(module_name="LearnwareClient") diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index b850f5c..a040444 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -1,4 +1,4 @@ -from .anchor import AnchoredUserInfo, AnchoredOrganizer +from .anchor import AnchoredUserInfo, AnchoredSearcher, AnchoredOrganizer from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher from .evolve_anchor import EvolvedAnchoredOrganizer from .evolve import EvolvedOrganizer diff --git a/learnware/market/anchor/__init__.py b/learnware/market/anchor/__init__.py index f2b453c..c005c0c 100644 --- a/learnware/market/anchor/__init__.py +++ b/learnware/market/anchor/__init__.py @@ -1,2 +1,13 @@ from .organizer import AnchoredOrganizer -from .searcher import AnchoredUserInfo +from .user_info import AnchoredUserInfo + +from ...utils import is_torch_avaliable +from ...logger import get_module_logger + +logger = get_module_logger("market_anchor") + +if not is_torch_avaliable(verbose=False): + AnchoredSearcher = None + logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!") +else: + from .searcher import AnchoredSearcher diff --git a/learnware/market/anchor/searcher.py b/learnware/market/anchor/searcher.py index b0d826f..34d326d 100644 --- a/learnware/market/anchor/searcher.py +++ b/learnware/market/anchor/searcher.py @@ -1,5 +1,6 @@ from typing import List, Dict, Tuple, Any, Union +from .user_info import AnchoredUserInfo from ..base import BaseUserInfo from ..easy.searcher import EasySearcher from ...logger import get_module_logger @@ -8,45 +9,6 @@ from ...learnware import Learnware logger = get_module_logger("anchor_searcher") -class AnchoredUserInfo(BaseUserInfo): - """ - User Information for searching learnware (add the anchor design) - - - UserInfo contains the anchor id list acquired from the market - - UserInfo can update stat_info based on anchors - """ - - def __init__( - self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None - ): - super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) - self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids - - def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): - """Add the anchor learnware ids acquired from the market - - Parameters - ---------- - learnware_ids : Union[str, List[str]] - Anchor learnware ids - """ - if isinstance(learnware_ids, str): - learnware_ids = [learnware_ids] - self.anchor_learnware_ids += learnware_ids - - def update_stat_info(self, name: str, item: Any): - """Update stat_info based on anchor learnwares - - Parameters - ---------- - name : str - Name of stat_info - item : Any - Statistical information calculated on anchor learnwares - """ - self.stat_info[name] = item - - class AnchoredSearcher(EasySearcher): def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: """Search anchor Learnwares from anchor_learnware_list based on user_info diff --git a/learnware/market/anchor/user_info.py b/learnware/market/anchor/user_info.py new file mode 100644 index 0000000..7ae4737 --- /dev/null +++ b/learnware/market/anchor/user_info.py @@ -0,0 +1,41 @@ +from typing import List, Any, Union +from ..base import BaseUserInfo + + +class AnchoredUserInfo(BaseUserInfo): + """ + User Information for searching learnware (add the anchor design) + + - UserInfo contains the anchor id list acquired from the market + - UserInfo can update stat_info based on anchors + """ + + def __init__( + self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None + ): + super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) + self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids + + def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): + """Add the anchor learnware ids acquired from the market + + Parameters + ---------- + learnware_ids : Union[str, List[str]] + Anchor learnware ids + """ + if isinstance(learnware_ids, str): + learnware_ids = [learnware_ids] + self.anchor_learnware_ids += learnware_ids + + def update_stat_info(self, name: str, item: Any): + """Update stat_info based on anchor learnwares + + Parameters + ---------- + name : str + Name of stat_info + item : Any + Statistical information calculated on anchor learnwares + """ + self.stat_info[name] = item diff --git a/learnware/market/base.py b/learnware/market/base.py index a422479..d061d7f 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -227,7 +227,7 @@ class LearnwareMarket: def reload_learnware(self, learnware_id: str): self.learnware_organizer.reload_learnware(learnware_id) - + def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) diff --git a/learnware/market/easy/__init__.py b/learnware/market/easy/__init__.py index 2835871..988a5c6 100644 --- a/learnware/market/easy/__init__.py +++ b/learnware/market/easy/__init__.py @@ -1,3 +1,15 @@ from .organizer import EasyOrganizer -from .searcher import EasySearcher -from .checker import EasySemanticChecker, EasyStatChecker + +from ...utils import is_torch_avaliable +from ...logger import get_module_logger + +logger = get_module_logger("market_easy") + +if not is_torch_avaliable(verbose=False): + EasySearcher = None + EasySemanticChecker = None + EasyStatChecker = None + logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!") +else: + from .searcher import EasySearcher + from .checker import EasySemanticChecker, EasyStatChecker diff --git a/learnware/market/easy/database_ops.py b/learnware/market/easy/database_ops.py index 077a04e..c9fb3de 100644 --- a/learnware/market/easy/database_ops.py +++ b/learnware/market/easy/database_ops.py @@ -166,7 +166,7 @@ class DatabaseOperations(object): return int(row[0]) pass pass - + def load_market(self): with self.engine.connect() as conn: cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) diff --git a/learnware/market/easy/organizer.py b/learnware/market/easy/organizer.py index 3d6d6b6..9337841 100644 --- a/learnware/market/easy/organizer.py +++ b/learnware/market/easy/organizer.py @@ -387,7 +387,8 @@ class EasyOrganizer(BaseOrganizer): self.learnware_folder_list[learnware_id] = target_folder_dir semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) self.learnware_list[learnware_id] = get_learnware_from_dirpath( - id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir) + id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir + ) self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) pass diff --git a/learnware/market/module.py b/learnware/market/module.py index d48e03e..290755a 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -2,25 +2,29 @@ from .base import LearnwareMarket from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher -MARKET_CONFIG = { - "easy": { - "organizer": EasyOrganizer(), - "searcher": EasySearcher(), - "checker_list": [EasySemanticChecker(), EasyStatChecker()], - }, - "hetero": { + +def get_market_config(): + market_config = { + "easy": { + "organizer": EasyOrganizer(), + "searcher": EasySearcher(), + "checker_list": [EasySemanticChecker(), EasyStatChecker()], + }, + "hetero": { "organizer": HeteroMapTableOrganizer(), "searcher": HeteroSearcher(), "checker_list": [] + } } -} + return market_config def instantiate_learnware_market(market_id="default", name="easy", **kwargs): + market_config = get_market_config() return LearnwareMarket( market_id=market_id, - organizer=MARKET_CONFIG[name]["organizer"], - searcher=MARKET_CONFIG[name]["searcher"], - checker_list=MARKET_CONFIG[name]["checker_list"], + organizer=market_config[name]["organizer"], + searcher=market_config[name]["searcher"], + checker_list=market_config[name]["checker_list"], **kwargs ) diff --git a/learnware/model/base.py b/learnware/model/base.py index 26ec380..e54d858 100644 --- a/learnware/model/base.py +++ b/learnware/model/base.py @@ -1,5 +1,4 @@ import numpy as np -import torch from typing import Union @@ -19,7 +18,7 @@ class BaseModel: self.input_shape = input_shape self.output_shape = output_shape - def predict(self, X: Union[np.ndarray, torch.tensor]) -> Union[np.ndarray, torch.tensor]: + def predict(self, X: np.ndarray) -> np.ndarray: """The prediction method for model in learnware, which will be checked when learnware is submitted into the market. Parameters @@ -33,10 +32,10 @@ class BaseModel: """ pass - def fit(self, X: Union[np.ndarray, torch.tensor], y: Union[np.ndarray, torch.tensor]): + def fit(self, X: np.ndarray, y: np.ndarray): pass - def finetune(self, X: Union[np.ndarray, torch.tensor], y: np.ndarray): + def finetune(self, X: np.ndarray, y: np.ndarray): """The finetune method for continuing train the model searched by market Parameters diff --git a/learnware/reuse/__init__.py b/learnware/reuse/__init__.py index 8e9621b..d8040a9 100644 --- a/learnware/reuse/__init__.py +++ b/learnware/reuse/__init__.py @@ -1,5 +1,42 @@ -from .ensemble_pruning import EnsemblePruningReuser -from .averaging import AveragingReuser -from .job_selector import JobSelectorReuser +from ..logger import get_module_logger +from ..utils import is_torch_avaliable +from .utils import is_geatpy_avaliable, is_lightgbm_avaliable + +logger = get_module_logger("reuse") + +if not is_geatpy_avaliable(verbose=False): + EnsemblePruningReuser = None + logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!") +else: + from .ensemble_pruning import EnsemblePruningReuser + +if not is_torch_avaliable(verbose=False): + AveragingReuser = None + logger.warning("AveragingReuser is skipped due to 'torch' is not installed!") +else: + from .averaging import AveragingReuser + +if not is_lightgbm_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): + JobSelectorReuser = None + uninstall_packages = [ + value + for flag, value in zip( + [ + is_lightgbm_avaliable(verbose=False), + is_torch_avaliable(verbose=False), + ], + ["lightgbm", "torch"], + ) + if flag is False + ] + logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!") +else: + from .job_selector import JobSelectorReuser + +if not is_torch_avaliable(verbose=False): + HeteroMapTableReuser = None + logger.warning("FeatureAugmentReuser is skipped due to 'torch' is not installed!") +else: + from .hetero_reuser import HeteroMapTableReuser + from .feature_augment_reuser import FeatureAugmentReuser -from .hetero_reuser import HeteroMapTableReuser \ No newline at end of file diff --git a/learnware/reuse/ensemble_pruning.py b/learnware/reuse/ensemble_pruning.py index d20664c..6001880 100644 --- a/learnware/reuse/ensemble_pruning.py +++ b/learnware/reuse/ensemble_pruning.py @@ -2,7 +2,6 @@ import torch import random import numpy as np import geatpy as ea - from typing import List from learnware.learnware import Learnware diff --git a/learnware/reuse/utils.py b/learnware/reuse/utils.py new file mode 100644 index 0000000..d0ab3f8 --- /dev/null +++ b/learnware/reuse/utils.py @@ -0,0 +1,25 @@ +from ..logger import get_module_logger + +logger = get_module_logger("reuse_utils") + + +def is_geatpy_avaliable(verbose=False): + try: + import geatpy + except ModuleNotFoundError as err: + if verbose is True: + logger.warning( + "ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!" + ) + return False + return True + + +def is_lightgbm_avaliable(verbose=False): + try: + import lightgbm + except ModuleNotFoundError as err: + if verbose is True: + logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!") + return False + return True diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 90c8758..fae0c7c 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,4 +1,3 @@ -from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec from .base import Specification, BaseStatSpecification from .regular import ( RegularStatsSpecification, @@ -7,4 +6,15 @@ from .regular import ( RKMEImageSpecification, RKMETextSpecification, ) + from .system import HeteroSpecification + +from ..utils import is_torch_avaliable + +if not is_torch_avaliable(verbose=False): + generate_stat_spec = None + generate_rkme_spec = None + generate_rkme_image_spec = None + generate_rkme_text_spec = None +else: + from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 9007e4d..9d46114 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,4 +1,6 @@ +from .base import RegularStatsSpecification +from ...utils import is_torch_avaliable + from .text import RKMETextSpecification from .table import RKMETableSpecification, RKMEStatSpecification from .image import RKMEImageSpecification -from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/image/__init__.py b/learnware/specification/regular/image/__init__.py index 0a18ded..e883f99 100644 --- a/learnware/specification/regular/image/__init__.py +++ b/learnware/specification/regular/image/__init__.py @@ -1 +1,29 @@ -from .rkme import RKMEImageSpecification +from .utils import is_torch_optimizer_avaliable, is_torch_vision_avaliable +from ....utils import is_torch_avaliable +from ....logger import get_module_logger + + +logger = get_module_logger("regular_image_spec") + +if ( + not is_torch_vision_avaliable(verbose=False) + or not is_torch_optimizer_avaliable(verbose=False) + or not is_torch_avaliable(verbose=False) +): + RKMEImageSpecification = None + uninstall_packages = [ + value + for flag, value in zip( + [ + is_torch_vision_avaliable(verbose=False), + is_torch_optimizer_avaliable(verbose=False), + is_torch_avaliable(verbose=False), + ], + ["torchvision", "torch-optimizer", "torch"], + ) + if flag is False + ] + + logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!") +else: + from .rkme import RKMEImageSpecification diff --git a/learnware/specification/regular/image/utils.py b/learnware/specification/regular/image/utils.py new file mode 100644 index 0000000..80c97f2 --- /dev/null +++ b/learnware/specification/regular/image/utils.py @@ -0,0 +1,23 @@ +from ....logger import get_module_logger + +logger = get_module_logger("regular_image_spec_utils") + + +def is_torch_optimizer_avaliable(verbose=False): + try: + import torch_optimizer + except ModuleNotFoundError as err: + if verbose is True: + logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!") + return False + return True + + +def is_torch_vision_avaliable(verbose=False): + try: + import torchvision + except ModuleNotFoundError as err: + if verbose is True: + logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!") + return False + return True diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index 19fa956..e8ec903 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1 +1,11 @@ -from .rkme import RKMETableSpecification, RKMEStatSpecification +from ....utils import is_torch_avaliable +from ....logger import get_module_logger + +logger = get_module_logger("regular_table_spec") + +if not is_torch_avaliable(verbose=False): + RKMETableSpecification = None + RKMEStatSpecification = None + logger.warning("RKMETableSpecification is skipped because torch is not installed!") +else: + from .rkme import RKMETableSpecification, RKMEStatSpecification diff --git a/learnware/specification/regular/text/__init__.py b/learnware/specification/regular/text/__init__.py index 35b8b0a..3d8c830 100644 --- a/learnware/specification/regular/text/__init__.py +++ b/learnware/specification/regular/text/__init__.py @@ -1 +1,23 @@ -from .rkme import RKMETextSpecification +from .utils import is_sentence_transformers_avaliable + +from ....utils import is_torch_avaliable +from ....logger import get_module_logger + +logger = get_module_logger("regular_text_spec") + +if not is_sentence_transformers_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): + RKMETextSpecification = None + uninstall_packages = [ + value + for flag, value in zip( + [ + is_sentence_transformers_avaliable(verbose=False), + is_torch_avaliable(verbose=False), + ], + ["sentence_transformers", "torch"], + ) + if flag is False + ] + logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!") +else: + from .rkme import RKMETextSpecification diff --git a/learnware/specification/regular/text/utils.py b/learnware/specification/regular/text/utils.py new file mode 100644 index 0000000..2052c02 --- /dev/null +++ b/learnware/specification/regular/text/utils.py @@ -0,0 +1,15 @@ +from ....logger import get_module_logger + +logger = get_module_logger("regular_text_spec_utils") + + +def is_sentence_transformers_avaliable(verbose=False): + try: + import sentence_transformers + except ModuleNotFoundError as err: + if verbose is True: + logger.warning( + "ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!" + ) + return False + return True diff --git a/learnware/test/__init__.py b/learnware/tests/__init__.py similarity index 100% rename from learnware/test/__init__.py rename to learnware/tests/__init__.py diff --git a/learnware/test/data.py b/learnware/tests/data.py similarity index 100% rename from learnware/test/data.py rename to learnware/tests/data.py diff --git a/learnware/test/module.py b/learnware/tests/module.py similarity index 100% rename from learnware/test/module.py rename to learnware/tests/module.py diff --git a/learnware/utils.py b/learnware/utils.py deleted file mode 100644 index 9bea664..0000000 --- a/learnware/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import sys - -import re -import yaml -import importlib -import importlib.util -from typing import Union -from types import ModuleType -import zipfile -from .logger import get_module_logger - -logger = get_module_logger("utils") - - -def get_module_by_module_path(module_path: Union[str, ModuleType]): - if module_path is None: - raise ModuleNotFoundError("None is passed in as parameters as module_path") - - if isinstance(module_path, ModuleType): - module = module_path - else: - if module_path.endswith(".py"): - module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) - module_spec = importlib.util.spec_from_file_location(module_name, module_path) - module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module - module_spec.loader.exec_module(module) - else: - module = importlib.import_module(module_path) - return module - - -def save_dict_to_yaml(dict_value: dict, save_path: str): - """save dict object into yaml file""" - with open(save_path, "w") as file: - file.write(yaml.dump(dict_value, allow_unicode=True)) - - -def read_yaml_to_dict(yaml_path: str): - """load yaml file into dict object""" - with open(yaml_path, "r") as file: - dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) - return dict_value - - -def zip_learnware_folder(path: str, output_name: str): - with zipfile.ZipFile(output_name, "w") as zip_ref: - for root, dirs, files in os.walk(path): - for file in files: - full_path = os.path.join(root, file) - if file.endswith(".pyc") or os.path.islink(full_path): - continue - zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) - pass - pass - pass - pass diff --git a/learnware/utils/__init__.py b/learnware/utils/__init__.py new file mode 100644 index 0000000..60f2b46 --- /dev/null +++ b/learnware/utils/__init__.py @@ -0,0 +1,16 @@ +import os +import zipfile + +from .import_utils import is_torch_avaliable +from .module import get_module_by_module_path +from .file import read_yaml_to_dict, save_dict_to_yaml + + +def zip_learnware_folder(path: str, output_name: str): + with zipfile.ZipFile(output_name, "w") as zip_ref: + for root, dirs, files in os.walk(path): + for file in files: + full_path = os.path.join(root, file) + if file.endswith(".pyc") or os.path.islink(full_path): + continue + zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) diff --git a/learnware/utils/file.py b/learnware/utils/file.py new file mode 100644 index 0000000..27ba5f5 --- /dev/null +++ b/learnware/utils/file.py @@ -0,0 +1,14 @@ +import yaml + + +def save_dict_to_yaml(dict_value: dict, save_path: str): + """save dict object into yaml file""" + with open(save_path, "w") as file: + file.write(yaml.dump(dict_value, allow_unicode=True)) + + +def read_yaml_to_dict(yaml_path: str): + """load yaml file into dict object""" + with open(yaml_path, "r") as file: + dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) + return dict_value diff --git a/learnware/utils/import_utils.py b/learnware/utils/import_utils.py new file mode 100644 index 0000000..9f4406c --- /dev/null +++ b/learnware/utils/import_utils.py @@ -0,0 +1,13 @@ +from ..logger import get_module_logger + +logger = get_module_logger("import_utils") + + +def is_torch_avaliable(verbose=False): + try: + import torch + except ModuleNotFoundError as err: + if verbose is True: + logger.warning("ModuleNotFoundError: torch is not installed, please install pytorch!") + return False + return True diff --git a/learnware/utils/module.py b/learnware/utils/module.py new file mode 100644 index 0000000..6f1b414 --- /dev/null +++ b/learnware/utils/module.py @@ -0,0 +1,24 @@ +import sys +import re +import importlib +import importlib.util +from typing import Union +from types import ModuleType + + +def get_module_by_module_path(module_path: Union[str, ModuleType]): + if module_path is None: + raise ModuleNotFoundError("None is passed in as parameters as module_path") + + if isinstance(module_path, ModuleType): + module = module_path + else: + if module_path.endswith(".py"): + module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) + module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(module_spec) + sys.modules[module_name] = module + module_spec.loader.exec_module(module) + else: + module = importlib.import_module(module_path) + return module diff --git a/setup.py b/setup.py index b37dfab..8f2f812 100644 --- a/setup.py +++ b/setup.py @@ -51,31 +51,23 @@ def get_platform(): # What packages are required for this module to be executed? # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. REQUIRED = [ - # "numpy>=1.20.0", - # "pandas>=0.25.1", - # "scipy>=1.0.0", - # "matplotlib>=3.1.3", - # "torch>=1.11.0", - # "cvxopt>=1.3.0", - # "tqdm>=4.65.0", - # "scikit-learn>=0.22", - # "joblib>=1.2.0", - # "pyyaml>=6.0", - # "fire>=0.3.1", - # "lightgbm>=3.3.0", - # "psutil>=5.9.4", - # "torchvision>=0.15.1", - # "sqlalchemy>=2.0.21", - # "shortuuid>=1.0.11", - # "geatpy>=2.7.0", - # "docker>=6.1.3", - # "rapidfuzz>=3.4.0", - # "torchtext>=0.16.0", - # "sentence_transformers>=2.2.2", - # "torch-optimizer>=0.3.0", - # "langdetect>=1.0.9", - # "huggingface-hub<0.18", - # "portalocker>=2.0.0", + "numpy>=1.20.0", + "pandas>=0.25.1", + "scipy>=1.0.0", + "cvxopt>=1.3.0", + "tqdm>=4.65.0", + "scikit-learn>=0.22", + "joblib>=1.2.0", + "pyyaml>=6.0", + "fire>=0.3.1", + "psutil>=5.9.4", + "sqlalchemy>=2.0.21", + "shortuuid>=1.0.11", + "docker>=6.1.3", + "rapidfuzz>=3.4.0", + "langdetect>=1.0.9", + "huggingface-hub<0.18", + "portalocker>=2.0.0", ] if get_platform() != MACOS: @@ -99,6 +91,23 @@ if __name__ == "__main__": long_description_content_type="text/markdown", python_requires=REQUIRES_PYTHON, install_requires=REQUIRED, + extras_require={ + "dev": [ + # For documentations + "sphinx", + "sphinx_rtd_theme", + # CI dependencies + "pytest>=3", + "wheel", + "setuptools", + "pylint", + # For static analysis + "mypy<0.981", + "flake8", + "black==23.1.0", + "pre-commit", + ], + }, classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", @@ -108,8 +117,10 @@ if __name__ == "__main__": "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", "Operating System :: MacOS", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], ) diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 492c8a9..fac8348 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -30,7 +30,7 @@ user_semantic = { } -class TestMarket(unittest.TestCase): +class TestWorkflow(unittest.TestCase): @classmethod def setUpClass(cls) -> None: np.random.seed(2023) @@ -226,11 +226,11 @@ class TestMarket(unittest.TestCase): def suite(): _suite = unittest.TestSuite() - _suite.addTest(TestMarket("test_prepare_learnware_randomly")) - _suite.addTest(TestMarket("test_upload_delete_learnware")) - _suite.addTest(TestMarket("test_search_semantics")) - _suite.addTest(TestMarket("test_stat_search")) - _suite.addTest(TestMarket("test_learnware_reuse")) + _suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) + _suite.addTest(TestWorkflow("test_upload_delete_learnware")) + _suite.addTest(TestWorkflow("test_search_semantics")) + _suite.addTest(TestWorkflow("test_stat_search")) + _suite.addTest(TestWorkflow("test_learnware_reuse")) return _suite