[FIX] add spec type checktags/v0.3.2
| @@ -44,7 +44,6 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: | |||
| raise TypeError( | |||
| f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}" | |||
| ) | |||
| if stat_spec_inst.load(stat_spec["file_name"]) is False: | |||
| raise ValueError("Load statistic specification failed!") | |||
| stat_spec_inst.load(stat_spec["file_name"]) | |||
| return stat_spec_inst | |||
| @@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | |||
| return self.INVALID_LEARNWARE, traceback.format_exc() | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # Check input shape | |||
| learnware_model = learnware.get_model() | |||
| input_shape = learnware_model.input_shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | |||
| @@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check statistical specification | |||
| spec_type = parse_specification_type(learnware.get_specification().stat_spec) | |||
| if spec_type is None: | |||
| message = f"No valid specification is found in stat spec {spec_type}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check if statistical specification is computable in dist() | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| stat_spec.dist(stat_spec) | |||
| if spec_type == "RKMETableSpecification": | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| @@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| inputs = np.random.randn(10, *input_shape) | |||
| elif spec_type == "RKMETextSpecification": | |||
| inputs = EasyStatChecker._generate_random_text_list(10) | |||
| elif spec_type == "RKMEImageSpecification": | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| ) | |||
| inputs = np.random.randint(0, 255, size=(10, *input_shape)) | |||
| else: | |||
| raise ValueError(f"not supported spec type for spec_type = {spec_type}") | |||
| @@ -1,3 +1,5 @@ | |||
| from __future__ import annotations | |||
| import copy | |||
| import numpy as np | |||
| from typing import Dict | |||
| @@ -22,6 +24,9 @@ class BaseStatSpecification: | |||
| def get_states(self): | |||
| return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} | |||
| def dist(self, stat_spec: BaseStatSpecification): | |||
| raise NotImplementedError("dist is not implemented") | |||
| def save(self, filepath: str): | |||
| """Save the statistical specification into file in filepath | |||
| @@ -1,7 +1,6 @@ | |||
| from __future__ import annotations | |||
| import codecs | |||
| import copy | |||
| import functools | |||
| import json | |||
| import os | |||
| @@ -17,8 +16,11 @@ from tqdm import tqdm | |||
| from . import cnn_gp | |||
| from ..base import RegularStatSpecification | |||
| from ..table.rkme import rkme_solve_qp | |||
| from ....logger import get_module_logger | |||
| from ....utils import choose_device, allocate_cuda_idx | |||
| logger = get_module_logger("image_rkme") | |||
| class RKMEImageSpecification(RegularStatSpecification): | |||
| # INNER_PRODUCT_COUNT = 0 | |||
| @@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| try: | |||
| from torchvision.transforms import Resize | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." ) | |||
| raise ModuleNotFoundError( | |||
| f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." | |||
| ) | |||
| if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | |||
| X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X) | |||
| @@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| with torch.no_grad(): | |||
| x_features = self._generate_random_feature(X_train, random_models=random_models) | |||
| self._update_beta(x_features, nonnegative_beta, random_models=random_models) | |||
| try: | |||
| import torch_optimizer | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.") | |||
| raise ModuleNotFoundError( | |||
| f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually." | |||
| ) | |||
| optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) | |||
| for _ in tqdm(range(steps)) if verbose else range(steps): | |||
| @@ -377,18 +383,16 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| rkme_load = json.loads(obj_text) | |||
| rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"], dtype="float32")) | |||
| rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"], dtype="float64")) | |||
| for d in self.get_states(): | |||
| if d in rkme_load.keys(): | |||
| if d == "type" and rkme_load[d] != self.type: | |||
| raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!") | |||
| setattr(self, d, rkme_load[d]) | |||
| self.beta = self.beta.to(self._device) | |||
| self.z = self.z.to(self._device) | |||
| return True | |||
| else: | |||
| return False | |||
| def _get_zca_matrix(X, reg_coef=0.1): | |||
| X_flat = X.reshape(X.shape[0], -1) | |||
| @@ -6,7 +6,7 @@ import json | |||
| import codecs | |||
| import scipy | |||
| import numpy as np | |||
| from qpsolvers import solve_qp, Problem, solve_problem | |||
| from qpsolvers import Problem, solve_problem | |||
| from collections import Counter | |||
| from typing import Any, Union | |||
| @@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification): | |||
| if isinstance(X, np.ndarray): | |||
| X = X.astype("float32") | |||
| X = torch.from_numpy(X) | |||
| X = X.to(self._device) | |||
| try: | |||
| from fast_pytorch_kmeans import KMeans | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." ) | |||
| raise ModuleNotFoundError( | |||
| f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." | |||
| ) | |||
| kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) | |||
| kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0) | |||
| kmeans.fit(X) | |||
| self.z = kmeans.centroids.double() | |||
| @@ -454,10 +456,9 @@ class RKMETableSpecification(RegularStatSpecification): | |||
| for d in self.get_states(): | |||
| if d in rkme_load.keys(): | |||
| if d == "type" and rkme_load[d] != self.type: | |||
| raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!") | |||
| setattr(self, d, rkme_load[d]) | |||
| return True | |||
| else: | |||
| return False | |||
| class RKMEStatSpecification(RKMETableSpecification): | |||
| @@ -1,7 +1,6 @@ | |||
| from __future__ import annotations | |||
| import os | |||
| import copy | |||
| import json | |||
| import torch | |||
| import codecs | |||
| @@ -10,8 +9,11 @@ import numpy as np | |||
| from .base import SystemStatSpecification | |||
| from ..regular import RKMETableSpecification | |||
| from ..regular.table.rkme import torch_rbf_kernel | |||
| from ...logger import get_module_logger | |||
| from ...utils import choose_device, allocate_cuda_idx | |||
| logger = get_module_logger("hetero_map_table_spec") | |||
| class HeteroMapTableSpecification(SystemStatSpecification): | |||
| """Heterogeneous Map-Table Specification""" | |||
| @@ -133,12 +135,10 @@ class HeteroMapTableSpecification(SystemStatSpecification): | |||
| for d in self.get_states(): | |||
| if d in embedding_load.keys(): | |||
| if d == "type" and embedding_load[d] != self.type: | |||
| raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!") | |||
| setattr(self, d, embedding_load[d]) | |||
| return True | |||
| else: | |||
| return False | |||
| def save(self, filepath: str) -> bool: | |||
| """Save the computed HeteroMapTableSpecification to a specified path in JSON format. | |||
| @@ -5,10 +5,10 @@ import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| import multiprocessing | |||
| from sklearn.linear_model import Ridge | |||
| from sklearn.datasets import make_regression | |||
| from shutil import copyfile, rmtree | |||
| from multiprocessing import Pool | |||
| from learnware.client import LearnwareClient | |||
| from sklearn.metrics import mean_squared_error | |||
| @@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase): | |||
| dir_path = os.path.join(curr_root, "learnware_pool") | |||
| # Execute multi-process checking using Pool | |||
| with Pool() as pool: | |||
| mp_context = multiprocessing.get_context("spawn") | |||
| with mp_context.Pool() as pool: | |||
| results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) | |||
| # Use an assert statement to ensure that all checks return True | |||