| @@ -0,0 +1,12 @@ | |||
| repos: | |||
| - repo: https://github.com/psf/black | |||
| rev: 23.1.0 | |||
| hooks: | |||
| - id: black | |||
| args: ["-l 120"] | |||
| - repo: https://github.com/PyCQA/flake8 | |||
| rev: 4.0.1 | |||
| hooks: | |||
| - id: flake8 | |||
| args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics"] | |||
| @@ -1,4 +1,3 @@ | |||
| import torch | |||
| from torchtext.datasets import SST2 | |||
| @@ -0,0 +1 @@ | |||
| torchtext>=0.14.1 | |||
| @@ -1,7 +1,10 @@ | |||
| __version__ = "0.1.1.99" | |||
| __version__ = "0.1.2.99" | |||
| import os | |||
| from .logger import get_module_logger | |||
| from .utils import is_torch_avaliable | |||
| logger = get_module_logger("Initialization") | |||
| def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | |||
| @@ -10,9 +13,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | |||
| C.reset() | |||
| C.update(**kwargs) | |||
| logger = get_module_logger("Initialization") | |||
| logger.info(f"init learnware market with {kwargs}") | |||
| ## make dirs | |||
| if make_dir: | |||
| os.makedirs(C.root_path, exist_ok=True) | |||
| @@ -25,3 +26,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): | |||
| ## ignore tensorflow warning | |||
| # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel | |||
| # logger.info(f"The tensorflow log level is setted to {tf_loglevel}") | |||
| if not is_torch_avaliable(verbose=False): | |||
| logger.warning("The functionality of learnware is limited due to 'torch' is not installed!") | |||
| @@ -18,7 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker | |||
| from ..logger import get_module_logger | |||
| from ..specification import Specification | |||
| from ..learnware import get_learnware_from_dirpath | |||
| from ..test import get_semantic_specification | |||
| from ..tests import get_semantic_specification | |||
| CHUNK_SIZE = 1024 * 1024 | |||
| logger = get_module_logger(module_name="LearnwareClient") | |||
| @@ -1,4 +1,4 @@ | |||
| from .anchor import AnchoredUserInfo, AnchoredOrganizer | |||
| from .anchor import AnchoredUserInfo, AnchoredSearcher, AnchoredOrganizer | |||
| from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher | |||
| from .evolve_anchor import EvolvedAnchoredOrganizer | |||
| from .evolve import EvolvedOrganizer | |||
| @@ -1,2 +1,13 @@ | |||
| from .organizer import AnchoredOrganizer | |||
| from .searcher import AnchoredUserInfo | |||
| from .user_info import AnchoredUserInfo | |||
| from ...utils import is_torch_avaliable | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("market_anchor") | |||
| if not is_torch_avaliable(verbose=False): | |||
| AnchoredSearcher = None | |||
| logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!") | |||
| else: | |||
| from .searcher import AnchoredSearcher | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import List, Dict, Tuple, Any, Union | |||
| from .user_info import AnchoredUserInfo | |||
| from ..base import BaseUserInfo | |||
| from ..easy.searcher import EasySearcher | |||
| from ...logger import get_module_logger | |||
| @@ -8,45 +9,6 @@ from ...learnware import Learnware | |||
| logger = get_module_logger("anchor_searcher") | |||
| class AnchoredUserInfo(BaseUserInfo): | |||
| """ | |||
| User Information for searching learnware (add the anchor design) | |||
| - UserInfo contains the anchor id list acquired from the market | |||
| - UserInfo can update stat_info based on anchors | |||
| """ | |||
| def __init__( | |||
| self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None | |||
| ): | |||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||
| self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids | |||
| def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): | |||
| """Add the anchor learnware ids acquired from the market | |||
| Parameters | |||
| ---------- | |||
| learnware_ids : Union[str, List[str]] | |||
| Anchor learnware ids | |||
| """ | |||
| if isinstance(learnware_ids, str): | |||
| learnware_ids = [learnware_ids] | |||
| self.anchor_learnware_ids += learnware_ids | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info based on anchor learnwares | |||
| Parameters | |||
| ---------- | |||
| name : str | |||
| Name of stat_info | |||
| item : Any | |||
| Statistical information calculated on anchor learnwares | |||
| """ | |||
| self.stat_info[name] = item | |||
| class AnchoredSearcher(EasySearcher): | |||
| def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: | |||
| """Search anchor Learnwares from anchor_learnware_list based on user_info | |||
| @@ -0,0 +1,41 @@ | |||
| from typing import List, Any, Union | |||
| from ..base import BaseUserInfo | |||
| class AnchoredUserInfo(BaseUserInfo): | |||
| """ | |||
| User Information for searching learnware (add the anchor design) | |||
| - UserInfo contains the anchor id list acquired from the market | |||
| - UserInfo can update stat_info based on anchors | |||
| """ | |||
| def __init__( | |||
| self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None | |||
| ): | |||
| super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info) | |||
| self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids | |||
| def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]): | |||
| """Add the anchor learnware ids acquired from the market | |||
| Parameters | |||
| ---------- | |||
| learnware_ids : Union[str, List[str]] | |||
| Anchor learnware ids | |||
| """ | |||
| if isinstance(learnware_ids, str): | |||
| learnware_ids = [learnware_ids] | |||
| self.anchor_learnware_ids += learnware_ids | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info based on anchor learnwares | |||
| Parameters | |||
| ---------- | |||
| name : str | |||
| Name of stat_info | |||
| item : Any | |||
| Statistical information calculated on anchor learnwares | |||
| """ | |||
| self.stat_info[name] = item | |||
| @@ -227,7 +227,7 @@ class LearnwareMarket: | |||
| def reload_learnware(self, learnware_id: str): | |||
| self.learnware_organizer.reload_learnware(learnware_id) | |||
| def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: | |||
| return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) | |||
| @@ -1,3 +1,15 @@ | |||
| from .organizer import EasyOrganizer | |||
| from .searcher import EasySearcher | |||
| from .checker import EasySemanticChecker, EasyStatChecker | |||
| from ...utils import is_torch_avaliable | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("market_easy") | |||
| if not is_torch_avaliable(verbose=False): | |||
| EasySearcher = None | |||
| EasySemanticChecker = None | |||
| EasyStatChecker = None | |||
| logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!") | |||
| else: | |||
| from .searcher import EasySearcher | |||
| from .checker import EasySemanticChecker, EasyStatChecker | |||
| @@ -166,7 +166,7 @@ class DatabaseOperations(object): | |||
| return int(row[0]) | |||
| pass | |||
| pass | |||
| def load_market(self): | |||
| with self.engine.connect() as conn: | |||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | |||
| @@ -387,7 +387,8 @@ class EasyOrganizer(BaseOrganizer): | |||
| self.learnware_folder_list[learnware_id] = target_folder_dir | |||
| semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) | |||
| self.learnware_list[learnware_id] = get_learnware_from_dirpath( | |||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir) | |||
| id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir | |||
| ) | |||
| self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) | |||
| pass | |||
| @@ -2,25 +2,29 @@ from .base import LearnwareMarket | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| MARKET_CONFIG = { | |||
| "easy": { | |||
| "organizer": EasyOrganizer(), | |||
| "searcher": EasySearcher(), | |||
| "checker_list": [EasySemanticChecker(), EasyStatChecker()], | |||
| }, | |||
| "hetero": { | |||
| def get_market_config(): | |||
| market_config = { | |||
| "easy": { | |||
| "organizer": EasyOrganizer(), | |||
| "searcher": EasySearcher(), | |||
| "checker_list": [EasySemanticChecker(), EasyStatChecker()], | |||
| }, | |||
| "hetero": { | |||
| "organizer": HeteroMapTableOrganizer(), | |||
| "searcher": HeteroSearcher(), | |||
| "checker_list": [] | |||
| } | |||
| } | |||
| } | |||
| return market_config | |||
| def instantiate_learnware_market(market_id="default", name="easy", **kwargs): | |||
| market_config = get_market_config() | |||
| return LearnwareMarket( | |||
| market_id=market_id, | |||
| organizer=MARKET_CONFIG[name]["organizer"], | |||
| searcher=MARKET_CONFIG[name]["searcher"], | |||
| checker_list=MARKET_CONFIG[name]["checker_list"], | |||
| organizer=market_config[name]["organizer"], | |||
| searcher=market_config[name]["searcher"], | |||
| checker_list=market_config[name]["checker_list"], | |||
| **kwargs | |||
| ) | |||
| @@ -1,5 +1,4 @@ | |||
| import numpy as np | |||
| import torch | |||
| from typing import Union | |||
| @@ -19,7 +18,7 @@ class BaseModel: | |||
| self.input_shape = input_shape | |||
| self.output_shape = output_shape | |||
| def predict(self, X: Union[np.ndarray, torch.tensor]) -> Union[np.ndarray, torch.tensor]: | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| """The prediction method for model in learnware, which will be checked when learnware is submitted into the market. | |||
| Parameters | |||
| @@ -33,10 +32,10 @@ class BaseModel: | |||
| """ | |||
| pass | |||
| def fit(self, X: Union[np.ndarray, torch.tensor], y: Union[np.ndarray, torch.tensor]): | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| def finetune(self, X: Union[np.ndarray, torch.tensor], y: np.ndarray): | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| """The finetune method for continuing train the model searched by market | |||
| Parameters | |||
| @@ -1,5 +1,42 @@ | |||
| from .ensemble_pruning import EnsemblePruningReuser | |||
| from .averaging import AveragingReuser | |||
| from .job_selector import JobSelectorReuser | |||
| from ..logger import get_module_logger | |||
| from ..utils import is_torch_avaliable | |||
| from .utils import is_geatpy_avaliable, is_lightgbm_avaliable | |||
| logger = get_module_logger("reuse") | |||
| if not is_geatpy_avaliable(verbose=False): | |||
| EnsemblePruningReuser = None | |||
| logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!") | |||
| else: | |||
| from .ensemble_pruning import EnsemblePruningReuser | |||
| if not is_torch_avaliable(verbose=False): | |||
| AveragingReuser = None | |||
| logger.warning("AveragingReuser is skipped due to 'torch' is not installed!") | |||
| else: | |||
| from .averaging import AveragingReuser | |||
| if not is_lightgbm_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): | |||
| JobSelectorReuser = None | |||
| uninstall_packages = [ | |||
| value | |||
| for flag, value in zip( | |||
| [ | |||
| is_lightgbm_avaliable(verbose=False), | |||
| is_torch_avaliable(verbose=False), | |||
| ], | |||
| ["lightgbm", "torch"], | |||
| ) | |||
| if flag is False | |||
| ] | |||
| logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!") | |||
| else: | |||
| from .job_selector import JobSelectorReuser | |||
| if not is_torch_avaliable(verbose=False): | |||
| HeteroMapTableReuser = None | |||
| logger.warning("FeatureAugmentReuser is skipped due to 'torch' is not installed!") | |||
| else: | |||
| from .hetero_reuser import HeteroMapTableReuser | |||
| from .feature_augment_reuser import FeatureAugmentReuser | |||
| from .hetero_reuser import HeteroMapTableReuser | |||
| @@ -2,7 +2,6 @@ import torch | |||
| import random | |||
| import numpy as np | |||
| import geatpy as ea | |||
| from typing import List | |||
| from learnware.learnware import Learnware | |||
| @@ -0,0 +1,25 @@ | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("reuse_utils") | |||
| def is_geatpy_avaliable(verbose=False): | |||
| try: | |||
| import geatpy | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning( | |||
| "ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!" | |||
| ) | |||
| return False | |||
| return True | |||
| def is_lightgbm_avaliable(verbose=False): | |||
| try: | |||
| import lightgbm | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!") | |||
| return False | |||
| return True | |||
| @@ -1,4 +1,3 @@ | |||
| from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec | |||
| from .base import Specification, BaseStatSpecification | |||
| from .regular import ( | |||
| RegularStatsSpecification, | |||
| @@ -7,4 +6,15 @@ from .regular import ( | |||
| RKMEImageSpecification, | |||
| RKMETextSpecification, | |||
| ) | |||
| from .system import HeteroSpecification | |||
| from ..utils import is_torch_avaliable | |||
| if not is_torch_avaliable(verbose=False): | |||
| generate_stat_spec = None | |||
| generate_rkme_spec = None | |||
| generate_rkme_image_spec = None | |||
| generate_rkme_text_spec = None | |||
| else: | |||
| from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec | |||
| @@ -1,4 +1,6 @@ | |||
| from .base import RegularStatsSpecification | |||
| from ...utils import is_torch_avaliable | |||
| from .text import RKMETextSpecification | |||
| from .table import RKMETableSpecification, RKMEStatSpecification | |||
| from .image import RKMEImageSpecification | |||
| from .base import RegularStatsSpecification | |||
| @@ -1 +1,29 @@ | |||
| from .rkme import RKMEImageSpecification | |||
| from .utils import is_torch_optimizer_avaliable, is_torch_vision_avaliable | |||
| from ....utils import is_torch_avaliable | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("regular_image_spec") | |||
| if ( | |||
| not is_torch_vision_avaliable(verbose=False) | |||
| or not is_torch_optimizer_avaliable(verbose=False) | |||
| or not is_torch_avaliable(verbose=False) | |||
| ): | |||
| RKMEImageSpecification = None | |||
| uninstall_packages = [ | |||
| value | |||
| for flag, value in zip( | |||
| [ | |||
| is_torch_vision_avaliable(verbose=False), | |||
| is_torch_optimizer_avaliable(verbose=False), | |||
| is_torch_avaliable(verbose=False), | |||
| ], | |||
| ["torchvision", "torch-optimizer", "torch"], | |||
| ) | |||
| if flag is False | |||
| ] | |||
| logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!") | |||
| else: | |||
| from .rkme import RKMEImageSpecification | |||
| @@ -0,0 +1,23 @@ | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("regular_image_spec_utils") | |||
| def is_torch_optimizer_avaliable(verbose=False): | |||
| try: | |||
| import torch_optimizer | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!") | |||
| return False | |||
| return True | |||
| def is_torch_vision_avaliable(verbose=False): | |||
| try: | |||
| import torchvision | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!") | |||
| return False | |||
| return True | |||
| @@ -1 +1,11 @@ | |||
| from .rkme import RKMETableSpecification, RKMEStatSpecification | |||
| from ....utils import is_torch_avaliable | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("regular_table_spec") | |||
| if not is_torch_avaliable(verbose=False): | |||
| RKMETableSpecification = None | |||
| RKMEStatSpecification = None | |||
| logger.warning("RKMETableSpecification is skipped because torch is not installed!") | |||
| else: | |||
| from .rkme import RKMETableSpecification, RKMEStatSpecification | |||
| @@ -1 +1,23 @@ | |||
| from .rkme import RKMETextSpecification | |||
| from .utils import is_sentence_transformers_avaliable | |||
| from ....utils import is_torch_avaliable | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("regular_text_spec") | |||
| if not is_sentence_transformers_avaliable(verbose=False) or not is_torch_avaliable(verbose=False): | |||
| RKMETextSpecification = None | |||
| uninstall_packages = [ | |||
| value | |||
| for flag, value in zip( | |||
| [ | |||
| is_sentence_transformers_avaliable(verbose=False), | |||
| is_torch_avaliable(verbose=False), | |||
| ], | |||
| ["sentence_transformers", "torch"], | |||
| ) | |||
| if flag is False | |||
| ] | |||
| logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!") | |||
| else: | |||
| from .rkme import RKMETextSpecification | |||
| @@ -0,0 +1,15 @@ | |||
| from ....logger import get_module_logger | |||
| logger = get_module_logger("regular_text_spec_utils") | |||
| def is_sentence_transformers_avaliable(verbose=False): | |||
| try: | |||
| import sentence_transformers | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning( | |||
| "ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!" | |||
| ) | |||
| return False | |||
| return True | |||
| @@ -1,58 +0,0 @@ | |||
| import os | |||
| import sys | |||
| import re | |||
| import yaml | |||
| import importlib | |||
| import importlib.util | |||
| from typing import Union | |||
| from types import ModuleType | |||
| import zipfile | |||
| from .logger import get_module_logger | |||
| logger = get_module_logger("utils") | |||
| def get_module_by_module_path(module_path: Union[str, ModuleType]): | |||
| if module_path is None: | |||
| raise ModuleNotFoundError("None is passed in as parameters as module_path") | |||
| if isinstance(module_path, ModuleType): | |||
| module = module_path | |||
| else: | |||
| if module_path.endswith(".py"): | |||
| module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) | |||
| module_spec = importlib.util.spec_from_file_location(module_name, module_path) | |||
| module = importlib.util.module_from_spec(module_spec) | |||
| sys.modules[module_name] = module | |||
| module_spec.loader.exec_module(module) | |||
| else: | |||
| module = importlib.import_module(module_path) | |||
| return module | |||
| def save_dict_to_yaml(dict_value: dict, save_path: str): | |||
| """save dict object into yaml file""" | |||
| with open(save_path, "w") as file: | |||
| file.write(yaml.dump(dict_value, allow_unicode=True)) | |||
| def read_yaml_to_dict(yaml_path: str): | |||
| """load yaml file into dict object""" | |||
| with open(yaml_path, "r") as file: | |||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||
| return dict_value | |||
| def zip_learnware_folder(path: str, output_name: str): | |||
| with zipfile.ZipFile(output_name, "w") as zip_ref: | |||
| for root, dirs, files in os.walk(path): | |||
| for file in files: | |||
| full_path = os.path.join(root, file) | |||
| if file.endswith(".pyc") or os.path.islink(full_path): | |||
| continue | |||
| zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) | |||
| pass | |||
| pass | |||
| pass | |||
| pass | |||
| @@ -0,0 +1,16 @@ | |||
| import os | |||
| import zipfile | |||
| from .import_utils import is_torch_avaliable | |||
| from .module import get_module_by_module_path | |||
| from .file import read_yaml_to_dict, save_dict_to_yaml | |||
| def zip_learnware_folder(path: str, output_name: str): | |||
| with zipfile.ZipFile(output_name, "w") as zip_ref: | |||
| for root, dirs, files in os.walk(path): | |||
| for file in files: | |||
| full_path = os.path.join(root, file) | |||
| if file.endswith(".pyc") or os.path.islink(full_path): | |||
| continue | |||
| zip_ref.write(full_path, arcname=os.path.relpath(full_path, path)) | |||
| @@ -0,0 +1,14 @@ | |||
| import yaml | |||
| def save_dict_to_yaml(dict_value: dict, save_path: str): | |||
| """save dict object into yaml file""" | |||
| with open(save_path, "w") as file: | |||
| file.write(yaml.dump(dict_value, allow_unicode=True)) | |||
| def read_yaml_to_dict(yaml_path: str): | |||
| """load yaml file into dict object""" | |||
| with open(yaml_path, "r") as file: | |||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||
| return dict_value | |||
| @@ -0,0 +1,13 @@ | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("import_utils") | |||
| def is_torch_avaliable(verbose=False): | |||
| try: | |||
| import torch | |||
| except ModuleNotFoundError as err: | |||
| if verbose is True: | |||
| logger.warning("ModuleNotFoundError: torch is not installed, please install pytorch!") | |||
| return False | |||
| return True | |||
| @@ -0,0 +1,24 @@ | |||
| import sys | |||
| import re | |||
| import importlib | |||
| import importlib.util | |||
| from typing import Union | |||
| from types import ModuleType | |||
| def get_module_by_module_path(module_path: Union[str, ModuleType]): | |||
| if module_path is None: | |||
| raise ModuleNotFoundError("None is passed in as parameters as module_path") | |||
| if isinstance(module_path, ModuleType): | |||
| module = module_path | |||
| else: | |||
| if module_path.endswith(".py"): | |||
| module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) | |||
| module_spec = importlib.util.spec_from_file_location(module_name, module_path) | |||
| module = importlib.util.module_from_spec(module_spec) | |||
| sys.modules[module_name] = module | |||
| module_spec.loader.exec_module(module) | |||
| else: | |||
| module = importlib.import_module(module_path) | |||
| return module | |||
| @@ -51,31 +51,23 @@ def get_platform(): | |||
| # What packages are required for this module to be executed? | |||
| # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. | |||
| REQUIRED = [ | |||
| # "numpy>=1.20.0", | |||
| # "pandas>=0.25.1", | |||
| # "scipy>=1.0.0", | |||
| # "matplotlib>=3.1.3", | |||
| # "torch>=1.11.0", | |||
| # "cvxopt>=1.3.0", | |||
| # "tqdm>=4.65.0", | |||
| # "scikit-learn>=0.22", | |||
| # "joblib>=1.2.0", | |||
| # "pyyaml>=6.0", | |||
| # "fire>=0.3.1", | |||
| # "lightgbm>=3.3.0", | |||
| # "psutil>=5.9.4", | |||
| # "torchvision>=0.15.1", | |||
| # "sqlalchemy>=2.0.21", | |||
| # "shortuuid>=1.0.11", | |||
| # "geatpy>=2.7.0", | |||
| # "docker>=6.1.3", | |||
| # "rapidfuzz>=3.4.0", | |||
| # "torchtext>=0.16.0", | |||
| # "sentence_transformers>=2.2.2", | |||
| # "torch-optimizer>=0.3.0", | |||
| # "langdetect>=1.0.9", | |||
| # "huggingface-hub<0.18", | |||
| # "portalocker>=2.0.0", | |||
| "numpy>=1.20.0", | |||
| "pandas>=0.25.1", | |||
| "scipy>=1.0.0", | |||
| "cvxopt>=1.3.0", | |||
| "tqdm>=4.65.0", | |||
| "scikit-learn>=0.22", | |||
| "joblib>=1.2.0", | |||
| "pyyaml>=6.0", | |||
| "fire>=0.3.1", | |||
| "psutil>=5.9.4", | |||
| "sqlalchemy>=2.0.21", | |||
| "shortuuid>=1.0.11", | |||
| "docker>=6.1.3", | |||
| "rapidfuzz>=3.4.0", | |||
| "langdetect>=1.0.9", | |||
| "huggingface-hub<0.18", | |||
| "portalocker>=2.0.0", | |||
| ] | |||
| if get_platform() != MACOS: | |||
| @@ -99,6 +91,23 @@ if __name__ == "__main__": | |||
| long_description_content_type="text/markdown", | |||
| python_requires=REQUIRES_PYTHON, | |||
| install_requires=REQUIRED, | |||
| extras_require={ | |||
| "dev": [ | |||
| # For documentations | |||
| "sphinx", | |||
| "sphinx_rtd_theme", | |||
| # CI dependencies | |||
| "pytest>=3", | |||
| "wheel", | |||
| "setuptools", | |||
| "pylint", | |||
| # For static analysis | |||
| "mypy<0.981", | |||
| "flake8", | |||
| "black==23.1.0", | |||
| "pre-commit", | |||
| ], | |||
| }, | |||
| classifiers=[ | |||
| "Intended Audience :: Science/Research", | |||
| "Intended Audience :: Developers", | |||
| @@ -108,8 +117,10 @@ if __name__ == "__main__": | |||
| "Operating System :: POSIX :: Linux", | |||
| "Operating System :: Microsoft :: Windows", | |||
| "Operating System :: MacOS", | |||
| "Programming Language :: Python :: 3.6", | |||
| "Programming Language :: Python :: 3.7", | |||
| "Programming Language :: Python :: 3.8", | |||
| "Programming Language :: Python :: 3.9", | |||
| "Programming Language :: Python :: 3.10", | |||
| "Programming Language :: Python :: 3.11", | |||
| ], | |||
| ) | |||
| @@ -30,7 +30,7 @@ user_semantic = { | |||
| } | |||
| class TestMarket(unittest.TestCase): | |||
| class TestWorkflow(unittest.TestCase): | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| np.random.seed(2023) | |||
| @@ -226,11 +226,11 @@ class TestMarket(unittest.TestCase): | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| _suite.addTest(TestMarket("test_learnware_reuse")) | |||
| _suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestWorkflow("test_upload_delete_learnware")) | |||
| _suite.addTest(TestWorkflow("test_search_semantics")) | |||
| _suite.addTest(TestWorkflow("test_stat_search")) | |||
| _suite.addTest(TestWorkflow("test_learnware_reuse")) | |||
| return _suite | |||