Browse Source

Merge pull request #130 from Learnware-LAMDA/fix_spec_type

[FIX] add spec type check
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
f51ec633b8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 31 deletions
  1. +2
    -3
      learnware/learnware/utils.py
  2. +9
    -2
      learnware/market/easy/checker.py
  3. +5
    -0
      learnware/specification/base.py
  4. +15
    -11
      learnware/specification/regular/image/rkme.py
  5. +9
    -8
      learnware/specification/regular/table/rkme.py
  6. +5
    -5
      learnware/specification/system/hetero_table.py
  7. +3
    -2
      tests/test_hetero_market/test_hetero.py

+ 2
- 3
learnware/learnware/utils.py View File

@@ -44,7 +44,6 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification:
raise TypeError(
f"Statistic specification must be type of BaseStatSpecification, not {BaseStatSpecification.__class__.__name__}"
)
if stat_spec_inst.load(stat_spec["file_name"]) is False:
raise ValueError("Load statistic specification failed!")

stat_spec_inst.load(stat_spec["file_name"])
return stat_spec_inst

+ 9
- 2
learnware/market/easy/checker.py View File

@@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker):
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE, traceback.format_exc()
try:
learnware_model = learnware.get_model()
# Check input shape
learnware_model = learnware.get_model()
input_shape = learnware_model.input_shape

if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (
@@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker):
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check statistical specification
spec_type = parse_specification_type(learnware.get_specification().stat_spec)
if spec_type is None:
message = f"No valid specification is found in stat spec {spec_type}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check if statistical specification is computable in dist()
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
stat_spec.dist(stat_spec)

if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}"
@@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker):
logger.warning(message)
return self.INVALID_LEARNWARE, message
inputs = np.random.randn(10, *input_shape)

elif spec_type == "RKMETextSpecification":
inputs = EasyStatChecker._generate_random_text_list(10)

elif spec_type == "RKMEImageSpecification":
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}"
)
inputs = np.random.randint(0, 255, size=(10, *input_shape))

else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")



+ 5
- 0
learnware/specification/base.py View File

@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import numpy as np
from typing import Dict
@@ -22,6 +24,9 @@ class BaseStatSpecification:
def get_states(self):
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}

def dist(self, stat_spec: BaseStatSpecification):
raise NotImplementedError("dist is not implemented")
def save(self, filepath: str):
"""Save the statistical specification into file in filepath



+ 15
- 11
learnware/specification/regular/image/rkme.py View File

@@ -1,7 +1,6 @@
from __future__ import annotations

import codecs
import copy
import functools
import json
import os
@@ -17,8 +16,11 @@ from tqdm import tqdm
from . import cnn_gp
from ..base import RegularStatSpecification
from ..table.rkme import rkme_solve_qp
from ....logger import get_module_logger
from ....utils import choose_device, allocate_cuda_idx

logger = get_module_logger("image_rkme")


class RKMEImageSpecification(RegularStatSpecification):
# INNER_PRODUCT_COUNT = 0
@@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification):
try:
from torchvision.transforms import Resize
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." )
raise ModuleNotFoundError(
f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually."
)

if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X)

@@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification):
with torch.no_grad():
x_features = self._generate_random_feature(X_train, random_models=random_models)
self._update_beta(x_features, nonnegative_beta, random_models=random_models)
try:
import torch_optimizer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.")
raise ModuleNotFoundError(
f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually."
)

optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)

for _ in tqdm(range(steps)) if verbose else range(steps):
@@ -377,18 +383,16 @@ class RKMEImageSpecification(RegularStatSpecification):
rkme_load = json.loads(obj_text)
rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"], dtype="float32"))
rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"], dtype="float64"))
for d in self.get_states():
if d in rkme_load.keys():
if d == "type" and rkme_load[d] != self.type:
raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!")
setattr(self, d, rkme_load[d])

self.beta = self.beta.to(self._device)
self.z = self.z.to(self._device)

return True
else:
return False


def _get_zca_matrix(X, reg_coef=0.1):
X_flat = X.reshape(X.shape[0], -1)


+ 9
- 8
learnware/specification/regular/table/rkme.py View File

@@ -6,7 +6,7 @@ import json
import codecs
import scipy
import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem
from qpsolvers import Problem, solve_problem
from collections import Counter
from typing import Any, Union

@@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification):
if isinstance(X, np.ndarray):
X = X.astype("float32")
X = torch.from_numpy(X)
X = X.to(self._device)
try:
from fast_pytorch_kmeans import KMeans
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." )
raise ModuleNotFoundError(
f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually."
)

kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0)
kmeans.fit(X)
self.z = kmeans.centroids.double()

@@ -454,10 +456,9 @@ class RKMETableSpecification(RegularStatSpecification):

for d in self.get_states():
if d in rkme_load.keys():
if d == "type" and rkme_load[d] != self.type:
raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!")
setattr(self, d, rkme_load[d])
return True
else:
return False


class RKMEStatSpecification(RKMETableSpecification):


+ 5
- 5
learnware/specification/system/hetero_table.py View File

@@ -1,7 +1,6 @@
from __future__ import annotations

import os
import copy
import json
import torch
import codecs
@@ -10,8 +9,11 @@ import numpy as np
from .base import SystemStatSpecification
from ..regular import RKMETableSpecification
from ..regular.table.rkme import torch_rbf_kernel
from ...logger import get_module_logger
from ...utils import choose_device, allocate_cuda_idx

logger = get_module_logger("hetero_map_table_spec")


class HeteroMapTableSpecification(SystemStatSpecification):
"""Heterogeneous Map-Table Specification"""
@@ -133,12 +135,10 @@ class HeteroMapTableSpecification(SystemStatSpecification):

for d in self.get_states():
if d in embedding_load.keys():
if d == "type" and embedding_load[d] != self.type:
raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!")
setattr(self, d, embedding_load[d])

return True
else:
return False

def save(self, filepath: str) -> bool:
"""Save the computed HeteroMapTableSpecification to a specified path in JSON format.



+ 3
- 2
tests/test_hetero_market/test_hetero.py View File

@@ -5,10 +5,10 @@ import copy
import joblib
import zipfile
import numpy as np
import multiprocessing
from sklearn.linear_model import Ridge
from sklearn.datasets import make_regression
from shutil import copyfile, rmtree
from multiprocessing import Pool
from learnware.client import LearnwareClient
from sklearn.metrics import mean_squared_error

@@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase):
dir_path = os.path.join(curr_root, "learnware_pool")

# Execute multi-process checking using Pool
with Pool() as pool:
mp_context = multiprocessing.get_context("spawn")
with mp_context.Pool() as pool:
results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)])

# Use an assert statement to ensure that all checks return True


Loading…
Cancel
Save