|
|
|
@@ -1,16 +1,20 @@ |
|
|
|
from __future__ import annotations |
|
|
|
|
|
|
|
import copy |
|
|
|
import multiprocessing |
|
|
|
import os |
|
|
|
import tempfile |
|
|
|
import zipfile |
|
|
|
from collections import defaultdict |
|
|
|
from shutil import copyfile, rmtree |
|
|
|
from typing import List |
|
|
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
from ....learnware import Learnware |
|
|
|
from ....learnware import Learnware, get_learnware_from_dirpath |
|
|
|
from ....logger import get_module_logger |
|
|
|
from ....specification.system import HeteroSpecification |
|
|
|
from ...base import BaseUserInfo |
|
|
|
from ...base import BaseChecker, BaseUserInfo |
|
|
|
from ...easy2 import EasyOrganizer |
|
|
|
from ..database_ops import DatabaseOperations |
|
|
|
from .config import C as conf |
|
|
|
@@ -68,27 +72,53 @@ class HeteroMapTableOrganizer(EasyOrganizer): |
|
|
|
self.training_args = kwargs |
|
|
|
|
|
|
|
def add_learnware( |
|
|
|
self, zip_path: str, semantic_spec: dict, check_status: int, learnware: Learnware |
|
|
|
self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None |
|
|
|
) -> Tuple[str, int]: |
|
|
|
self._update_learnware_list([learnware]) |
|
|
|
self.learnware_list[learnware.id] = learnware |
|
|
|
logger.info("Get new learnware from %s" % (zip_path)) |
|
|
|
|
|
|
|
learnware_id = "%08d" % (self.count) if learnware_id is None else learnware_id |
|
|
|
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) |
|
|
|
target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) |
|
|
|
copyfile(zip_path, target_zip_dir) |
|
|
|
|
|
|
|
with zipfile.ZipFile(target_zip_dir, "r") as z_file: |
|
|
|
z_file.extractall(target_folder_dir) |
|
|
|
logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir)) |
|
|
|
|
|
|
|
try: |
|
|
|
new_learnware = get_learnware_from_dirpath( |
|
|
|
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir |
|
|
|
) |
|
|
|
except: |
|
|
|
logger.info("New Learnware Not Properly Added!!!") |
|
|
|
try: |
|
|
|
os.remove(target_zip_dir) |
|
|
|
rmtree(target_folder_dir) |
|
|
|
except: |
|
|
|
pass |
|
|
|
return None, BaseChecker.INVALID_LEARNWARE |
|
|
|
|
|
|
|
if new_learnware is None: |
|
|
|
return None, BaseChecker.INVALID_LEARNWARE |
|
|
|
|
|
|
|
learnwere_status = check_status if check_status is not None else BaseChecker.NONUSABLE_LEARNWARE |
|
|
|
|
|
|
|
self._update_learnware_list([new_learnware]) |
|
|
|
self.learnware_list[learnware_id] = new_learnware |
|
|
|
self.learnware_zip_list[learnware_id] = target_zip_dir |
|
|
|
self.learnware_folder_list[learnware_id] = target_folder_dir |
|
|
|
self.use_flags[learnware_id] = learnwere_status |
|
|
|
self.count += 1 |
|
|
|
|
|
|
|
if self.auto_update and self.count >= self.auto_update_limit: |
|
|
|
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list,)) |
|
|
|
train_process = multiprocessing.Process(target=self.train, args=(self.learnware_list.values(),)) |
|
|
|
train_process.start() |
|
|
|
# train_process.join() |
|
|
|
|
|
|
|
return learnware_id, learnwere_status |
|
|
|
|
|
|
|
def delete_learnware(self, id: str) -> bool: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def update_learnware(self, learnware: Learnware): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def get_learnwares(self): |
|
|
|
return self.learnware_list |
|
|
|
|
|
|
|
def train(self, learnware_list: List[Learnware]): |
|
|
|
def train(self, learnware_list: List[Learnware] = None): |
|
|
|
learnware_list = learnware_list or self.learnware_list.values() |
|
|
|
allset = self._learnwares_to_dataframes(learnware_list) |
|
|
|
self.market_mapping = HeteroMapping(**self.training_args) |
|
|
|
market_mapping_trainer = Trainer( |
|
|
|
@@ -115,7 +145,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): |
|
|
|
|
|
|
|
def _update_learnware_specification(self, learnware: Learnware, save_path: str) -> Learnware: |
|
|
|
specification = learnware.specification |
|
|
|
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"] |
|
|
|
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"] |
|
|
|
learnware_features = specification.get_semantic_spec()["Input"]["Description"].values() |
|
|
|
learnware_hetero_spec = self.market_mapping.hetero_mapping(learnware_rkme, learnware_features) |
|
|
|
learnware.update_stat_spec("HeteroSpecification", learnware_hetero_spec) |
|
|
|
@@ -124,7 +154,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): |
|
|
|
learnware_hetero_spec.save(save_path) |
|
|
|
|
|
|
|
def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroSpecification: |
|
|
|
user_rkme = user_info.stat_info["RKMEStatSpecification"] |
|
|
|
user_rkme = user_info.stat_info["RKMETableSpecification"] |
|
|
|
user_features = user_info.semantic_spec["Input"]["Description"].values() |
|
|
|
user_hetero_spec = self.market_mapping.hetero_mapping(user_rkme, user_features) |
|
|
|
return user_hetero_spec |
|
|
|
@@ -133,7 +163,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): |
|
|
|
learnware_df_dict = defaultdict(list) |
|
|
|
for learnware in learnware_list: |
|
|
|
specification = learnware.get_specification() |
|
|
|
learnware_rkme = specification.get_stat_spec()["RKMEStatSpecification"] |
|
|
|
learnware_rkme = specification.get_stat_spec()["RKMETableSpecification"] |
|
|
|
learnware_features = specification.get_semantic_spec()["Input"]["Description"] |
|
|
|
learnware_df = pd.DataFrame(data=learnware_rkme.get_z(), columns=learnware_features.values()) |
|
|
|
|
|
|
|
@@ -143,7 +173,4 @@ class HeteroMapTableOrganizer(EasyOrganizer): |
|
|
|
return merged_dfs |
|
|
|
|
|
|
|
def save(self, save_path): |
|
|
|
return NotImplementedError |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.learnware_list) |
|
|
|
return NotImplementedError |