Browse Source

[MNT] fix details in add_learnware and searcher

tags/v0.3.2
liuht 2 years ago
parent
commit
664463ca57
5 changed files with 56 additions and 27 deletions
  1. +2
    -0
      .gitignore
  2. +50
    -23
      learnware/market/hetergeneous/organizer/__init__.py
  3. +1
    -1
      learnware/market/hetergeneous/searcher.py
  4. +1
    -1
      learnware/market/module.py
  5. +2
    -2
      learnware/specification/system/heter_table.py

+ 2
- 0
.gitignore View File

@@ -27,6 +27,7 @@ dist/
*.db *.db
*.json *.json
*.zip *.zip
*.bin


# special software # special software
.pytest_cache/ .pytest_cache/
@@ -43,3 +44,4 @@ tmp/
learnware_pool/ learnware_pool/
PFS/ PFS/
data/ data/
learnware/market/hetergeneous/.learnware/*

+ 50
- 23
learnware/market/hetergeneous/organizer/__init__.py View File

@@ -1,16 +1,20 @@
from __future__ import annotations from __future__ import annotations


import copy
import multiprocessing import multiprocessing
import os import os
import tempfile
import zipfile
from collections import defaultdict from collections import defaultdict
from shutil import copyfile, rmtree
from typing import List from typing import List


import pandas as pd import pandas as pd


from ....learnware import Learnware
from ....learnware import Learnware, get_learnware_from_dirpath
from ....logger import get_module_logger from ....logger import get_module_logger
from ....specification.system import HeteroSpecification from ....specification.system import HeteroSpecification
from ...base import BaseUserInfo
from ...base import BaseChecker, BaseUserInfo
from ...easy2 import EasyOrganizer from ...easy2 import EasyOrganizer
from ..database_ops import DatabaseOperations from ..database_ops import DatabaseOperations
from .config import C as conf from .config import C as conf
@@ -68,27 +72,53 @@ class HeteroMapTableOrganizer(EasyOrganizer):
self.training_args = kwargs self.training_args = kwargs


def add_learnware( def add_learnware(
self, zip_path: str, semantic_spec: dict, check_status: int, learnware: Learnware
self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None
) -> Tuple[str, int]: ) -> Tuple[str, int]:
self._update_learnware_list([learnware])
self.learnware_list[learnware.id] = learnware
logger.info("Get new learnware from %s" % (zip_path))

learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id)
copyfile(zip_path, target_zip_dir)

with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)
logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir))

try:
new_learnware = get_learnware_from_dirpath(
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
except:
logger.info("New Learnware Not Properly Added!!!")
try:
os.remove(target_zip_dir)
rmtree(target_folder_dir)
except:
pass
return None, BaseChecker.INVALID_LEARNWARE
if new_learnware is None:
return None, BaseChecker.INVALID_LEARNWARE

learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE

self._update_learnware_list([new_learnware])
self.learnware_list[learnware_id] = new_learnware
self.learnware_zip_list[learnware_id] = target_zip_dir
self.learnware_folder_list[learnware_id] = target_folder_dir
self.use_flags[learnware_id] = learnwere_status
self.count += 1 self.count += 1


if self.auto_update and self.count >= self.auto_update_limit: if self.auto_update and self.count >= self.auto_update_limit:
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list,))
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list.values(),))
train_process.start() train_process.start()
# train_process.join() # train_process.join()
return learnware_id, learnwere_status


def delete_learnware(self, id: str) -> bool:
raise NotImplementedError

def update_learnware(self, learnware: Learnware):
raise NotImplementedError

def get_learnwares(self):
return self.learnware_list

def train(self, learnware_list: List[Learnware]):
def train(self, learnware_list: List[Learnware] = None):
learnware_list = learnware_list or self.learnware_list.values()
allset = self._learnwares_to_dataframes(learnware_list) allset = self._learnwares_to_dataframes(learnware_list)
self.market_mapping = HeteroMapping(**self.training_args) self.market_mapping = HeteroMapping(**self.training_args)
market_mapping_trainer = Trainer( market_mapping_trainer = Trainer(
@@ -115,7 +145,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):


def _update_learnware_specification(self, learnware: Learnware, save_path: str) -> Learnware: def _update_learnware_specification(self, learnware: Learnware, save_path: str) -> Learnware:
specification = learnware.specification specification = learnware.specification
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"]
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"]
learnware_features = specification.get_semantic_spec()["Input"]["Description"].values() learnware_features = specification.get_semantic_spec()["Input"]["Description"].values()
learnware_hetero_spec = self.market_mapping.hetero_mapping(learnware_rkme, learnware_features) learnware_hetero_spec = self.market_mapping.hetero_mapping(learnware_rkme, learnware_features)
learnware.update_stat_spec("HeteroSpecification", learnware_hetero_spec) learnware.update_stat_spec("HeteroSpecification", learnware_hetero_spec)
@@ -124,7 +154,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
learnware_hetero_spec.save(save_path) learnware_hetero_spec.save(save_path)


def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroSpecification: def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroSpecification:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
user_rkme = user_info.stat_info["RKMETableSpecification"]
user_features = user_info.semantic_spec["Input"]["Description"].values() user_features = user_info.semantic_spec["Input"]["Description"].values()
user_hetero_spec = self.market_mapping.hetero_mapping(user_rkme, user_features) user_hetero_spec = self.market_mapping.hetero_mapping(user_rkme, user_features)
return user_hetero_spec return user_hetero_spec
@@ -133,7 +163,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
learnware_df_dict = defaultdict(list) learnware_df_dict = defaultdict(list)
for learnware in learnware_list: for learnware in learnware_list:
specification = learnware.get_specification() specification = learnware.get_specification()
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"]
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"]
learnware_features = specification.get_semantic_spec()["Input"]["Description"] learnware_features = specification.get_semantic_spec()["Input"]["Description"]
learnware_df = pd.DataFrame(data=learnware_rkme.get_z(), columns=learnware_features.values()) learnware_df = pd.DataFrame(data=learnware_rkme.get_z(), columns=learnware_features.values())


@@ -143,7 +173,4 @@ class HeteroMapTableOrganizer(EasyOrganizer):
return merged_dfs return merged_dfs


def save(self, save_path): def save(self, save_path):
return NotImplementedError

def __len__(self):
return len(self.learnware_list)
return NotImplementedError

+ 1
- 1
learnware/market/hetergeneous/searcher.py View File

@@ -17,7 +17,7 @@ class HeteroMapTableSearcher(BaseSearcher):
learnware_list = self.learnware_oganizer.get_learnwares() learnware_list = self.learnware_oganizer.get_learnwares()
target_learnware, min_dist = None, None target_learnware, min_dist = None, None
user_hetero_spec = self.learnware_oganizer.generate_hetero_map_spec(user_info) user_hetero_spec = self.learnware_oganizer.generate_hetero_map_spec(user_info)
for learnware in learnware_list.values():
for learnware in learnware_list:
learnware_hetero_spec = learnware.specification.get_stat_spec_by_name("HeteroSpecification") learnware_hetero_spec = learnware.specification.get_stat_spec_by_name("HeteroSpecification")
mmd_dist = learnware_hetero_spec.dist(user_hetero_spec) mmd_dist = learnware_hetero_spec.dist(user_hetero_spec)
if target_learnware is None or mmd_dist < min_dist: if target_learnware is None or mmd_dist < min_dist:


+ 1
- 1
learnware/market/module.py View File

@@ -11,7 +11,7 @@ MARKET_CONFIG = {
"hetero": { "hetero": {
"organizer": HeteroMapTableOrganizer(), "organizer": HeteroMapTableOrganizer(),
"searcher": HeteroMapTableSearcher(), "searcher": HeteroMapTableSearcher(),
"checker_list": []
"checker_list": [EasySemanticChecker(), EasyStatChecker()]
} }
} }




+ 2
- 2
learnware/specification/system/heter_table.py View File

@@ -8,7 +8,7 @@ import os
import numpy as np import numpy as np
import torch import torch


from ..regular.table import RKMEStatSpecification
from ..regular import RKMETableSpecification
from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel
from .base import SystemStatsSpecification from .base import SystemStatsSpecification


@@ -34,7 +34,7 @@ class HeteroSpecification(SystemStatsSpecification):
def get_beta(self) -> np.ndarray: def get_beta(self) -> np.ndarray:
return self.beta.detach().cpu().numpy return self.beta.detach().cpu().numpy


def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMEStatSpecification):
def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMETableSpecification):
self.beta = rkme_spec.beta.to(self.device) self.beta = rkme_spec.beta.to(self.device)
self.z = torch.from_numpy(heter_embedding).double().to(self.device) self.z = torch.from_numpy(heter_embedding).double().to(self.device)




Loading…
Cancel
Save