diff --git a/.gitignore b/.gitignore index 92d2a4d..d22ea69 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ dist/ *.db *.json *.zip +*.bin # special software .pytest_cache/ @@ -43,3 +44,4 @@ tmp/ learnware_pool/ PFS/ data/ +learnware/market/hetergeneous/.learnware/* \ No newline at end of file diff --git a/learnware/market/hetergeneous/organizer/__init__.py b/learnware/market/hetergeneous/organizer/__init__.py index db0c05e..85c3e7f 100644 --- a/learnware/market/hetergeneous/organizer/__init__.py +++ b/learnware/market/hetergeneous/organizer/__init__.py @@ -1,16 +1,20 @@ from __future__ import annotations +import copy import multiprocessing import os +import tempfile +import zipfile from collections import defaultdict +from shutil import copyfile, rmtree from typing import List import pandas as pd -from ....learnware import Learnware +from ....learnware import Learnware, get_learnware_from_dirpath from ....logger import get_module_logger from ....specification.system import HeteroSpecification -from ...base import BaseUserInfo +from ...base import BaseChecker, BaseUserInfo from ...easy2 import EasyOrganizer from ..database_ops import DatabaseOperations from .config import C as conf @@ -68,27 +72,53 @@ class HeteroMapTableOrganizer(EasyOrganizer): self.training_args = kwargs def add_learnware( - self, zip_path: str, semantic_spec: dict, check_status: int, learnware: Learnware + self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None ) -> Tuple[str, int]: - self._update_learnware_list([learnware]) - self.learnware_list[learnware.id] = learnware + logger.info("Get new learnware from %s" % (zip_path)) + + learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id + target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) + target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) + copyfile(zip_path, target_zip_dir) + + with zipfile.ZipFile(target_zip_dir, "r") as z_file: + z_file.extractall(target_folder_dir) + logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir)) + + try: + new_learnware = get_learnware_from_dirpath( + id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir + ) + except: + logger.info("New Learnware Not Properly Added!!!") + try: + os.remove(target_zip_dir) + rmtree(target_folder_dir) + except: + pass + return None, BaseChecker.INVALID_LEARNWARE + + if new_learnware is None: + return None, BaseChecker.INVALID_LEARNWARE + + learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE + + self._update_learnware_list([new_learnware]) + self.learnware_list[learnware_id] = new_learnware + self.learnware_zip_list[learnware_id] = target_zip_dir + self.learnware_folder_list[learnware_id] = target_folder_dir + self.use_flags[learnware_id] = learnwere_status self.count += 1 if self.auto_update and self.count >= self.auto_update_limit: - train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list,)) + train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list.values(),)) train_process.start() # train_process.join() + + return learnware_id, learnwere_status - def delete_learnware(self, id: str) -> bool: - raise NotImplementedError - - def update_learnware(self, learnware: Learnware): - raise NotImplementedError - - def get_learnwares(self): - return self.learnware_list - - def train(self, learnware_list: List[Learnware]): + def train(self, learnware_list: List[Learnware] = None): + learnware_list = learnware_list or self.learnware_list.values() allset = self._learnwares_to_dataframes(learnware_list) self.market_mapping = HeteroMapping(**self.training_args) market_mapping_trainer = Trainer( @@ -115,7 +145,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): def _update_learnware_specification(self, learnware: Learnware, save_path: str) -> Learnware: specification = learnware.specification - learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"] + learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"] learnware_features = specification.get_semantic_spec()["Input"]["Description"].values() learnware_hetero_spec = self.market_mapping.hetero_mapping(learnware_rkme, learnware_features) learnware.update_stat_spec("HeteroSpecification", learnware_hetero_spec) @@ -124,7 +154,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): learnware_hetero_spec.save(save_path) def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroSpecification: - user_rkme = user_info.stat_info["RKMEStatSpecification"] + user_rkme = user_info.stat_info["RKMETableSpecification"] user_features = user_info.semantic_spec["Input"]["Description"].values() user_hetero_spec = self.market_mapping.hetero_mapping(user_rkme, user_features) return user_hetero_spec @@ -133,7 +163,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): learnware_df_dict = defaultdict(list) for learnware in learnware_list: specification = learnware.get_specification() - learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"] + learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"] learnware_features = specification.get_semantic_spec()["Input"]["Description"] learnware_df = pd.DataFrame(data=learnware_rkme.get_z(), columns=learnware_features.values()) @@ -143,7 +173,4 @@ class HeteroMapTableOrganizer(EasyOrganizer): return merged_dfs def save(self, save_path): - return NotImplementedError - - def __len__(self): - return len(self.learnware_list) + return NotImplementedError \ No newline at end of file diff --git a/learnware/market/hetergeneous/searcher.py b/learnware/market/hetergeneous/searcher.py index d70b06e..90bbefd 100644 --- a/learnware/market/hetergeneous/searcher.py +++ b/learnware/market/hetergeneous/searcher.py @@ -17,7 +17,7 @@ class HeteroMapTableSearcher(BaseSearcher): learnware_list = self.learnware_oganizer.get_learnwares() target_learnware, min_dist = None, None user_hetero_spec = self.learnware_oganizer.generate_hetero_map_spec(user_info) - for learnware in learnware_list.values(): + for learnware in learnware_list: learnware_hetero_spec = learnware.specification.get_stat_spec_by_name("HeteroSpecification") mmd_dist = learnware_hetero_spec.dist(user_hetero_spec) if target_learnware is None or mmd_dist < min_dist: diff --git a/learnware/market/module.py b/learnware/market/module.py index 43499ec..f60e06a 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -11,7 +11,7 @@ MARKET_CONFIG = { "hetero": { "organizer": HeteroMapTableOrganizer(), "searcher": HeteroMapTableSearcher(), - "checker_list": [] + "checker_list": [EasySemanticChecker(), EasyStatChecker()] } } diff --git a/learnware/specification/system/heter_table.py b/learnware/specification/system/heter_table.py index e611ca1..f721d8f 100644 --- a/learnware/specification/system/heter_table.py +++ b/learnware/specification/system/heter_table.py @@ -8,7 +8,7 @@ import os import numpy as np import torch -from ..regular.table import RKMEStatSpecification +from ..regular import RKMETableSpecification from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel from .base import SystemStatsSpecification @@ -34,7 +34,7 @@ class HeteroSpecification(SystemStatsSpecification): def get_beta(self) -> np.ndarray: return self.beta.detach().cpu().numpy - def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMEStatSpecification): + def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMETableSpecification): self.beta = rkme_spec.beta.to(self.device) self.z = torch.from_numpy(heter_embedding).double().to(self.device)