From 72214b8116e84b8c725b4b52a113dc3e26baf8e5 Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 4 Dec 2023 13:26:06 +0800 Subject: [PATCH 1/7] [MNT] add type check in spec load --- learnware/specification/regular/image/rkme.py | 28 +++++++++++++------ learnware/specification/regular/table/rkme.py | 21 +++++++++----- .../specification/system/hetero_table.py | 14 +++++++--- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 81b68f7..821924b 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -1,7 +1,6 @@ from __future__ import annotations import codecs -import copy import functools import json import os @@ -17,8 +16,11 @@ from tqdm import tqdm from . import cnn_gp from ..base import RegularStatSpecification from ..table.rkme import rkme_solve_qp +from ....logger import get_module_logger from ....utils import choose_device, allocate_cuda_idx +logger = get_module_logger("image_rkme") + class RKMEImageSpecification(RegularStatSpecification): # INNER_PRODUCT_COUNT = 0 @@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification): try: from torchvision.transforms import Resize except ModuleNotFoundError: - raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." ) - + raise ModuleNotFoundError( + f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." + ) + if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X) @@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification): with torch.no_grad(): x_features = self._generate_random_feature(X_train, random_models=random_models) self._update_beta(x_features, nonnegative_beta, random_models=random_models) - + try: import torch_optimizer except ModuleNotFoundError: - raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.") - + raise ModuleNotFoundError( + f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually." + ) + optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) for _ in tqdm(range(steps)) if verbose else range(steps): @@ -385,9 +391,13 @@ class RKMEImageSpecification(RegularStatSpecification): self.beta = self.beta.to(self._device) self.z = self.z.to(self._device) - return True - else: - return False + if self.type == self.__class__.__name__: + logger.error( + f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" + ) + return True + + return False def _get_zca_matrix(X, reg_coef=0.1): diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index f335f4d..af8ff20 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification): if isinstance(X, np.ndarray): X = X.astype("float32") X = torch.from_numpy(X) - + X = X.to(self._device) - + try: from fast_pytorch_kmeans import KMeans except ModuleNotFoundError: - raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." ) + raise ModuleNotFoundError( + f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." + ) - kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) + kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0) kmeans.fit(X) self.z = kmeans.centroids.double() @@ -455,9 +457,14 @@ class RKMETableSpecification(RegularStatSpecification): for d in self.get_states(): if d in rkme_load.keys(): setattr(self, d, rkme_load[d]) - return True - else: - return False + + if self.type == self.__class__.__name__: + logger.error( + f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" + ) + return True + + return False class RKMEStatSpecification(RKMETableSpecification): diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index 3987040..ad9d7a8 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import copy import json import torch import codecs @@ -10,8 +9,11 @@ import numpy as np from .base import SystemStatSpecification from ..regular import RKMETableSpecification from ..regular.table.rkme import torch_rbf_kernel +from ...logger import get_module_logger from ...utils import choose_device, allocate_cuda_idx +logger = get_module_logger("hetero_map_table_spec") + class HeteroMapTableSpecification(SystemStatSpecification): """Heterogeneous Map-Table Specification""" @@ -135,9 +137,13 @@ class HeteroMapTableSpecification(SystemStatSpecification): if d in embedding_load.keys(): setattr(self, d, embedding_load[d]) - return True - else: - return False + if self.type == self.__class__.__name__: + logger.error( + f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" + ) + return True + + return False def save(self, filepath: str) -> bool: """Save the computed HeteroMapTableSpecification to a specified path in JSON format. From a7b8f8d12fa7e3bf27525b39b67d3b4bb8e8e5e6 Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 4 Dec 2023 13:36:16 +0800 Subject: [PATCH 2/7] [FIX] add spec.dist() check --- learnware/market/easy/checker.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 7995d94..57cae22 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker): logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") return self.INVALID_LEARNWARE, traceback.format_exc() try: - learnware_model = learnware.get_model() # Check input shape + learnware_model = learnware.get_model() input_shape = learnware_model.input_shape if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( @@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker): logger.warning(message) return self.INVALID_LEARNWARE, message + # Check statistical specification spec_type = parse_specification_type(learnware.get_specification().stat_spec) if spec_type is None: message = f"No valid specification is found in stat spec {spec_type}" logger.warning(message) return self.INVALID_LEARNWARE, message + # Check if statistical specification is computable in dist() + stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) + stat_spec.dist(stat_spec) + if spec_type == "RKMETableSpecification": - stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): raise ValueError( f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" @@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker): logger.warning(message) return self.INVALID_LEARNWARE, message inputs = np.random.randn(10, *input_shape) + elif spec_type == "RKMETextSpecification": inputs = EasyStatChecker._generate_random_text_list(10) + elif spec_type == "RKMEImageSpecification": if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): raise ValueError( f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" ) inputs = np.random.randint(0, 255, size=(10, *input_shape)) + else: raise ValueError(f"not supported spec type for spec_type = {spec_type}") From d9a0ccc9588e771f1f0a7799ec688fe5e31458ec Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 4 Dec 2023 13:46:47 +0800 Subject: [PATCH 3/7] [FIX] fix details --- learnware/specification/regular/image/rkme.py | 2 +- learnware/specification/regular/table/rkme.py | 2 +- learnware/specification/system/hetero_table.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 821924b..448ae62 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -391,7 +391,7 @@ class RKMEImageSpecification(RegularStatSpecification): self.beta = self.beta.to(self._device) self.z = self.z.to(self._device) - if self.type == self.__class__.__name__: + if self.type != self.__class__.__name__: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index af8ff20..78e4996 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -458,7 +458,7 @@ class RKMETableSpecification(RegularStatSpecification): if d in rkme_load.keys(): setattr(self, d, rkme_load[d]) - if self.type == self.__class__.__name__: + if self.type != self.__class__.__name__: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index ad9d7a8..5bec254 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -137,7 +137,7 @@ class HeteroMapTableSpecification(SystemStatSpecification): if d in embedding_load.keys(): setattr(self, d, embedding_load[d]) - if self.type == self.__class__.__name__: + if self.type != self.__class__.__name__: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) From 61ed6da6c43cadd5624e1249f5d94c4a11d94ace Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 4 Dec 2023 13:57:49 +0800 Subject: [PATCH 4/7] [FIX] fix bugs --- learnware/specification/regular/image/rkme.py | 5 +++-- learnware/specification/regular/table/rkme.py | 7 ++++--- learnware/specification/system/hetero_table.py | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 448ae62..32809b7 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -391,11 +391,12 @@ class RKMEImageSpecification(RegularStatSpecification): self.beta = self.beta.to(self._device) self.z = self.z.to(self._device) - if self.type != self.__class__.__name__: + if self.type == self.__class__.__name__: + return True + else: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) - return True return False diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 78e4996..046a956 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -6,7 +6,7 @@ import json import codecs import scipy import numpy as np -from qpsolvers import solve_qp, Problem, solve_problem +from qpsolvers import Problem, solve_problem from collections import Counter from typing import Any, Union @@ -458,11 +458,12 @@ class RKMETableSpecification(RegularStatSpecification): if d in rkme_load.keys(): setattr(self, d, rkme_load[d]) - if self.type != self.__class__.__name__: + if self.type == self.__class__.__name__: + return True + else: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) - return True return False diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index 5bec254..2a8ab6b 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -137,11 +137,12 @@ class HeteroMapTableSpecification(SystemStatSpecification): if d in embedding_load.keys(): setattr(self, d, embedding_load[d]) - if self.type != self.__class__.__name__: + if self.type == self.__class__.__name__: + return True + else: logger.error( f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" ) - return True return False From 13bb47359ddb453616dc9921ba4b28cb115fd0ac Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 4 Dec 2023 14:10:18 +0800 Subject: [PATCH 5/7] [FIX] fix bugs in test_hetero --- tests/test_hetero_market/test_hetero.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index 41b4261..7b0740b 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -5,10 +5,10 @@ import copy import joblib import zipfile import numpy as np +import multiprocessing from sklearn.linear_model import Ridge from sklearn.datasets import make_regression from shutil import copyfile, rmtree -from multiprocessing import Pool from learnware.client import LearnwareClient from sklearn.metrics import mean_squared_error @@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase): dir_path = os.path.join(curr_root, "learnware_pool") # Execute multi-process checking using Pool - with Pool() as pool: + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool() as pool: results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) # Use an assert statement to ensure that all checks return True From b3c50926d42b59e82db1dc155c4b86b98b64cb92 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 4 Dec 2023 14:54:48 +0800 Subject: [PATCH 6/7] [MNT] add dist method in base stat_spec class --- learnware/specification/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 064173d..6b1c5f5 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import numpy as np from typing import Dict @@ -22,6 +24,9 @@ class BaseStatSpecification: def get_states(self): return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + def dist(self, stat_spec: BaseStatSpecification): + raise NotImplementedError("dist is not implemented") + def save(self, filepath: str): """Save the statistical specification into file in filepath From d10d87914ce68a3374b2d1a590ab6ede9ce2d413 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 4 Dec 2023 15:12:53 +0800 Subject: [PATCH 7/7] [FIX] fix bugs which assumes type of stat_spec is class_name --- learnware/learnware/utils.py | 5 ++--- learnware/specification/regular/image/rkme.py | 13 +++---------- learnware/specification/regular/table/rkme.py | 11 ++--------- learnware/specification/system/hetero_table.py | 11 ++--------- 4 files changed, 9 insertions(+), 31 deletions(-) diff --git a/learnware/learnware/utils.py b/learnware/learnware/utils.py index e677c58..ebc4243 100644 --- a/learnware/learnware/utils.py +++ b/learnware/learnware/utils.py @@ -44,7 +44,6 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: raise TypeError( f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}" ) - if stat_spec_inst.load(stat_spec["file_name"]) is False: - raise ValueError("Load statistic specification failed!") - + stat_spec_inst.load(stat_spec["file_name"]) + return stat_spec_inst diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 32809b7..84222f6 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -383,23 +383,16 @@ class RKMEImageSpecification(RegularStatSpecification): rkme_load = json.loads(obj_text) rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"], dtype="float32")) rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"], dtype="float64")) - + for d in self.get_states(): if d in rkme_load.keys(): + if d == "type" and rkme_load[d] != self.type: + raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!") setattr(self, d, rkme_load[d]) self.beta = self.beta.to(self._device) self.z = self.z.to(self._device) - if self.type == self.__class__.__name__: - return True - else: - logger.error( - f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" - ) - - return False - def _get_zca_matrix(X, reg_coef=0.1): X_flat = X.reshape(X.shape[0], -1) diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 046a956..abecf6f 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -456,17 +456,10 @@ class RKMETableSpecification(RegularStatSpecification): for d in self.get_states(): if d in rkme_load.keys(): + if d == "type" and rkme_load[d] != self.type: + raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!") setattr(self, d, rkme_load[d]) - if self.type == self.__class__.__name__: - return True - else: - logger.error( - f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" - ) - - return False - class RKMEStatSpecification(RKMETableSpecification): """nickname for RKMETableSpecification, for compatibility currently. diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index 2a8ab6b..4726f83 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -135,17 +135,10 @@ class HeteroMapTableSpecification(SystemStatSpecification): for d in self.get_states(): if d in embedding_load.keys(): + if d == "type" and embedding_load[d] != self.type: + raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!") setattr(self, d, embedding_load[d]) - if self.type == self.__class__.__name__: - return True - else: - logger.error( - f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" - ) - - return False - def save(self, filepath: str) -> bool: """Save the computed HeteroMapTableSpecification to a specified path in JSON format.