Browse Source

[MNT] fix details in add_learnware and searcher

tags/v0.3.2
liuht 2 years ago
parent
commit
664463ca57
5 changed files with 56 additions and 27 deletions
  1. +2
    -0
      .gitignore
  2. +50
    -23
      learnware/market/hetergeneous/organizer/__init__.py
  3. +1
    -1
      learnware/market/hetergeneous/searcher.py
  4. +1
    -1
      learnware/market/module.py
  5. +2
    -2
      learnware/specification/system/heter_table.py

+ 2
- 0
.gitignore View File

@@ -27,6 +27,7 @@ dist/
*.db
*.json
*.zip
*.bin

# special software
.pytest_cache/
@@ -43,3 +44,4 @@ tmp/
learnware_pool/
PFS/
data/
learnware/market/hetergeneous/.learnware/*

+ 50
- 23
learnware/market/hetergeneous/organizer/__init__.py View File

@@ -1,16 +1,20 @@
from __future__ import annotations

import copy
import multiprocessing
import os
import tempfile
import zipfile
from collections import defaultdict
from shutil import copyfile, rmtree
from typing import List

import pandas as pd

from ....learnware import Learnware
from ....learnware import Learnware, get_learnware_from_dirpath
from ....logger import get_module_logger
from ....specification.system import HeteroSpecification
from ...base import BaseUserInfo
from ...base import BaseChecker, BaseUserInfo
from ...easy2 import EasyOrganizer
from ..database_ops import DatabaseOperations
from .config import C as conf
@@ -68,27 +72,53 @@ class HeteroMapTableOrganizer(EasyOrganizer):
self.training_args = kwargs

def add_learnware(
self, zip_path: str, semantic_spec: dict, check_status: int, learnware: Learnware
self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None
) -> Tuple[str, int]:
self._update_learnware_list([learnware])
self.learnware_list[learnware.id] = learnware
logger.info("Get new learnware from %s" % (zip_path))

learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id)
copyfile(zip_path, target_zip_dir)

with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)
logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir))

try:
new_learnware = get_learnware_from_dirpath(
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
except:
logger.info("New Learnware Not Properly Added!!!")
try:
os.remove(target_zip_dir)
rmtree(target_folder_dir)
except:
pass
return None, BaseChecker.INVALID_LEARNWARE
if new_learnware is None:
return None, BaseChecker.INVALID_LEARNWARE

learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE

self._update_learnware_list([new_learnware])
self.learnware_list[learnware_id] = new_learnware
self.learnware_zip_list[learnware_id] = target_zip_dir
self.learnware_folder_list[learnware_id] = target_folder_dir
self.use_flags[learnware_id] = learnwere_status
self.count += 1

if self.auto_update and self.count >= self.auto_update_limit:
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list,))
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list.values(),))
train_process.start()
# train_process.join()
return learnware_id, learnwere_status

def delete_learnware(self, id: str) -> bool:
raise NotImplementedError

def update_learnware(self, learnware: Learnware):
raise NotImplementedError

def get_learnwares(self):
return self.learnware_list

def train(self, learnware_list: List[Learnware]):
def train(self, learnware_list: List[Learnware] = None):
learnware_list = learnware_list or self.learnware_list.values()
allset = self._learnwares_to_dataframes(learnware_list)
self.market_mapping = HeteroMapping(**self.training_args)
market_mapping_trainer = Trainer(
@@ -115,7 +145,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):

def _update_learnware_specification(self, learnware: Learnware, save_path: str) -> Learnware:
specification = learnware.specification
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"]
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"]
learnware_features = specification.get_semantic_spec()["Input"]["Description"].values()
learnware_hetero_spec = self.market_mapping.hetero_mapping(learnware_rkme, learnware_features)
learnware.update_stat_spec("HeteroSpecification", learnware_hetero_spec)
@@ -124,7 +154,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
learnware_hetero_spec.save(save_path)

def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroSpecification:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
user_rkme = user_info.stat_info["RKMETableSpecification"]
user_features = user_info.semantic_spec["Input"]["Description"].values()
user_hetero_spec = self.market_mapping.hetero_mapping(user_rkme, user_features)
return user_hetero_spec
@@ -133,7 +163,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
learnware_df_dict = defaultdict(list)
for learnware in learnware_list:
specification = learnware.get_specification()
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"]
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"]
learnware_features = specification.get_semantic_spec()["Input"]["Description"]
learnware_df = pd.DataFrame(data=learnware_rkme.get_z(), columns=learnware_features.values())

@@ -143,7 +173,4 @@ class HeteroMapTableOrganizer(EasyOrganizer):
return merged_dfs

def save(self, save_path):
return NotImplementedError

def __len__(self):
return len(self.learnware_list)
return NotImplementedError

+ 1
- 1
learnware/market/hetergeneous/searcher.py View File

@@ -17,7 +17,7 @@ class HeteroMapTableSearcher(BaseSearcher):
learnware_list = self.learnware_oganizer.get_learnwares()
target_learnware, min_dist = None, None
user_hetero_spec = self.learnware_oganizer.generate_hetero_map_spec(user_info)
for learnware in learnware_list.values():
for learnware in learnware_list:
learnware_hetero_spec = learnware.specification.get_stat_spec_by_name("HeteroSpecification")
mmd_dist = learnware_hetero_spec.dist(user_hetero_spec)
if target_learnware is None or mmd_dist < min_dist:


+ 1
- 1
learnware/market/module.py View File

@@ -11,7 +11,7 @@ MARKET_CONFIG = {
"hetero": {
"organizer": HeteroMapTableOrganizer(),
"searcher": HeteroMapTableSearcher(),
"checker_list": []
"checker_list": [EasySemanticChecker(), EasyStatChecker()]
}
}



+ 2
- 2
learnware/specification/system/heter_table.py View File

@@ -8,7 +8,7 @@ import os
import numpy as np
import torch

from ..regular.table import RKMEStatSpecification
from ..regular import RKMETableSpecification
from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel
from .base import SystemStatsSpecification

@@ -34,7 +34,7 @@ class HeteroSpecification(SystemStatsSpecification):
def get_beta(self) -> np.ndarray:
return self.beta.detach().cpu().numpy

def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMEStatSpecification):
def generate_stat_spec_from_system(self, heter_embedding: np.ndarray, rkme_spec: RKMETableSpecification):
self.beta = rkme_spec.beta.to(self.device)
self.z = torch.from_numpy(heter_embedding).double().to(self.device)



Loading…
Cancel
Save