diff --git a/learnware/market/easy/organizer.py b/learnware/market/easy/organizer.py index 3165fd6..7e877c7 100644 --- a/learnware/market/easy/organizer.py +++ b/learnware/market/easy/organizer.py @@ -374,13 +374,8 @@ class EasyOrganizer(BaseOrganizer): return [self.learnware_list[idx] for idx in learnware_ids] def reload_learnware(self, learnware_id: str): - current_learnware = self.learnware_list.get(learnware_id) - - if current_learnware is None: - # add learnware + if learnware_id not in self.learnware_list: self.count += 1 - else: - pass target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (learnware_id)) target_folder_dir = os.path.join(self.learnware_folder_pool_path, learnware_id) diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index 16d66fb..89d6386 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -36,13 +36,15 @@ class HeteroMapTableOrganizer(EasyOrganizer): if not rebuild: if os.path.exists(self.hetero_specs_path): for hetero_json_path in os.listdir(self.hetero_specs_path): + if not hetero_json_path.endswith(".json"): + continue try: idx = hetero_json_path.split(".")[0] hetero_spec = HeteroMapTableSpecification() - hetero_spec.load(os.path.join(self.hetero_specs_path, f"{idx}.json")) + hetero_spec.load(os.path.join(self.hetero_specs_path, hetero_json_path)) self.learnware_list[idx].update_stat_spec(hetero_spec.type, hetero_spec) - except: - logger.warning(f"Learnware {idx} hetero spec loaded failed!") + except Exception as err: + logger.warning(f"Learnware in {hetero_json_path} hetero spec loaded failed! due to {err}.") else: logger.info("No HeteroMapTableSpecification to reload. Use loaded market mapping to regenerate.") self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) @@ -240,8 +242,8 @@ class HeteroMapTableOrganizer(EasyOrganizer): semantic_spec, rkme = spec.get_semantic_spec(), spec.get_stat_spec().get("RKMETableSpecification", None) if isinstance(rkme, RKMETableSpecification) and isinstance(semantic_spec["Input"], dict): ret.append(idx) - except: - continue + except Exception: + pass return ret def generate_hetero_map_spec(self, user_info: BaseUserInfo) -> HeteroMapTableSpecification: diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index e16f46a..a2e5636 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -1,15 +1,13 @@ -import os import numpy as np import pandas as pd -from typing import List, Optional, Union, Callable import torch import torch.nn.functional as F from torch import Tensor, nn -from loguru import logger +from typing import List, Optional from .....specification import HeteroMapTableSpecification, RKMETableSpecification -from .feature_extractor import * -from .trainer import Trainer, TransTabCollatorForCL +from .feature_extractor import FeatureTokenizer, FeatureProcessor, CLSToken +from .trainer import TransTabCollatorForCL, Trainer class HeteroMap(nn.Module): diff --git a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py index d424a72..2e7a003 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py @@ -119,7 +119,7 @@ class FeatureTokenizer: disable_tokenizer_parallel : bool, optional Whether to disable tokenizer parallelism, by default True. """ - cache_dir = conf["cache_path"] + cache_dir = conf.cache_path os.makedirs(cache_dir, exist_ok=True) self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) self.tokenizer.__dict__["model_max_length"] = 512 diff --git a/learnware/reuse/align.py b/learnware/reuse/align.py index 04d29d4..47a47d5 100644 --- a/learnware/reuse/align.py +++ b/learnware/reuse/align.py @@ -18,7 +18,6 @@ class AlignLearnware(Learnware): specification=learnware.get_specification(), learnware_dirpath=learnware.get_dirpath(), ) - self.learnware = learnware def align(self): """Align the learnware with specification or data""" diff --git a/learnware/reuse/hetero/feature_align.py b/learnware/reuse/hetero/feature_align.py index 71e3d29..b43ed57 100644 --- a/learnware/reuse/hetero/feature_align.py +++ b/learnware/reuse/hetero/feature_align.py @@ -7,10 +7,10 @@ from tqdm import trange import torch.nn.functional as F from ..align import AlignLearnware +from ...utils import choose_device from ...logger import get_module_logger from ...learnware import Learnware from ...specification import RKMETableSpecification -from ...specification.regular.table.rkme import choose_device logger = get_module_logger("feature_align") @@ -60,7 +60,7 @@ class FeatureAlignLearnware(AlignLearnware): user_rkme : RKMETableSpecification The RKME specification from the user dataset. """ - target_rkme = self.learnware.specification.get_stat_spec()["RKMETableSpecification"] + target_rkme = self.specification.get_stat_spec()["RKMETableSpecification"] trainer = FeatureAlignTrainer( target_rkme=target_rkme, user_rkme=user_rkme, cuda_idx=self.cuda_idx, **self.align_arguments ) @@ -86,7 +86,7 @@ class FeatureAlignLearnware(AlignLearnware): transformed_user_data = ( self.align_model(torch.tensor(user_data, device=self.device).float()).detach().cpu().numpy() ) - y_pred = self.learnware.predict(transformed_user_data) + y_pred = super(FeatureAlignLearnware, self).predict(transformed_user_data) return y_pred def _fill_data(self, X: np.ndarray): diff --git a/learnware/reuse/hetero/hetero_map.py b/learnware/reuse/hetero/hetero_map.py index 76f9ce0..c41095a 100644 --- a/learnware/reuse/hetero/hetero_map.py +++ b/learnware/reuse/hetero/hetero_map.py @@ -63,7 +63,7 @@ class HeteroMapAlignLearnware(AlignLearnware): Training data labels. """ self.feature_align_learnware = FeatureAlignLearnware( - learnware=self.learnware, cuda_idx=self.cuda_idx, **self.align_arguments + learnware=self, cuda_idx=self.cuda_idx, **self.align_arguments ) self.feature_align_learnware.align(user_rkme) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index c09be8b..50e367a 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -18,7 +18,8 @@ from tqdm import tqdm from . import cnn_gp from ..base import RegularStatSpecification -from ..table.rkme import rkme_solve_qp, choose_device, setup_seed +from ..table.rkme import rkme_solve_qp +from ....utils import choose_device, setup_seed class RKMEImageSpecification(RegularStatSpecification): diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 996b9a9..8b97632 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -15,6 +15,7 @@ from sklearn.cluster import MiniBatchKMeans from ..base import RegularStatSpecification from ....logger import get_module_logger +from ....utils import setup_seed, choose_device logger = get_module_logger("rkme") @@ -461,46 +462,6 @@ class RKMEStatSpecification(RKMETableSpecification): super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__) -def setup_seed(seed): - """Fix a random seed for addressing reproducibility issues. - - Parameters - ---------- - seed : int - Random seed for torch, torch.cuda, numpy, random and cudnn libraries. - """ - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.deterministic = True - - -def choose_device(cuda_idx=-1): - """Let users choose compuational device between CPU or GPU. - - Parameters - ---------- - cuda_idx : int, optional - GPU index, by default -1 which stands for using CPU instead. - - Returns - ------- - torch.device - A torch.device object - """ - cuda_idx = int(cuda_idx) - if cuda_idx == -1 or not torch.cuda.is_available(): - device = torch.device("cpu") - else: - device_count = torch.cuda.device_count() - if cuda_idx >= 0 and cuda_idx < device_count: - device = torch.device(f"cuda:{cuda_idx}") - else: - device = torch.device("cuda:0") - return device - - def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor: """Use pytorch to compute rbf_kernel function at faster speed. diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index 918ee11..4e89f2d 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -7,9 +7,10 @@ import torch import codecs import numpy as np -from ..regular import RKMETableSpecification -from ..regular.table.rkme import choose_device, setup_seed, torch_rbf_kernel from .base import SystemStatSpecification +from ..regular import RKMETableSpecification +from ..regular.table.rkme import torch_rbf_kernel +from ...utils import choose_device, setup_seed class HeteroMapTableSpecification(SystemStatSpecification): diff --git a/learnware/utils/__init__.py b/learnware/utils/__init__.py index 60f2b46..f37bc03 100644 --- a/learnware/utils/__init__.py +++ b/learnware/utils/__init__.py @@ -4,6 +4,7 @@ import zipfile from .import_utils import is_torch_avaliable from .module import get_module_by_module_path from .file import read_yaml_to_dict, save_dict_to_yaml +from .gpu import setup_seed, choose_device def zip_learnware_folder(path: str, output_name: str): diff --git a/learnware/utils/gpu.py b/learnware/utils/gpu.py new file mode 100644 index 0000000..95fbfe1 --- /dev/null +++ b/learnware/utils/gpu.py @@ -0,0 +1,46 @@ +import random +import numpy as np + + +def setup_seed(seed): + import torch + + """Fix a random seed for addressing reproducibility issues. + + Parameters + ---------- + seed : int + Random seed for torch, torch.cuda, numpy, random and cudnn libraries. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +def choose_device(cuda_idx=-1): + import torch + + """Let users choose compuational device between CPU or GPU. + + Parameters + ---------- + cuda_idx : int, optional + GPU index, by default -1 which stands for using CPU instead. + + Returns + ------- + torch.device + A torch.device object + """ + cuda_idx = int(cuda_idx) + if cuda_idx == -1 or not torch.cuda.is_available(): + device = torch.device("cpu") + else: + device_count = torch.cuda.device_count() + if cuda_idx >= 0 and cuda_idx < device_count: + device = torch.device(f"cuda:{cuda_idx}") + else: + device = torch.device("cuda:0") + return device