| @@ -374,13 +374,8 @@ class EasyOrganizer(BaseOrganizer): | |||||
| return [self.learnware_list[idx] for idx in learnware_ids] | return [self.learnware_list[idx] for idx in learnware_ids] | ||||
| def reload_learnware(self, learnware_id: str): | def reload_learnware(self, learnware_id: str): | ||||
| current_learnware = self.learnware_list.get(learnware_id) | |||||
| if current_learnware is None: | |||||
| # add learnware | |||||
| if learnware_id not in self.learnware_list: | |||||
| self.count += 1 | self.count += 1 | ||||
| else: | |||||
| pass | |||||
| target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) | target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) | ||||
| target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) | target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) | ||||
| @@ -36,13 +36,15 @@ class HeteroMapTableOrganizer(EasyOrganizer): | |||||
| if not rebuild: | if not rebuild: | ||||
| if os.path.exists(self.hetero_specs_path): | if os.path.exists(self.hetero_specs_path): | ||||
| for hetero_json_path in os.listdir(self.hetero_specs_path): | for hetero_json_path in os.listdir(self.hetero_specs_path): | ||||
| if not hetero_json_path.endswith(".json"): | |||||
| continue | |||||
| try: | try: | ||||
| idx = hetero_json_path.split(".")[0] | idx = hetero_json_path.split(".")[0] | ||||
| hetero_spec = HeteroMapTableSpecification() | hetero_spec = HeteroMapTableSpecification() | ||||
| hetero_spec.load(os.path.join(self.hetero_specs_path, f"{idx}.json")) | |||||
| hetero_spec.load(os.path.join(self.hetero_specs_path, hetero_json_path)) | |||||
| self.learnware_list[idx].update_stat_spec(hetero_spec.type, hetero_spec) | self.learnware_list[idx].update_stat_spec(hetero_spec.type, hetero_spec) | ||||
| except: | |||||
| logger.warning(f"Learnware {idx} hetero spec loaded failed!") | |||||
| except Exception as err: | |||||
| logger.warning(f"Learnware in {hetero_json_path} hetero spec loaded failed! due to {err}.") | |||||
| else: | else: | ||||
| logger.info("No HeteroMapTableSpecification to reload. Use loaded market mapping to regenerate.") | logger.info("No HeteroMapTableSpecification to reload. Use loaded market mapping to regenerate.") | ||||
| self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) | self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) | ||||
| @@ -240,8 +242,8 @@ class HeteroMapTableOrganizer(EasyOrganizer): | |||||
| semantic_spec, rkme = spec.get_semantic_spec(), spec.get_stat_spec().get("RKMETableSpecification", None) | semantic_spec, rkme = spec.get_semantic_spec(), spec.get_stat_spec().get("RKMETableSpecification", None) | ||||
| if isinstance(rkme, RKMETableSpecification) and isinstance(semantic_spec["Input"], dict): | if isinstance(rkme, RKMETableSpecification) and isinstance(semantic_spec["Input"], dict): | ||||
| ret.append(idx) | ret.append(idx) | ||||
| except: | |||||
| continue | |||||
| except Exception: | |||||
| pass | |||||
| return ret | return ret | ||||
| def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroMapTableSpecification: | def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroMapTableSpecification: | ||||
| @@ -1,15 +1,13 @@ | |||||
| import os | |||||
| import numpy as np | import numpy as np | ||||
| import pandas as pd | import pandas as pd | ||||
| from typing import List, Optional, Union, Callable | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import Tensor, nn | from torch import Tensor, nn | ||||
| from loguru import logger | |||||
| from typing import List, Optional | |||||
| from .....specification import HeteroMapTableSpecification, RKMETableSpecification | from .....specification import HeteroMapTableSpecification, RKMETableSpecification | ||||
| from .feature_extractor import * | |||||
| from .trainer import Trainer, TransTabCollatorForCL | |||||
| from .feature_extractor import FeatureTokenizer, FeatureProcessor, CLSToken | |||||
| from .trainer import TransTabCollatorForCL, Trainer | |||||
| class HeteroMap(nn.Module): | class HeteroMap(nn.Module): | ||||
| @@ -119,7 +119,7 @@ class FeatureTokenizer: | |||||
| disable_tokenizer_parallel : bool, optional | disable_tokenizer_parallel : bool, optional | ||||
| Whether to disable tokenizer parallelism, by default True. | Whether to disable tokenizer parallelism, by default True. | ||||
| """ | """ | ||||
| cache_dir = conf["cache_path"] | |||||
| cache_dir = conf.cache_path | |||||
| os.makedirs(cache_dir, exist_ok=True) | os.makedirs(cache_dir, exist_ok=True) | ||||
| self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) | ||||
| self.tokenizer.__dict__["model_max_length"] = 512 | self.tokenizer.__dict__["model_max_length"] = 512 | ||||
| @@ -18,7 +18,6 @@ class AlignLearnware(Learnware): | |||||
| specification=learnware.get_specification(), | specification=learnware.get_specification(), | ||||
| learnware_dirpath=learnware.get_dirpath(), | learnware_dirpath=learnware.get_dirpath(), | ||||
| ) | ) | ||||
| self.learnware = learnware | |||||
| def align(self): | def align(self): | ||||
| """Align the learnware with specification or data""" | """Align the learnware with specification or data""" | ||||
| @@ -7,10 +7,10 @@ from tqdm import trange | |||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from ..align import AlignLearnware | from ..align import AlignLearnware | ||||
| from ...utils import choose_device | |||||
| from ...logger import get_module_logger | from ...logger import get_module_logger | ||||
| from ...learnware import Learnware | from ...learnware import Learnware | ||||
| from ...specification import RKMETableSpecification | from ...specification import RKMETableSpecification | ||||
| from ...specification.regular.table.rkme import choose_device | |||||
| logger = get_module_logger("feature_align") | logger = get_module_logger("feature_align") | ||||
| @@ -60,7 +60,7 @@ class FeatureAlignLearnware(AlignLearnware): | |||||
| user_rkme : RKMETableSpecification | user_rkme : RKMETableSpecification | ||||
| The RKME specification from the user dataset. | The RKME specification from the user dataset. | ||||
| """ | """ | ||||
| target_rkme = self.learnware.specification.get_stat_spec()["RKMETableSpecification"] | |||||
| target_rkme = self.specification.get_stat_spec()["RKMETableSpecification"] | |||||
| trainer = FeatureAlignTrainer( | trainer = FeatureAlignTrainer( | ||||
| target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments | target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments | ||||
| ) | ) | ||||
| @@ -86,7 +86,7 @@ class FeatureAlignLearnware(AlignLearnware): | |||||
| transformed_user_data = ( | transformed_user_data = ( | ||||
| self.align_model(torch.tensor(user_data, device=self.device).float()).detach().cpu().numpy() | self.align_model(torch.tensor(user_data, device=self.device).float()).detach().cpu().numpy() | ||||
| ) | ) | ||||
| y_pred = self.learnware.predict(transformed_user_data) | |||||
| y_pred = super(FeatureAlignLearnware, self).predict(transformed_user_data) | |||||
| return y_pred | return y_pred | ||||
| def _fill_data(self, X: np.ndarray): | def _fill_data(self, X: np.ndarray): | ||||
| @@ -63,7 +63,7 @@ class HeteroMapAlignLearnware(AlignLearnware): | |||||
| Training data labels. | Training data labels. | ||||
| """ | """ | ||||
| self.feature_align_learnware = FeatureAlignLearnware( | self.feature_align_learnware = FeatureAlignLearnware( | ||||
| learnware=self.learnware, cuda_idx=self.cuda_idx, **self.align_arguments | |||||
| learnware=self, cuda_idx=self.cuda_idx, **self.align_arguments | |||||
| ) | ) | ||||
| self.feature_align_learnware.align(user_rkme) | self.feature_align_learnware.align(user_rkme) | ||||
| @@ -18,7 +18,8 @@ from tqdm import tqdm | |||||
| from . import cnn_gp | from . import cnn_gp | ||||
| from ..base import RegularStatSpecification | from ..base import RegularStatSpecification | ||||
| from ..table.rkme import rkme_solve_qp, choose_device, setup_seed | |||||
| from ..table.rkme import rkme_solve_qp | |||||
| from ....utils import choose_device, setup_seed | |||||
| class RKMEImageSpecification(RegularStatSpecification): | class RKMEImageSpecification(RegularStatSpecification): | ||||
| @@ -15,6 +15,7 @@ from sklearn.cluster import MiniBatchKMeans | |||||
| from ..base import RegularStatSpecification | from ..base import RegularStatSpecification | ||||
| from ....logger import get_module_logger | from ....logger import get_module_logger | ||||
| from ....utils import setup_seed, choose_device | |||||
| logger = get_module_logger("rkme") | logger = get_module_logger("rkme") | ||||
| @@ -461,46 +462,6 @@ class RKMEStatSpecification(RKMETableSpecification): | |||||
| super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__) | super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__) | ||||
| def setup_seed(seed): | |||||
| """Fix a random seed for addressing reproducibility issues. | |||||
| Parameters | |||||
| ---------- | |||||
| seed : int | |||||
| Random seed for torch, torch.cuda, numpy, random and cudnn libraries. | |||||
| """ | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| np.random.seed(seed) | |||||
| random.seed(seed) | |||||
| torch.backends.cudnn.deterministic = True | |||||
| def choose_device(cuda_idx=-1): | |||||
| """Let users choose compuational device between CPU or GPU. | |||||
| Parameters | |||||
| ---------- | |||||
| cuda_idx : int, optional | |||||
| GPU index, by default -1 which stands for using CPU instead. | |||||
| Returns | |||||
| ------- | |||||
| torch.device | |||||
| A torch.device object | |||||
| """ | |||||
| cuda_idx = int(cuda_idx) | |||||
| if cuda_idx == -1 or not torch.cuda.is_available(): | |||||
| device = torch.device("cpu") | |||||
| else: | |||||
| device_count = torch.cuda.device_count() | |||||
| if cuda_idx >= 0 and cuda_idx < device_count: | |||||
| device = torch.device(f"cuda:{cuda_idx}") | |||||
| else: | |||||
| device = torch.device("cuda:0") | |||||
| return device | |||||
| def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: | def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: | ||||
| """Use pytorch to compute rbf_kernel function at faster speed. | """Use pytorch to compute rbf_kernel function at faster speed. | ||||
| @@ -7,9 +7,10 @@ import torch | |||||
| import codecs | import codecs | ||||
| import numpy as np | import numpy as np | ||||
| from ..regular import RKMETableSpecification | |||||
| from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel | |||||
| from .base import SystemStatSpecification | from .base import SystemStatSpecification | ||||
| from ..regular import RKMETableSpecification | |||||
| from ..regular.table.rkme import torch_rbf_kernel | |||||
| from ...utils import choose_device, setup_seed | |||||
| class HeteroMapTableSpecification(SystemStatSpecification): | class HeteroMapTableSpecification(SystemStatSpecification): | ||||
| @@ -4,6 +4,7 @@ import zipfile | |||||
| from .import_utils import is_torch_avaliable | from .import_utils import is_torch_avaliable | ||||
| from .module import get_module_by_module_path | from .module import get_module_by_module_path | ||||
| from .file import read_yaml_to_dict, save_dict_to_yaml | from .file import read_yaml_to_dict, save_dict_to_yaml | ||||
| from .gpu import setup_seed, choose_device | |||||
| def zip_learnware_folder(path: str, output_name: str): | def zip_learnware_folder(path: str, output_name: str): | ||||
| @@ -0,0 +1,46 @@ | |||||
| import random | |||||
| import numpy as np | |||||
| def setup_seed(seed): | |||||
| import torch | |||||
| """Fix a random seed for addressing reproducibility issues. | |||||
| Parameters | |||||
| ---------- | |||||
| seed : int | |||||
| Random seed for torch, torch.cuda, numpy, random and cudnn libraries. | |||||
| """ | |||||
| torch.manual_seed(seed) | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| np.random.seed(seed) | |||||
| random.seed(seed) | |||||
| torch.backends.cudnn.deterministic = True | |||||
| def choose_device(cuda_idx=-1): | |||||
| import torch | |||||
| """Let users choose compuational device between CPU or GPU. | |||||
| Parameters | |||||
| ---------- | |||||
| cuda_idx : int, optional | |||||
| GPU index, by default -1 which stands for using CPU instead. | |||||
| Returns | |||||
| ------- | |||||
| torch.device | |||||
| A torch.device object | |||||
| """ | |||||
| cuda_idx = int(cuda_idx) | |||||
| if cuda_idx == -1 or not torch.cuda.is_available(): | |||||
| device = torch.device("cpu") | |||||
| else: | |||||
| device_count = torch.cuda.device_count() | |||||
| if cuda_idx >= 0 and cuda_idx < device_count: | |||||
| device = torch.device(f"cuda:{cuda_idx}") | |||||
| else: | |||||
| device = torch.device("cuda:0") | |||||
| return device | |||||