| @@ -0,0 +1,12 @@ | |||||
| repos: | |||||
| - repo: https://github.com/psf/black | |||||
| rev: 23.1.0 | |||||
| hooks: | |||||
| - id: black | |||||
| args: ["-l 120"] | |||||
| - repo: https://github.com/PyCQA/flake8 | |||||
| rev: 4.0.1 | |||||
| hooks: | |||||
| - id: flake8 | |||||
| args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics"] | |||||
| @@ -1,4 +1,3 @@ | |||||
| import torch | |||||
| from torchtext.datasets import SST2 | from torchtext.datasets import SST2 | ||||
| @@ -0,0 +1 @@ | |||||
| torchtext>=0.14.1 | |||||
| @@ -1,7 +1,10 @@ | |||||
| __version__ = "0.1.1.99" | |||||
| __version__ = "0.1.2.99" | |||||
| import os | import os | ||||
| from .logger import get_module_logger | from .logger import get_module_logger | ||||
| from .utils import is_torch_avaliable | |||||
| logger = get_module_logger("Initialization") | |||||
| def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | ||||
| @@ -10,9 +13,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | |||||
| C.reset() | C.reset() | ||||
| C.update(**kwargs) | C.update(**kwargs) | ||||
| logger = get_module_logger("Initialization") | |||||
| logger.info(f"init learnware market with {kwargs}") | logger.info(f"init learnware market with {kwargs}") | ||||
| ## make dirs | ## make dirs | ||||
| if make_dir: | if make_dir: | ||||
| os.makedirs(C.root_path, exist_ok=True) | os.makedirs(C.root_path, exist_ok=True) | ||||
| @@ -25,3 +26,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | |||||
| ## ignore tensorflow warning | ## ignore tensorflow warning | ||||
| # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel | # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel | ||||
| # logger.info(f"The tensorflow log level is setted to {tf_loglevel}") | # logger.info(f"The tensorflow log level is setted to {tf_loglevel}") | ||||
| if not is_torch_avaliable(verbose=False): | |||||
| logger.warning("The functionality of learnware is limited due to 'torch' is not installed!") | |||||
| @@ -18,7 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker | |||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| from ..specification import Specification | from ..specification import Specification | ||||
| from ..learnware import get_learnware_from_dirpath | from ..learnware import get_learnware_from_dirpath | ||||
| from ..test import get_semantic_specification | |||||
| from ..tests import get_semantic_specification | |||||
| CHUNK_SIZE = 1024 * 1024 | CHUNK_SIZE = 1024 * 1024 | ||||
| logger = get_module_logger(module_name="LearnwareClient") | logger = get_module_logger(module_name="LearnwareClient") | ||||
| @@ -1,4 +1,4 @@ | |||||
| from .anchor import AnchoredUserInfo, AnchoredOrganizer | |||||
| from .anchor import AnchoredUserInfo, AnchoredSearcher, AnchoredOrganizer | |||||
| from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | ||||
| from .evolve_anchor import EvolvedAnchoredOrganizer | from .evolve_anchor import EvolvedAnchoredOrganizer | ||||
| from .evolve import EvolvedOrganizer | from .evolve import EvolvedOrganizer | ||||
| @@ -1,2 +1,13 @@ | |||||
| from .organizer import AnchoredOrganizer | from .organizer import AnchoredOrganizer | ||||
| from .searcher import AnchoredUserInfo | |||||
| from .user_info import AnchoredUserInfo | |||||
| from ...utils import is_torch_avaliable | |||||
| from ...logger import get_module_logger | |||||
| logger = get_module_logger("market_anchor") | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| AnchoredSearcher = None | |||||
| logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!") | |||||
| else: | |||||
| from .searcher import AnchoredSearcher | |||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import List, Dict, Tuple, Any, Union | from typing import List, Dict, Tuple, Any, Union | ||||
| from .user_info import AnchoredUserInfo | |||||
| from ..base import BaseUserInfo | from ..base import BaseUserInfo | ||||
| from ..easy.searcher import EasySearcher | from ..easy.searcher import EasySearcher | ||||
| from ...logger import get_module_logger | from ...logger import get_module_logger | ||||
| @@ -8,45 +9,6 @@ from ...learnware import Learnware | |||||
| logger = get_module_logger("anchor_searcher") | logger = get_module_logger("anchor_searcher") | ||||
| class AnchoredUserInfo(BaseUserInfo): | |||||
| """ | |||||
| User Information for searching learnware (add the anchor design) | |||||
| - UserInfo contains the anchor id list acquired from the market | |||||
| - UserInfo can update stat_info based on anchors | |||||
| """ | |||||
| def __init__( | |||||
| self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None | |||||
| ): | |||||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||||
| self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids | |||||
| def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): | |||||
| """Add the anchor learnware ids acquired from the market | |||||
| Parameters | |||||
| ---------- | |||||
| learnware_ids : Union[str, List[str]] | |||||
| Anchor learnware ids | |||||
| """ | |||||
| if isinstance(learnware_ids, str): | |||||
| learnware_ids = [learnware_ids] | |||||
| self.anchor_learnware_ids += learnware_ids | |||||
| def update_stat_info(self, name: str, item: Any): | |||||
| """Update stat_info based on anchor learnwares | |||||
| Parameters | |||||
| ---------- | |||||
| name : str | |||||
| Name of stat_info | |||||
| item : Any | |||||
| Statistical information calculated on anchor learnwares | |||||
| """ | |||||
| self.stat_info[name] = item | |||||
| class AnchoredSearcher(EasySearcher): | class AnchoredSearcher(EasySearcher): | ||||
| def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | ||||
| """Search anchor Learnwares from anchor_learnware_list based on user_info | """Search anchor Learnwares from anchor_learnware_list based on user_info | ||||
| @@ -0,0 +1,41 @@ | |||||
| from typing import List, Any, Union | |||||
| from ..base import BaseUserInfo | |||||
| class AnchoredUserInfo(BaseUserInfo): | |||||
| """ | |||||
| User Information for searching learnware (add the anchor design) | |||||
| - UserInfo contains the anchor id list acquired from the market | |||||
| - UserInfo can update stat_info based on anchors | |||||
| """ | |||||
| def __init__( | |||||
| self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None | |||||
| ): | |||||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||||
| self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids | |||||
| def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): | |||||
| """Add the anchor learnware ids acquired from the market | |||||
| Parameters | |||||
| ---------- | |||||
| learnware_ids : Union[str, List[str]] | |||||
| Anchor learnware ids | |||||
| """ | |||||
| if isinstance(learnware_ids, str): | |||||
| learnware_ids = [learnware_ids] | |||||
| self.anchor_learnware_ids += learnware_ids | |||||
| def update_stat_info(self, name: str, item: Any): | |||||
| """Update stat_info based on anchor learnwares | |||||
| Parameters | |||||
| ---------- | |||||
| name : str | |||||
| Name of stat_info | |||||
| item : Any | |||||
| Statistical information calculated on anchor learnwares | |||||
| """ | |||||
| self.stat_info[name] = item | |||||
| @@ -227,7 +227,7 @@ class LearnwareMarket: | |||||
| def reload_learnware(self, learnware_id: str): | def reload_learnware(self, learnware_id: str): | ||||
| self.learnware_organizer.reload_learnware(learnware_id) | self.learnware_organizer.reload_learnware(learnware_id) | ||||
| def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | ||||
| return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) | return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) | ||||
| @@ -1,3 +1,15 @@ | |||||
| from .organizer import EasyOrganizer | from .organizer import EasyOrganizer | ||||
| from .searcher import EasySearcher | |||||
| from .checker import EasySemanticChecker, EasyStatChecker | |||||
| from ...utils import is_torch_avaliable | |||||
| from ...logger import get_module_logger | |||||
| logger = get_module_logger("market_easy") | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| EasySearcher = None | |||||
| EasySemanticChecker = None | |||||
| EasyStatChecker = None | |||||
| logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!") | |||||
| else: | |||||
| from .searcher import EasySearcher | |||||
| from .checker import EasySemanticChecker, EasyStatChecker | |||||
| @@ -166,7 +166,7 @@ class DatabaseOperations(object): | |||||
| return int(row[0]) | return int(row[0]) | ||||
| pass | pass | ||||
| pass | pass | ||||
| def load_market(self): | def load_market(self): | ||||
| with self.engine.connect() as conn: | with self.engine.connect() as conn: | ||||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | ||||
| @@ -387,7 +387,8 @@ class EasyOrganizer(BaseOrganizer): | |||||
| self.learnware_folder_list[learnware_id] = target_folder_dir | self.learnware_folder_list[learnware_id] = target_folder_dir | ||||
| semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) | semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) | ||||
| self.learnware_list[learnware_id] = get_learnware_from_dirpath( | self.learnware_list[learnware_id] = get_learnware_from_dirpath( | ||||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir) | |||||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||||
| ) | |||||
| self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) | self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) | ||||
| pass | pass | ||||
| @@ -2,25 +2,29 @@ from .base import LearnwareMarket | |||||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | ||||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | ||||
| MARKET_CONFIG = { | |||||
| "easy": { | |||||
| "organizer": EasyOrganizer(), | |||||
| "searcher": EasySearcher(), | |||||
| "checker_list": [EasySemanticChecker(), EasyStatChecker()], | |||||
| }, | |||||
| "hetero": { | |||||
| def get_market_config(): | |||||
| market_config = { | |||||
| "easy": { | |||||
| "organizer": EasyOrganizer(), | |||||
| "searcher": EasySearcher(), | |||||
| "checker_list": [EasySemanticChecker(), EasyStatChecker()], | |||||
| }, | |||||
| "hetero": { | |||||
| "organizer": HeteroMapTableOrganizer(), | "organizer": HeteroMapTableOrganizer(), | ||||
| "searcher": HeteroSearcher(), | "searcher": HeteroSearcher(), | ||||
| "checker_list": [] | "checker_list": [] | ||||
| } | |||||
| } | } | ||||
| } | |||||
| return market_config | |||||
| def instantiate_learnware_market(market_id="default", name="easy", **kwargs): | def instantiate_learnware_market(market_id="default", name="easy", **kwargs): | ||||
| market_config = get_market_config() | |||||
| return LearnwareMarket( | return LearnwareMarket( | ||||
| market_id=market_id, | market_id=market_id, | ||||
| organizer=MARKET_CONFIG[name]["organizer"], | |||||
| searcher=MARKET_CONFIG[name]["searcher"], | |||||
| checker_list=MARKET_CONFIG[name]["checker_list"], | |||||
| organizer=market_config[name]["organizer"], | |||||
| searcher=market_config[name]["searcher"], | |||||
| checker_list=market_config[name]["checker_list"], | |||||
| **kwargs | **kwargs | ||||
| ) | ) | ||||
| @@ -1,5 +1,4 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| from typing import Union | from typing import Union | ||||
| @@ -19,7 +18,7 @@ class BaseModel: | |||||
| self.input_shape = input_shape | self.input_shape = input_shape | ||||
| self.output_shape = output_shape | self.output_shape = output_shape | ||||
| def predict(self, X: Union[np.ndarray, torch.tensor]) -> Union[np.ndarray, torch.tensor]: | |||||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||||
| """The prediction method for model in learnware, which will be checked when learnware is submitted into the market. | """The prediction method for model in learnware, which will be checked when learnware is submitted into the market. | ||||
| Parameters | Parameters | ||||
| @@ -33,10 +32,10 @@ class BaseModel: | |||||
| """ | """ | ||||
| pass | pass | ||||
| def fit(self, X: Union[np.ndarray, torch.tensor], y: Union[np.ndarray, torch.tensor]): | |||||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||||
| pass | pass | ||||
| def finetune(self, X: Union[np.ndarray, torch.tensor], y: np.ndarray): | |||||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||||
| """The finetune method for continuing train the model searched by market | """The finetune method for continuing train the model searched by market | ||||
| Parameters | Parameters | ||||
| @@ -1,5 +1,42 @@ | |||||
| from .ensemble_pruning import EnsemblePruningReuser | |||||
| from .averaging import AveragingReuser | |||||
| from .job_selector import JobSelectorReuser | |||||
| from ..logger import get_module_logger | |||||
| from ..utils import is_torch_avaliable | |||||
| from .utils import is_geatpy_avaliable, is_lightgbm_avaliable | |||||
| logger = get_module_logger("reuse") | |||||
| if not is_geatpy_avaliable(verbose=False): | |||||
| EnsemblePruningReuser = None | |||||
| logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!") | |||||
| else: | |||||
| from .ensemble_pruning import EnsemblePruningReuser | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| AveragingReuser = None | |||||
| logger.warning("AveragingReuser is skipped due to 'torch' is not installed!") | |||||
| else: | |||||
| from .averaging import AveragingReuser | |||||
| if not is_lightgbm_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): | |||||
| JobSelectorReuser = None | |||||
| uninstall_packages = [ | |||||
| value | |||||
| for flag, value in zip( | |||||
| [ | |||||
| is_lightgbm_avaliable(verbose=False), | |||||
| is_torch_avaliable(verbose=False), | |||||
| ], | |||||
| ["lightgbm", "torch"], | |||||
| ) | |||||
| if flag is False | |||||
| ] | |||||
| logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!") | |||||
| else: | |||||
| from .job_selector import JobSelectorReuser | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| HeteroMapTableReuser = None | |||||
| logger.warning("FeatureAugmentReuser is skipped due to 'torch' is not installed!") | |||||
| else: | |||||
| from .hetero_reuser import HeteroMapTableReuser | |||||
| from .feature_augment_reuser import FeatureAugmentReuser | from .feature_augment_reuser import FeatureAugmentReuser | ||||
| from .hetero_reuser import HeteroMapTableReuser | |||||
| @@ -2,7 +2,6 @@ import torch | |||||
| import random | import random | ||||
| import numpy as np | import numpy as np | ||||
| import geatpy as ea | import geatpy as ea | ||||
| from typing import List | from typing import List | ||||
| from learnware.learnware import Learnware | from learnware.learnware import Learnware | ||||
| @@ -0,0 +1,25 @@ | |||||
| from ..logger import get_module_logger | |||||
| logger = get_module_logger("reuse_utils") | |||||
| def is_geatpy_avaliable(verbose=False): | |||||
| try: | |||||
| import geatpy | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning( | |||||
| "ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!" | |||||
| ) | |||||
| return False | |||||
| return True | |||||
| def is_lightgbm_avaliable(verbose=False): | |||||
| try: | |||||
| import lightgbm | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!") | |||||
| return False | |||||
| return True | |||||
| @@ -1,4 +1,3 @@ | |||||
| from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec | |||||
| from .base import Specification, BaseStatSpecification | from .base import Specification, BaseStatSpecification | ||||
| from .regular import ( | from .regular import ( | ||||
| RegularStatsSpecification, | RegularStatsSpecification, | ||||
| @@ -7,4 +6,15 @@ from .regular import ( | |||||
| RKMEImageSpecification, | RKMEImageSpecification, | ||||
| RKMETextSpecification, | RKMETextSpecification, | ||||
| ) | ) | ||||
| from .system import HeteroSpecification | from .system import HeteroSpecification | ||||
| from ..utils import is_torch_avaliable | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| generate_stat_spec = None | |||||
| generate_rkme_spec = None | |||||
| generate_rkme_image_spec = None | |||||
| generate_rkme_text_spec = None | |||||
| else: | |||||
| from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec | |||||
| @@ -1,4 +1,6 @@ | |||||
| from .base import RegularStatsSpecification | |||||
| from ...utils import is_torch_avaliable | |||||
| from .text import RKMETextSpecification | from .text import RKMETextSpecification | ||||
| from .table import RKMETableSpecification, RKMEStatSpecification | from .table import RKMETableSpecification, RKMEStatSpecification | ||||
| from .image import RKMEImageSpecification | from .image import RKMEImageSpecification | ||||
| from .base import RegularStatsSpecification | |||||
| @@ -1 +1,29 @@ | |||||
| from .rkme import RKMEImageSpecification | |||||
| from .utils import is_torch_optimizer_avaliable, is_torch_vision_avaliable | |||||
| from ....utils import is_torch_avaliable | |||||
| from ....logger import get_module_logger | |||||
| logger = get_module_logger("regular_image_spec") | |||||
| if ( | |||||
| not is_torch_vision_avaliable(verbose=False) | |||||
| or not is_torch_optimizer_avaliable(verbose=False) | |||||
| or not is_torch_avaliable(verbose=False) | |||||
| ): | |||||
| RKMEImageSpecification = None | |||||
| uninstall_packages = [ | |||||
| value | |||||
| for flag, value in zip( | |||||
| [ | |||||
| is_torch_vision_avaliable(verbose=False), | |||||
| is_torch_optimizer_avaliable(verbose=False), | |||||
| is_torch_avaliable(verbose=False), | |||||
| ], | |||||
| ["torchvision", "torch-optimizer", "torch"], | |||||
| ) | |||||
| if flag is False | |||||
| ] | |||||
| logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!") | |||||
| else: | |||||
| from .rkme import RKMEImageSpecification | |||||
| @@ -0,0 +1,23 @@ | |||||
| from ....logger import get_module_logger | |||||
| logger = get_module_logger("regular_image_spec_utils") | |||||
| def is_torch_optimizer_avaliable(verbose=False): | |||||
| try: | |||||
| import torch_optimizer | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!") | |||||
| return False | |||||
| return True | |||||
| def is_torch_vision_avaliable(verbose=False): | |||||
| try: | |||||
| import torchvision | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!") | |||||
| return False | |||||
| return True | |||||
| @@ -1 +1,11 @@ | |||||
| from .rkme import RKMETableSpecification, RKMEStatSpecification | |||||
| from ....utils import is_torch_avaliable | |||||
| from ....logger import get_module_logger | |||||
| logger = get_module_logger("regular_table_spec") | |||||
| if not is_torch_avaliable(verbose=False): | |||||
| RKMETableSpecification = None | |||||
| RKMEStatSpecification = None | |||||
| logger.warning("RKMETableSpecification is skipped because torch is not installed!") | |||||
| else: | |||||
| from .rkme import RKMETableSpecification, RKMEStatSpecification | |||||
| @@ -1 +1,23 @@ | |||||
| from .rkme import RKMETextSpecification | |||||
| from .utils import is_sentence_transformers_avaliable | |||||
| from ....utils import is_torch_avaliable | |||||
| from ....logger import get_module_logger | |||||
| logger = get_module_logger("regular_text_spec") | |||||
| if not is_sentence_transformers_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): | |||||
| RKMETextSpecification = None | |||||
| uninstall_packages = [ | |||||
| value | |||||
| for flag, value in zip( | |||||
| [ | |||||
| is_sentence_transformers_avaliable(verbose=False), | |||||
| is_torch_avaliable(verbose=False), | |||||
| ], | |||||
| ["sentence_transformers", "torch"], | |||||
| ) | |||||
| if flag is False | |||||
| ] | |||||
| logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!") | |||||
| else: | |||||
| from .rkme import RKMETextSpecification | |||||
| @@ -0,0 +1,15 @@ | |||||
| from ....logger import get_module_logger | |||||
| logger = get_module_logger("regular_text_spec_utils") | |||||
| def is_sentence_transformers_avaliable(verbose=False): | |||||
| try: | |||||
| import sentence_transformers | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning( | |||||
| "ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!" | |||||
| ) | |||||
| return False | |||||
| return True | |||||
| @@ -1,58 +0,0 @@ | |||||
| import os | |||||
| import sys | |||||
| import re | |||||
| import yaml | |||||
| import importlib | |||||
| import importlib.util | |||||
| from typing import Union | |||||
| from types import ModuleType | |||||
| import zipfile | |||||
| from .logger import get_module_logger | |||||
| logger = get_module_logger("utils") | |||||
| def get_module_by_module_path(module_path: Union[str, ModuleType]): | |||||
| if module_path is None: | |||||
| raise ModuleNotFoundError("None is passed in as parameters as module_path") | |||||
| if isinstance(module_path, ModuleType): | |||||
| module = module_path | |||||
| else: | |||||
| if module_path.endswith(".py"): | |||||
| module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) | |||||
| module_spec = importlib.util.spec_from_file_location(module_name, module_path) | |||||
| module = importlib.util.module_from_spec(module_spec) | |||||
| sys.modules[module_name] = module | |||||
| module_spec.loader.exec_module(module) | |||||
| else: | |||||
| module = importlib.import_module(module_path) | |||||
| return module | |||||
| def save_dict_to_yaml(dict_value: dict, save_path: str): | |||||
| """save dict object into yaml file""" | |||||
| with open(save_path, "w") as file: | |||||
| file.write(yaml.dump(dict_value, allow_unicode=True)) | |||||
| def read_yaml_to_dict(yaml_path: str): | |||||
| """load yaml file into dict object""" | |||||
| with open(yaml_path, "r") as file: | |||||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||||
| return dict_value | |||||
| def zip_learnware_folder(path: str, output_name: str): | |||||
| with zipfile.ZipFile(output_name, "w") as zip_ref: | |||||
| for root, dirs, files in os.walk(path): | |||||
| for file in files: | |||||
| full_path = os.path.join(root, file) | |||||
| if file.endswith(".pyc") or os.path.islink(full_path): | |||||
| continue | |||||
| zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) | |||||
| pass | |||||
| pass | |||||
| pass | |||||
| pass | |||||
| @@ -0,0 +1,16 @@ | |||||
| import os | |||||
| import zipfile | |||||
| from .import_utils import is_torch_avaliable | |||||
| from .module import get_module_by_module_path | |||||
| from .file import read_yaml_to_dict, save_dict_to_yaml | |||||
| def zip_learnware_folder(path: str, output_name: str): | |||||
| with zipfile.ZipFile(output_name, "w") as zip_ref: | |||||
| for root, dirs, files in os.walk(path): | |||||
| for file in files: | |||||
| full_path = os.path.join(root, file) | |||||
| if file.endswith(".pyc") or os.path.islink(full_path): | |||||
| continue | |||||
| zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) | |||||
| @@ -0,0 +1,14 @@ | |||||
| import yaml | |||||
| def save_dict_to_yaml(dict_value: dict, save_path: str): | |||||
| """save dict object into yaml file""" | |||||
| with open(save_path, "w") as file: | |||||
| file.write(yaml.dump(dict_value, allow_unicode=True)) | |||||
| def read_yaml_to_dict(yaml_path: str): | |||||
| """load yaml file into dict object""" | |||||
| with open(yaml_path, "r") as file: | |||||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||||
| return dict_value | |||||
| @@ -0,0 +1,13 @@ | |||||
| from ..logger import get_module_logger | |||||
| logger = get_module_logger("import_utils") | |||||
| def is_torch_avaliable(verbose=False): | |||||
| try: | |||||
| import torch | |||||
| except ModuleNotFoundError as err: | |||||
| if verbose is True: | |||||
| logger.warning("ModuleNotFoundError: torch is not installed, please install pytorch!") | |||||
| return False | |||||
| return True | |||||
| @@ -0,0 +1,24 @@ | |||||
| import sys | |||||
| import re | |||||
| import importlib | |||||
| import importlib.util | |||||
| from typing import Union | |||||
| from types import ModuleType | |||||
| def get_module_by_module_path(module_path: Union[str, ModuleType]): | |||||
| if module_path is None: | |||||
| raise ModuleNotFoundError("None is passed in as parameters as module_path") | |||||
| if isinstance(module_path, ModuleType): | |||||
| module = module_path | |||||
| else: | |||||
| if module_path.endswith(".py"): | |||||
| module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) | |||||
| module_spec = importlib.util.spec_from_file_location(module_name, module_path) | |||||
| module = importlib.util.module_from_spec(module_spec) | |||||
| sys.modules[module_name] = module | |||||
| module_spec.loader.exec_module(module) | |||||
| else: | |||||
| module = importlib.import_module(module_path) | |||||
| return module | |||||
| @@ -51,31 +51,23 @@ def get_platform(): | |||||
| # What packages are required for this module to be executed? | # What packages are required for this module to be executed? | ||||
| # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. | # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. | ||||
| REQUIRED = [ | REQUIRED = [ | ||||
| # "numpy>=1.20.0", | |||||
| # "pandas>=0.25.1", | |||||
| # "scipy>=1.0.0", | |||||
| # "matplotlib>=3.1.3", | |||||
| # "torch>=1.11.0", | |||||
| # "cvxopt>=1.3.0", | |||||
| # "tqdm>=4.65.0", | |||||
| # "scikit-learn>=0.22", | |||||
| # "joblib>=1.2.0", | |||||
| # "pyyaml>=6.0", | |||||
| # "fire>=0.3.1", | |||||
| # "lightgbm>=3.3.0", | |||||
| # "psutil>=5.9.4", | |||||
| # "torchvision>=0.15.1", | |||||
| # "sqlalchemy>=2.0.21", | |||||
| # "shortuuid>=1.0.11", | |||||
| # "geatpy>=2.7.0", | |||||
| # "docker>=6.1.3", | |||||
| # "rapidfuzz>=3.4.0", | |||||
| # "torchtext>=0.16.0", | |||||
| # "sentence_transformers>=2.2.2", | |||||
| # "torch-optimizer>=0.3.0", | |||||
| # "langdetect>=1.0.9", | |||||
| # "huggingface-hub<0.18", | |||||
| # "portalocker>=2.0.0", | |||||
| "numpy>=1.20.0", | |||||
| "pandas>=0.25.1", | |||||
| "scipy>=1.0.0", | |||||
| "cvxopt>=1.3.0", | |||||
| "tqdm>=4.65.0", | |||||
| "scikit-learn>=0.22", | |||||
| "joblib>=1.2.0", | |||||
| "pyyaml>=6.0", | |||||
| "fire>=0.3.1", | |||||
| "psutil>=5.9.4", | |||||
| "sqlalchemy>=2.0.21", | |||||
| "shortuuid>=1.0.11", | |||||
| "docker>=6.1.3", | |||||
| "rapidfuzz>=3.4.0", | |||||
| "langdetect>=1.0.9", | |||||
| "huggingface-hub<0.18", | |||||
| "portalocker>=2.0.0", | |||||
| ] | ] | ||||
| if get_platform() != MACOS: | if get_platform() != MACOS: | ||||
| @@ -99,6 +91,23 @@ if __name__ == "__main__": | |||||
| long_description_content_type="text/markdown", | long_description_content_type="text/markdown", | ||||
| python_requires=REQUIRES_PYTHON, | python_requires=REQUIRES_PYTHON, | ||||
| install_requires=REQUIRED, | install_requires=REQUIRED, | ||||
| extras_require={ | |||||
| "dev": [ | |||||
| # For documentations | |||||
| "sphinx", | |||||
| "sphinx_rtd_theme", | |||||
| # CI dependencies | |||||
| "pytest>=3", | |||||
| "wheel", | |||||
| "setuptools", | |||||
| "pylint", | |||||
| # For static analysis | |||||
| "mypy<0.981", | |||||
| "flake8", | |||||
| "black==23.1.0", | |||||
| "pre-commit", | |||||
| ], | |||||
| }, | |||||
| classifiers=[ | classifiers=[ | ||||
| "Intended Audience :: Science/Research", | "Intended Audience :: Science/Research", | ||||
| "Intended Audience :: Developers", | "Intended Audience :: Developers", | ||||
| @@ -108,8 +117,10 @@ if __name__ == "__main__": | |||||
| "Operating System :: POSIX :: Linux", | "Operating System :: POSIX :: Linux", | ||||
| "Operating System :: Microsoft :: Windows", | "Operating System :: Microsoft :: Windows", | ||||
| "Operating System :: MacOS", | "Operating System :: MacOS", | ||||
| "Programming Language :: Python :: 3.6", | |||||
| "Programming Language :: Python :: 3.7", | "Programming Language :: Python :: 3.7", | ||||
| "Programming Language :: Python :: 3.8", | "Programming Language :: Python :: 3.8", | ||||
| "Programming Language :: Python :: 3.9", | |||||
| "Programming Language :: Python :: 3.10", | |||||
| "Programming Language :: Python :: 3.11", | |||||
| ], | ], | ||||
| ) | ) | ||||
| @@ -30,7 +30,7 @@ user_semantic = { | |||||
| } | } | ||||
| class TestMarket(unittest.TestCase): | |||||
| class TestWorkflow(unittest.TestCase): | |||||
| @classmethod | @classmethod | ||||
| def setUpClass(cls) -> None: | def setUpClass(cls) -> None: | ||||
| np.random.seed(2023) | np.random.seed(2023) | ||||
| @@ -226,11 +226,11 @@ class TestMarket(unittest.TestCase): | |||||
| def suite(): | def suite(): | ||||
| _suite = unittest.TestSuite() | _suite = unittest.TestSuite() | ||||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||||
| _suite.addTest(TestMarket("test_search_semantics")) | |||||
| _suite.addTest(TestMarket("test_stat_search")) | |||||
| _suite.addTest(TestMarket("test_learnware_reuse")) | |||||
| _suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) | |||||
| _suite.addTest(TestWorkflow("test_upload_delete_learnware")) | |||||
| _suite.addTest(TestWorkflow("test_search_semantics")) | |||||
| _suite.addTest(TestWorkflow("test_stat_search")) | |||||
| _suite.addTest(TestWorkflow("test_learnware_reuse")) | |||||
| return _suite | return _suite | ||||