From 002f22a8a2742aa7d8685cf45ff7cfa383dfc213 Mon Sep 17 00:00:00 2001 From: shihy Date: Wed, 25 Oct 2023 16:41:27 +0800 Subject: [PATCH 01/24] [ENH] Adding RKME Image Specification designed for image data, version 1 --- .../example_files/example_yaml.yaml | 4 +- examples/dataset_image_workflow/main.py | 48 +- learnware/specification/image.py | 457 ++++++++++++++++++ 3 files changed, 484 insertions(+), 25 deletions(-) create mode 100644 learnware/specification/image.py diff --git a/examples/dataset_image_workflow/example_files/example_yaml.yaml b/examples/dataset_image_workflow/example_files/example_yaml.yaml index 6ca01c9..2f2b4cd 100644 --- a/examples/dataset_image_workflow/example_files/example_yaml.yaml +++ b/examples/dataset_image_workflow/example_files/example_yaml.yaml @@ -2,7 +2,7 @@ model: class_name: Model kwargs: {} stat_specifications: - - module_path: learnware.specification - class_name: RKMEStatSpecification + - module_path: learnware.specification.image + class_name: RKMEImageStatSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 26e639a..1bb41ab 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -3,9 +3,11 @@ import torch from get_data import * import os import random + +from learnware.specification.image import RKMEImageStatSpecification +from learnware.reuse.averaging import AveragingReuser from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction from learnware.learnware import Learnware -from learnware.reuse import JobSelectorReuser, AveragingReuser import time from learnware.market import EasyMarket, BaseUserInfo @@ -23,8 +25,8 @@ processed_data_root = "./data/processed_data" tmp_dir = "./data/tmp" learnware_pool_dir = "./data/learnware_pool" dataset = "cifar10" -n_uploaders = 50 -n_users = 20 +n_uploaders = 3 +n_users = 3 n_classes = 10 data_root = os.path.join(origin_data_root, dataset) data_save_root = os.path.join(processed_data_root, dataset) @@ -45,6 +47,7 @@ semantic_specs = [ "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, + "Output": {"Dimension": 10} } ] @@ -88,9 +91,15 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo tmp_init_path = os.path.join(save_root, "__init__.py") tmp_model_file_path = os.path.join(save_root, "model.py") mmodel_file_path = "./example_files/model.py" + + # Computing the specification from the whole dataset is too costly. X = np.load(data_path) + indices = np.random.choice(len(X), size=2000, replace=False) + X_sampled = X[indices] + st = time.time() - user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + user_spec = RKMEImageStatSpecification(cuda_idx=0) + user_spec.generate_stat_spec_from_data(X=X_sampled) ed = time.time() logger.info("Stat spec generated in %.3f s" % (ed - st)) user_spec.save(tmp_spec_path) @@ -153,35 +162,33 @@ def test_search(gamma=0.1, load_market=True): user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) user_data = np.load(user_data_path) user_label = np.load(user_label_path) - user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0) + user_stat_spec = RKMEImageStatSpecification(cuda_idx=0) + user_stat_spec.generate_stat_spec_from_data(X=user_data, steps=5, resize=False) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % (i)) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( user_info ) - l = len(sorted_score_list) acc_list = [] - for idx in range(l): - learnware = single_learnware_list[idx] - score = sorted_score_list[idx] + for idx, (score, learnware) in enumerate(zip(sorted_score_list[:5], single_learnware_list[:5])): pred_y = learnware.predict(user_data) acc = eval_prediction(pred_y, user_label) acc_list.append(acc) logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) # test reuse (job selector) - reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) - reuse_predict = reuse_baseline.predict(user_data=user_data) - reuse_score = eval_prediction(reuse_predict, user_label) - job_selector_score_list.append(reuse_score) - print(f"mixture reuse loss: {reuse_score}") + # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) + # reuse_predict = reuse_baseline.predict(user_data=user_data) + # reuse_score = eval_prediction(reuse_predict, user_label) + # job_selector_score_list.append(reuse_score) + # print(f"mixture reuse loss: {reuse_score}") # test reuse (ensemble) - reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") + reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_score = eval_prediction(ensemble_predict_y, user_label) ensemble_score_list.append(ensemble_score) - print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n") + print(f"reuse accuracy (vote_by_prob): {ensemble_score}\n") select_list.append(acc_list[0]) avg_list.append(np.mean(acc_list)) @@ -191,17 +198,12 @@ def test_search(gamma=0.1, load_market=True): "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f" % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) ) - logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) - logger.info( - "Average Job Selector Reuse Performance: %.3f +/- %.3f" - % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) - ) logger.info( "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) if __name__ == "__main__": - prepare_data() - prepare_model() + # prepare_data() + # prepare_model() test_search(load_market=False) diff --git a/learnware/specification/image.py b/learnware/specification/image.py new file mode 100644 index 0000000..55cc580 --- /dev/null +++ b/learnware/specification/image.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import codecs +import copy +import functools +import json +import os + +from typing import Any, Union + +import numpy as np +import torch +import torch_optimizer +from torch import nn +from torch.func import jacrev, functional_call +from torch.utils.data import TensorDataset, DataLoader +from torchvision.transforms import Resize +from tqdm import tqdm + +from .base import BaseStatSpecification +from .rkme import solve_qp, choose_device, setup_seed + + +class RKMEImageStatSpecification(BaseStatSpecification): + inner_prod_buffer = dict() + INNER_PRODUCT_COUNT = 0 + IMAGE_WIDTH = 32 + + def __init__(self, cuda_idx: int = -1, buffering: bool=True, **kwargs): + """Initializing RKME Image specification's parameters. + + Parameters + ---------- + cuda_idx : int + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. + buffering: bool + When buffering is True, the result of inner_prod will be buffered according to id(object), avoiding duplicate kernel function calculations, by default True. + """ + self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. + # torch.cuda.empty_cache() + + self.z = None + self.beta = None + self.cuda_idx = cuda_idx + self.device = choose_device(cuda_idx=cuda_idx) + self.buffering = buffering + + self.n_models = kwargs["n_models"] if "n_models" in kwargs else 16 + self.model_config = { + "k": 2, "mu": 0, "sigma": None, 'chopped_head': True, + "net_width": 128, "net_depth": 3, "net_act": "relu" + } if "model_config" not in kwargs else kwargs["model_config"] + + setup_seed(0) + + def _generate_models(self, n_models: int, channel: int=3, fixed_seed=None): + model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) + + def __builder(i): + if fixed_seed is not None: + torch.manual_seed(fixed_seed[i]) + return model_class().to(self.device) + + return (__builder(m) for m in range(n_models)) + + def generate_stat_spec_from_data( + self, + X: np.ndarray, + K: int = 50, + step_size: float = 0.01, + steps: int=100, + resize: bool = False, + nonnegative_beta: bool = True, + reduce: bool = True + ): + """Construct reduced set from raw dataset using iterative optimization. + + Parameters + ---------- + X : np.ndarray or torch.tensor + Raw data in np.ndarray format. + K : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + resize : bool + Whether to scale the image to the requested size, by default True. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + + Returns + ------- + + """ + if (X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or + X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH) and not resize: + raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. " + "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}." + .format(RKMEImageStatSpecification.IMAGE_WIDTH)) + + if not torch.is_tensor(X): + X = torch.from_numpy(X) + X = X.to(self.device) + + X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan + if torch.any(torch.isnan(X)): + for i, img in enumerate(X): + is_nan = torch.isnan(img) + if torch.any(is_nan): + if torch.all(is_nan): + raise ValueError(f"All values in image {i} are exceptional, e.g., NaN and Inf.") + img_mean = torch.nanmean(img) + X[i] = torch.where(is_nan, img_mean, img) + + if (X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or + X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH): + X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, + RKMEImageStatSpecification.IMAGE_WIDTH))(X) + + num_points = X.shape[0] + X_shape = X.shape + Z_shape = tuple([K] + list(X_shape)[1:]) + + X_train = (X - torch.mean(X, [0, 2, 3], keepdim=True)) / (torch.std(X, [0, 2, 3], keepdim=True)) + if X_train.shape[1] > 1: + whitening = _get_zca_matrix(X_train) + X_train = X_train.reshape(num_points, -1) @ whitening + X_train = X_train.view(*X_shape) + + if not reduce: + self.beta = 1 / num_points * np.ones(num_points) + self.z = torch.to(self.device) + self.beta = torch.from_numpy(self.beta).to(self.device) + return + + random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) + self.z = torch.zeros(Z_shape).to(self.device).float().normal_(0, 1) + with torch.no_grad(): + x_features = self._generate_random_feature(X_train, random_models=random_models) + self._update_beta(x_features, nonnegative_beta, random_models=random_models) + + optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], + lr=step_size, eps=1e-16) + + for i in tqdm(range(steps), total=steps): + # Regenerate Random Models + random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) + + with torch.no_grad(): + x_features = self._generate_random_feature(X_train, random_models=random_models) + self._update_z(x_features, optimizer, random_models=random_models) + self._update_beta(x_features, nonnegative_beta, random_models=random_models) + + @torch.no_grad() + def _update_beta(self, x_features: Any, nonnegative_beta: bool = True, random_models=None): + Z = self.z + if not torch.is_tensor(Z): + Z = torch.from_numpy(Z) + Z = Z.to(self.device) + + if not torch.is_tensor(x_features): + x_features = torch.from_numpy(x_features) + x_features = x_features.to(self.device) + + z_features = self._generate_random_feature(Z, random_models=random_models) + K = self._calc_ntk_from_feature(z_features, z_features).to(self.device) + C = self._calc_ntk_from_feature(z_features, x_features).to(self.device) + C = torch.sum(C, dim=1) / x_features.shape[0] + + if nonnegative_beta: + beta = solve_qp(K.double(), C.double()).to(self.device) + else: + beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C + + self.beta = beta + + def _update_z(self, x_features: Any, optimizer, random_models=None): + Z = self.z + beta = self.beta + + if not torch.is_tensor(Z): + Z = torch.from_numpy(Z) + Z = Z.to(self.device).float() + + if not torch.is_tensor(beta): + beta = torch.from_numpy(beta) + beta = beta.to(self.device) + + if not torch.is_tensor(x_features): + x_features = torch.from_numpy(x_features) + x_features = x_features.to(self.device).float() + + with torch.no_grad(): + beta = beta.unsqueeze(0) + + for i in range(3): + z_features = self._generate_random_feature(Z, random_models=random_models) + K_z = self._calc_ntk_from_feature(z_features, z_features) + K_zx = self._calc_ntk_from_feature(x_features, z_features) + term_1 = torch.sum(K_z * (beta.T @ beta)) + term_2 = torch.sum(K_zx * beta / x_features.shape[0]) + loss = term_1 - 2 * term_2 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def _generate_random_feature(self, data_X, batch_size=4096, random_models=None) -> torch.Tensor: + X_features_list = [] + if not torch.is_tensor(data_X): + data_X = torch.from_numpy(data_X) + data_X = data_X.to(self.device) + + dataset = TensorDataset(data_X) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + for m, model in enumerate(random_models if random_models else + self._generate_models(n_models=self.n_models, channel=data_X.shape[1])): + model.eval() + curr_features_list = [] + for i, (X,) in enumerate(dataloader): + out = model(X) + curr_features_list.append(out) + curr_features = torch.cat(curr_features_list, 0) + X_features_list.append(curr_features) + X_features = torch.cat(X_features_list, 1) + X_features = X_features / torch.sqrt(torch.asarray(X_features.shape[1], device=self.device)) + + return X_features + + def inner_prod(self, Phi2: RKMEImageStatSpecification) -> float: + """Compute the inner product between two RKME Image specifications + + Parameters + ---------- + Phi2 : RKMEImageStatSpecification + The other RKME Image specification. + + Returns + ------- + float + The inner product between two RKME Image specifications. + """ + + if self.buffering and Phi2.buffering: + if (id(self), id(Phi2)) in RKMEImageStatSpecification.inner_prod_buffer: + return RKMEImageStatSpecification.inner_prod_buffer[(id(self), id(Phi2))] + + v = self._inner_prod_ntk(Phi2) + if self.buffering and Phi2.buffering: + RKMEImageStatSpecification.inner_prod_buffer[(id(self), id(Phi2))] = v + RKMEImageStatSpecification.inner_prod_buffer[(id(Phi2), id(self))] = v + return v + + def _inner_prod_ntk(self, Phi2: RKMEImageStatSpecification) -> float: + beta_1 = self.beta.reshape(1, -1).detach() + beta_2 = Phi2.beta.reshape(1, -1).detach() + + Z1 = self.z.to(self.device) + Z2 = Phi2.z.to(self.device) + + # Use the old way + assert Z1.shape[1] == Z2.shape[1] + random_models = list(self._generate_models(n_models=self.n_models * 4, channel=Z1.shape[1])) + z1_features = self._generate_random_feature(Z1, random_models=random_models) + z2_features = self._generate_random_feature(Z2, random_models=random_models) + K_zz = self._calc_ntk_from_feature(z1_features, z2_features) + + v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() + + RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 + return v + + def dist(self, Phi2: RKMEImageStatSpecification, omit_term1: bool = False) -> float: + """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications + + Parameters + ---------- + Phi2 : RKMEImageStatSpecification + The other RKME specification. + omit_term1 : bool, optional + True if the inner product of self with itself can be omitted, by default False. + """ + + with torch.no_grad(): + if omit_term1: + term1 = 0 + else: + term1 = self.inner_prod(self) + term2 = self.inner_prod(Phi2) + term3 = Phi2.inner_prod(Phi2) + + return float(term1 - 2 * term2 + term3) + + @staticmethod + def _calc_ntk_from_feature(x1_feature: torch.Tensor, x2_feature: torch.Tensor): + K_12 = x1_feature @ x2_feature.T + 0.01 + return K_12 + + def _calc_ntk_empirical(self, x1: torch.Tensor, x2: torch.Tensor): + if x1.shape[1] != x2.shape[1]: + raise ValueError("The channel of two rkme image specification should be equal (e.g. 3 or 1).") + + results = [] + for m, model in enumerate(self._generate_models(n_models=self.n_models, channel=x1.shape[1])): + # Compute J(x1) + # jac1 = vamp(lambda x: jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x))(x1) + jac1 = jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x1) + jac1 = [j.flatten(2) for j in jac1] + + # Compute J(x2) + jac2 = functional_call(model, model.parameters(), x2) + jac2 = [j.flatten(2) for j in jac2] + + result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)]) + results.append(result.sum(0)) + + results = torch.stack(results) + return results.mean(0) + + def herding(self, T: int) -> np.ndarray: + raise NotImplementedError( + "The function herding hasn't been supported in Image RKME Specification.") + + def _sampling_candidates(self, N: int) -> np.ndarray: + raise NotImplementedError() + + def get_beta(self) -> np.ndarray: + return self.beta.detach().cpu().numpy() + + def get_z(self) -> np.ndarray: + return self.z.detach().cpu().numpy() + + def save(self, filepath: str): + """Save the computed RKME Image specification to a specified path in JSON format. + + Parameters + ---------- + filepath : str + The specified saving path. + """ + save_path = filepath + rkme_to_save = copy.deepcopy(self.__dict__) + if torch.is_tensor(rkme_to_save["z"]): + rkme_to_save["z"] = rkme_to_save["z"].detach().cpu().numpy() + rkme_to_save["z"] = rkme_to_save["z"].tolist() + if torch.is_tensor(rkme_to_save["beta"]): + rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() + rkme_to_save["beta"] = rkme_to_save["beta"].tolist() + rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + + json.dump( + rkme_to_save, + codecs.open(save_path, "w", encoding="utf-8"), + separators=(",", ":"), + ) + + def load(self, filepath: str) -> bool: + """Load a RKME Image specification file in JSON format from the specified path. + + Parameters + ---------- + filepath : str + The specified loading path. + + Returns + ------- + bool + True if the RKME is loaded successfully. + """ + # Load JSON file: + load_path = filepath + if os.path.exists(load_path): + with codecs.open(load_path, "r", encoding="utf-8") as fin: + obj_text = fin.read() + rkme_load = json.loads(obj_text) + rkme_load["device"] = choose_device(rkme_load["cuda_idx"]) + rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"])).float() + rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"])) + + for d in self.__dir__(): + if d in rkme_load.keys(): + setattr(self, d, rkme_load[d]) + + self.beta = self.beta.to(self.device) + self.z = self.z.to(self.device) + + return True + else: + return False + + +def _get_zca_matrix(X, reg_coef=0.1): + X_flat = X.reshape(X.shape[0], -1) + cov = (X_flat.T @ X_flat) / X_flat.shape[0] + reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] + u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda()) + inv_sqrt_zca_eigs = s ** (-0.5) + whitening_transform = torch.einsum( + 'ij,j,kj->ik', u, inv_sqrt_zca_eigs, u) + + return whitening_transform + + +class _ConvNet_wide(nn.Module): + def __init__(self, channel, mu=None, sigma=None, k=4, net_width=128, net_depth=3, + net_act='relu', net_norm='none', net_pooling='avgpooling', im_size=(32, 32), chopped_head=False): + self.k = k + # print('Building Conv Model') + super().__init__() + + # net_depth = 1 + self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, + net_act, net_pooling, im_size, mu, sigma) + # print(shape_feat) + self.chopped_head = chopped_head + + def forward(self, x): + out = self.features(x) + # print(out.size()) + out = out.reshape(out.size(0), -1) + # print(out.size()) + return out + + def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size, mu, sigma): + k = self.k + + layers = [] + in_channels = channel + shape_feat = [in_channels, im_size[0], im_size[1]] + for d in range(net_depth): + layers += [build_conv2d_gaussian(in_channels, int(k * net_width), 3, + 1, mean=mu, std=sigma)] + shape_feat[0] = int(k * net_width) + + layers += [nn.ReLU(inplace=True)] + in_channels = int(k * net_width) + + layers += [nn.AvgPool2d(kernel_size=2, stride=2)] + shape_feat[1] //= 2 + shape_feat[2] //= 2 + + return nn.Sequential(*layers), shape_feat + +def build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean=None, std=None): + layer = nn.Conv2d(in_channels, out_channels, kernel, padding=padding) + if mean is None: + mean = 0 + if std is None: + std = np.sqrt(2)/np.sqrt(layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]) + # print('Initializing Conv. Mean=%.2f, std=%.2f'%(mean, std)) + torch.nn.init.normal_(layer.weight, mean, std) + torch.nn.init.normal_(layer.bias, 0, .1) + return layer \ No newline at end of file From 2bda163d5f544a6d287e3285c829dc913775a406 Mon Sep 17 00:00:00 2001 From: shihy Date: Wed, 25 Oct 2023 20:35:51 +0800 Subject: [PATCH 02/24] [ENH] Optimized graphics memory consumption for calculating dist. --- examples/dataset_image_workflow/main.py | 25 ++++--- learnware/specification/image.py | 88 ++++++++++++++----------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 1bb41ab..9495498 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -1,5 +1,7 @@ import numpy as np import torch +from tqdm import tqdm + from get_data import * import os import random @@ -25,8 +27,8 @@ processed_data_root = "./data/processed_data" tmp_dir = "./data/tmp" learnware_pool_dir = "./data/learnware_pool" dataset = "cifar10" -n_uploaders = 3 -n_users = 3 +n_uploaders = 30 +n_users = 20 n_classes = 10 data_root = os.path.join(origin_data_root, dataset) data_save_root = os.path.join(processed_data_root, dataset) @@ -126,7 +128,7 @@ def prepare_market(): except: pass os.makedirs(learnware_pool_dir, exist_ok=True) - for i in range(n_uploaders): + for i in tqdm(range(n_uploaders), total=n_uploaders, desc="Preparing..."): data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i)) model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) init_file_path = "./example_files/example_init.py" @@ -157,15 +159,15 @@ def test_search(gamma=0.1, load_market=True): improve_list = [] job_selector_score_list = [] ensemble_score_list = [] - for i in range(n_users): + for i in tqdm(range(n_users), total=n_users, desc="Searching..."): user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i)) user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) user_data = np.load(user_data_path) user_label = np.load(user_label_path) user_stat_spec = RKMEImageStatSpecification(cuda_idx=0) - user_stat_spec.generate_stat_spec_from_data(X=user_data, steps=5, resize=False) + user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) - logger.info("Searching Market for user: %d" % (i)) + logger.info("Searching Market for user: %d" % i) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( user_info ) @@ -174,7 +176,7 @@ def test_search(gamma=0.1, load_market=True): pred_y = learnware.predict(user_data) acc = eval_prediction(pred_y, user_label) acc_list.append(acc) - logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) + logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) # test reuse (job selector) # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) @@ -204,6 +206,11 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": - # prepare_data() - # prepare_model() + logger.info("=" * 40) + logger.info(f"n_uploaders:\t{n_uploaders}") + logger.info(f"n_users:\t{n_users}") + logger.info("=" * 40) + + prepare_data() + prepare_model() test_search(load_market=False) diff --git a/learnware/specification/image.py b/learnware/specification/image.py index 55cc580..88cd116 100644 --- a/learnware/specification/image.py +++ b/learnware/specification/image.py @@ -15,7 +15,6 @@ from torch import nn from torch.func import jacrev, functional_call from torch.utils.data import TensorDataset, DataLoader from torchvision.transforms import Resize -from tqdm import tqdm from .base import BaseStatSpecification from .rkme import solve_qp, choose_device, setup_seed @@ -146,7 +145,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) - for i in tqdm(range(steps), total=steps): + for i in range(steps): # Regenerate Random Models random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) @@ -209,27 +208,43 @@ class RKMEImageStatSpecification(BaseStatSpecification): loss.backward() optimizer.step() - def _generate_random_feature(self, data_X, batch_size=4096, random_models=None) -> torch.Tensor: - X_features_list = [] - if not torch.is_tensor(data_X): - data_X = torch.from_numpy(data_X) - data_X = data_X.to(self.device) + def _generate_random_feature(self, data_X, data_Y=None, batch_size=4096, random_models=None): + X_features_list, Y_features_list = [], [] + + dataset_X, dataset_Y = TensorDataset(data_X), None + dataloader_X, dataloader_Y = DataLoader(dataset_X, batch_size=batch_size, shuffle=True), None + if data_Y is not None: + dataset_Y = TensorDataset(data_Y) + dataloader_Y = DataLoader(dataset_Y, batch_size=batch_size, shuffle=True) + assert data_X.shape[1] == data_Y.shape[1] - dataset = TensorDataset(data_X) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) for m, model in enumerate(random_models if random_models else self._generate_models(n_models=self.n_models, channel=data_X.shape[1])): model.eval() + curr_features_list = [] - for i, (X,) in enumerate(dataloader): + for i, (X,) in enumerate(dataloader_X): out = model(X) curr_features_list.append(out) curr_features = torch.cat(curr_features_list, 0) X_features_list.append(curr_features) + + if data_Y is not None: + curr_features_list = [] + for i, (Y,) in enumerate(dataloader_Y): + out = model(Y) + curr_features_list.append(out) + curr_features = torch.cat(curr_features_list, 0) + Y_features_list.append(curr_features) + X_features = torch.cat(X_features_list, 1) X_features = X_features / torch.sqrt(torch.asarray(X_features.shape[1], device=self.device)) - - return X_features + if data_Y is None: + return X_features + else: + Y_features = torch.cat(Y_features_list, 1) + Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) + return X_features, Y_features def inner_prod(self, Phi2: RKMEImageStatSpecification) -> float: """Compute the inner product between two RKME Image specifications @@ -264,9 +279,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): # Use the old way assert Z1.shape[1] == Z2.shape[1] - random_models = list(self._generate_models(n_models=self.n_models * 4, channel=Z1.shape[1])) - z1_features = self._generate_random_feature(Z1, random_models=random_models) - z2_features = self._generate_random_feature(Z2, random_models=random_models) + random_models = self._generate_models(n_models=self.n_models * 4, channel=Z1.shape[1]) + z1_features, z2_features = self._generate_random_feature(data_X=Z1, data_Y=Z2, random_models=random_models) K_zz = self._calc_ntk_from_feature(z1_features, z2_features) v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() @@ -300,26 +314,26 @@ class RKMEImageStatSpecification(BaseStatSpecification): K_12 = x1_feature @ x2_feature.T + 0.01 return K_12 - def _calc_ntk_empirical(self, x1: torch.Tensor, x2: torch.Tensor): - if x1.shape[1] != x2.shape[1]: - raise ValueError("The channel of two rkme image specification should be equal (e.g. 3 or 1).") - - results = [] - for m, model in enumerate(self._generate_models(n_models=self.n_models, channel=x1.shape[1])): - # Compute J(x1) - # jac1 = vamp(lambda x: jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x))(x1) - jac1 = jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x1) - jac1 = [j.flatten(2) for j in jac1] - - # Compute J(x2) - jac2 = functional_call(model, model.parameters(), x2) - jac2 = [j.flatten(2) for j in jac2] - - result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)]) - results.append(result.sum(0)) - - results = torch.stack(results) - return results.mean(0) + # def _calc_ntk_empirical(self, x1: torch.Tensor, x2: torch.Tensor): + # if x1.shape[1] != x2.shape[1]: + # raise ValueError("The channel of two rkme image specification should be equal (e.g. 3 or 1).") + # + # results = [] + # for m, model in enumerate(self._generate_models(n_models=self.n_models, channel=x1.shape[1])): + # # Compute J(x1) + # # jac1 = vamp(lambda x: jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x))(x1) + # jac1 = jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x1) + # jac1 = [j.flatten(2) for j in jac1] + # + # # Compute J(x2) + # jac2 = functional_call(model, model.parameters(), x2) + # jac2 = [j.flatten(2) for j in jac2] + # + # result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)]) + # results.append(result.sum(0)) + # + # results = torch.stack(results) + # return results.mean(0) def herding(self, T: int) -> np.ndarray: raise NotImplementedError( @@ -432,7 +446,7 @@ class _ConvNet_wide(nn.Module): in_channels = channel shape_feat = [in_channels, im_size[0], im_size[1]] for d in range(net_depth): - layers += [build_conv2d_gaussian(in_channels, int(k * net_width), 3, + layers += [_build_conv2d_gaussian(in_channels, int(k * net_width), 3, 1, mean=mu, std=sigma)] shape_feat[0] = int(k * net_width) @@ -445,7 +459,7 @@ class _ConvNet_wide(nn.Module): return nn.Sequential(*layers), shape_feat -def build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean=None, std=None): +def _build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean=None, std=None): layer = nn.Conv2d(in_channels, out_channels, kernel, padding=padding) if mean is None: mean = 0 From bccc180d293073c3c4f263c4830a194f40ef0c1b Mon Sep 17 00:00:00 2001 From: shihy Date: Sat, 28 Oct 2023 19:20:10 +0800 Subject: [PATCH 03/24] [ENH] Introducing packages for nngp calculations --- learnware/specification/cnn_gp/__init__.py | 6 + learnware/specification/cnn_gp/data.py | 196 ++++++++++++ .../specification/cnn_gp/kernel_patch.py | 89 ++++++ .../specification/cnn_gp/kernel_save_tools.py | 58 ++++ learnware/specification/cnn_gp/kernels.py | 295 ++++++++++++++++++ 5 files changed, 644 insertions(+) create mode 100644 learnware/specification/cnn_gp/__init__.py create mode 100644 learnware/specification/cnn_gp/data.py create mode 100644 learnware/specification/cnn_gp/kernel_patch.py create mode 100644 learnware/specification/cnn_gp/kernel_save_tools.py create mode 100644 learnware/specification/cnn_gp/kernels.py diff --git a/learnware/specification/cnn_gp/__init__.py b/learnware/specification/cnn_gp/__init__.py new file mode 100644 index 0000000..044ef69 --- /dev/null +++ b/learnware/specification/cnn_gp/__init__.py @@ -0,0 +1,6 @@ +from . import kernels, data, kernel_save_tools +from .kernels import * +from .data import * +from .kernel_save_tools import * + +__all__ = kernels.__all__ + data.__all__ + kernel_save_tools.__all__ diff --git a/learnware/specification/cnn_gp/data.py b/learnware/specification/cnn_gp/data.py new file mode 100644 index 0000000..1f8b446 --- /dev/null +++ b/learnware/specification/cnn_gp/data.py @@ -0,0 +1,196 @@ +import torchvision +from torch.utils.data import ConcatDataset, DataLoader, Subset +import os +import numpy as np +import itertools + +__all__ = ('DatasetFromConfig', 'ProductIterator', 'DiagIterator', + 'print_timings') + + +def _this_worker_batch(N_batches, worker_rank, n_workers): + batches_per_worker = np.zeros([n_workers], dtype=np.int) + batches_per_worker[:] = N_batches // n_workers + batches_per_worker[:N_batches % n_workers] += 1 + + start_batch = np.sum(batches_per_worker[:worker_rank]) + batches_this_worker = batches_per_worker[worker_rank] + + return int(start_batch), int(batches_this_worker) + + +def _product_generator(N_batches_X, N_batches_X2, same): + for i in range(N_batches_X): + if same: + # Yield only upper triangle + yield (True, i, i) + for j in range(i+1 if same else 0, + N_batches_X2): + yield (False, i, j) + + +def _round_up_div(a, b): + return (a+b-1)//b + + +class ProductIterator(object): + """ + Returns an iterator for loading data from both X and X2. It divides the + load equally among `n_workers`, returning only the one that belongs to + `worker_rank`. + """ + def __init__(self, batch_size, X, X2=None, worker_rank=0, n_workers=1): + N_batches_X = _round_up_div(len(X), batch_size) + if X2 is None: + same = True + X2 = X + N_batches_X2 = N_batches_X + N_batches = max(1, N_batches_X * (N_batches_X+1) // 2) + else: + same = False + N_batches_X2 = _round_up_div(len(X2), batch_size) + N_batches = N_batches_X * N_batches_X2 + + start_batch, self.batches_this_worker = _this_worker_batch( + N_batches, worker_rank, n_workers) + + self.idx_iter = itertools.islice( + _product_generator(N_batches_X, N_batches_X2, same), + start_batch, + start_batch + self.batches_this_worker) + + self.worker_rank = worker_rank + self.prev_j = -2 # this + 1 = -1, which is not a valid j + self.X_loader = None + self.X2_loader = None + self.x_batch = None + self.X = X + self.X2 = X2 + self.same = same + self.batch_size = batch_size + + def __len__(self): + return self.batches_this_worker + + def __iter__(self): + return self + + def dataloader_beginning_at(self, i, dataset): + return iter(DataLoader( + Subset(dataset, range(i*self.batch_size, len(dataset))), + batch_size=self.batch_size)) + + def __next__(self): + same, i, j = next(self.idx_iter) + + if self.X_loader is None: + self.X_loader = self.dataloader_beginning_at(i, self.X) + + if j != self.prev_j+1: + self.X2_loader = self.dataloader_beginning_at(j, self.X2) + self.x_batch = next(self.X_loader) + self.prev_j = j + + return (same, + (i*self.batch_size, self.x_batch), + (j*self.batch_size, next(self.X2_loader))) + + +class DiagIterator(object): + def __init__(self, batch_size, X, X2=None): + self.batch_size = batch_size + dl = DataLoader(X, batch_size=batch_size) + if X2 is None: + self.same = True + self.it = iter(enumerate(dl)) + self.length = len(dl) + else: + dl2 = DataLoader(X2, batch_size=batch_size) + self.same = False + self.it = iter(enumerate(zip(dl, dl2))) + self.length = min(len(dl), len(dl2)) + + def __iter__(self): + return self + + def __len__(self): + return self.length + + def __next__(self): + if self.same: + i, xy = next(self.it) + xy2 = xy + else: + i, xy, xy2 = next(self.it) + ib = i*self.batch_size + return (self.same, (ib, xy), (ib, xy2)) + + +class DatasetFromConfig(object): + """ + A dataset that contains train, validation and test, and is created from a + config file. + """ + def __init__(self, datasets_path, config): + """ + Requires: + config.dataset_name (e.g. "MNIST") + config.train_range + config.test_range + """ + self.config = config + + trans = torchvision.transforms.ToTensor() + if len(config.transforms) > 0: + trans = torchvision.transforms.Compose([trans] + config.transforms) + + # Full datasets + datasets_path = os.path.join(datasets_path, config.dataset_name) + train_full = config.dataset(datasets_path, train=True, download=True, + transform=trans) + test_full = config.dataset(datasets_path, train=False, transform=trans) + self.data_full = ConcatDataset([train_full, test_full]) + + # Our training/test split + # (could omit some data, or include validation in test) + self.train = Subset(self.data_full, config.train_range) + self.validation = Subset(self.data_full, config.validation_range) + self.test = Subset(self.data_full, config.test_range) + + @staticmethod + def load_full(dataset): + return next(iter(DataLoader(dataset, batch_size=len(dataset)))) + + +def _hhmmss(s): + m, s = divmod(int(s), 60) + h, m = divmod(m, 60) + if h == 0.0: + return f"{m:02d}:{s:02d}" + else: + return f"{h:02d}:{m:02d}:{s:02d}" + + +def print_timings(iterator, desc="time", print_interval=2.): + """ + Prints the current total number of iterations, speed of iteration, and + elapsed time. + + Meant as a rudimentary replacement for `tqdm` that prints a new line at + each iteration, and thus can be used in multiple parallel processes in the + same terminal. + """ + import time + start_time = time.perf_counter() + total = len(iterator) + last_printed = -print_interval + for i, value in enumerate(iterator): + yield value + cur_time = time.perf_counter() + elapsed = cur_time - start_time + it_s = (i+1)/elapsed + total_s = total/it_s + if elapsed > last_printed + print_interval: + print(f"{desc}: {i+1}/{total} it, {it_s:.02f} it/s," + f"[{_hhmmss(elapsed)}<{_hhmmss(total_s)}]") + last_printed = elapsed diff --git a/learnware/specification/cnn_gp/kernel_patch.py b/learnware/specification/cnn_gp/kernel_patch.py new file mode 100644 index 0000000..53fbce3 --- /dev/null +++ b/learnware/specification/cnn_gp/kernel_patch.py @@ -0,0 +1,89 @@ +__all__ = ('ConvKP', 'NonlinKP') + + +class KernelPatch: + """ + Represents a block of the kernel matrix. + Critically, we need the variances of the rows and columns, even if the + diagonal isn't part of the block, and this introduces considerable + complexity. + In particular, we also need to know whether the + rows and columns of the matrix correspond, in which case, we need to do + something different when we add IID noise. + """ + def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None): + if isinstance(same_or_kp, KernelPatch): + same = same_or_kp.same + diag = same_or_kp.diag + xy = same_or_kp.xy + xx = same_or_kp.xx + yy = same_or_kp.yy + else: + same = same_or_kp + + self.Nx = xx.size(0) + self.Ny = yy.size(0) + self.W = xy.size(-2) + self.H = xy.size(-1) + + self.init(same, diag, xy, xx, yy) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __add__(self, other): + return self._do_elementwise(other, '__add__') + + def __mul__(self, other): + return self._do_elementwise(other, '__mul__') + + def _do_elementwise(self, other, op): + KP = type(self) + if isinstance(other, KernelPatch): + other = KP(other) + assert self.same == other.same + assert self.diag == other.diag + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other.xy), + getattr(self.xx, op)(other.xx), + getattr(self.yy, op)(other.yy) + ) + else: + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other), + getattr(self.xx, op)(other), + getattr(self.yy, op)(other) + ) + + +class ConvKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx*self.Ny, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + + +class NonlinKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx, self.Ny, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view( self.Ny, self.W, self.H) diff --git a/learnware/specification/cnn_gp/kernel_save_tools.py b/learnware/specification/cnn_gp/kernel_save_tools.py new file mode 100644 index 0000000..b0952af --- /dev/null +++ b/learnware/specification/cnn_gp/kernel_save_tools.py @@ -0,0 +1,58 @@ +import numpy as np +from .data import ProductIterator, DiagIterator, print_timings + +__all__ = ('create_h5py_dataset', 'save_K') + + +def create_h5py_dataset(f, batch_size, name, diag, N, N2): + """ + Creates a dataset named `name` on `f`, with chunks of size `batch_size`. + The chunks have leading dimension 1, so as to accommodate future resizing + of the leading dimension of the dataset (which starts at 1). + """ + if diag: + chunk_shape = (1, min(batch_size, N)) + shape = (1, N) + maxshape = (None, N) + else: + chunk_shape = (1, min(batch_size, N), min(batch_size, N2)) + shape = (1, N, N2) + maxshape = (None, N, N2) + return f.create_dataset(name, shape=shape, dtype=np.float32, + fillvalue=np.nan, chunks=chunk_shape, + maxshape=maxshape) + + +def save_K(f, kern, name, X, X2, diag, batch_size, worker_rank=0, n_workers=1, + print_interval=2.): + """ + Saves a kernel to the h5py file `f`. Creates its dataset with name `name` + if necessary. + """ + if name in f.keys(): + print("Skipping {} (group exists)".format(name)) + return + else: + N = len(X) + N2 = N if X2 is None else len(X2) + out = create_h5py_dataset(f, batch_size, name, diag, N, N2) + + if diag: + # Don't split the load for diagonals, they are cheap + it = DiagIterator(batch_size, X, X2) + else: + it = ProductIterator(batch_size, X, X2, worker_rank=worker_rank, + n_workers=n_workers) + it = print_timings(it, desc=f"{name} (worker {worker_rank}/{n_workers})", + print_interval=print_interval) + + for same, (i, (x, _y)), (j, (x2, _y2)) in it: + k = kern(x, x2, same, diag) + if np.any(np.isinf(k)) or np.any(np.isnan(k)): + print(f"About to write a nan or inf for {i},{j}") + import ipdb; ipdb.set_trace() + + if diag: + out[0, i:i+len(x)] = k + else: + out[0, i:i+len(x), j:j+len(x2)] = k diff --git a/learnware/specification/cnn_gp/kernels.py b/learnware/specification/cnn_gp/kernels.py new file mode 100644 index 0000000..0e7af1c --- /dev/null +++ b/learnware/specification/cnn_gp/kernels.py @@ -0,0 +1,295 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .kernel_patch import ConvKP, NonlinKP +import math + + +__all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "Mixture", + "MixtureModule", "Sum", "SumModule", "resnet_block") + +class NNGPKernel(nn.Module): + """ + Transforms one kernel matrix into another. + [N1, N2, W, H] -> [N1, N2, W, H] + """ + def forward(self, x, y=None, same=None, diag=False): + """ + Either takes one minibatch (x), or takes two minibatches (x and y), and + a boolean indicating whether they're the same. + """ + if y is None: + assert same is None + y = x + same = True + + assert not diag or len(x) == len(y), ( + "diagonal kernels must operate with data of equal length") + + assert 4==len(x.size()) + assert 4==len(y.size()) + assert x.size(1) == y.size(1) + assert x.size(2) == y.size(2) + assert x.size(3) == y.size(3) + + N1 = x.size(0) + N2 = y.size(0) + C = x.size(1) + W = x.size(2) + H = x.size(3) + + # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H] + if diag: + xy = (x*y).mean(1, keepdim=True) + else: + xy = (x.unsqueeze(1)*y).mean(2).view(N1*N2, 1, W, H) + xx = (x**2).mean(1, keepdim=True) + yy = (y**2).mean(1, keepdim=True) + + initial_kp = ConvKP(same, diag, xy, xx, yy) + final_kp = self.propagate(initial_kp) + r = NonlinKP(final_kp).xy + if diag: + return r.view(N1) + else: + return r.view(N1, N2) + + +class Conv2d(NNGPKernel): + def __init__(self, kernel_size, stride=1, padding="same", dilation=1, + var_weight=1., var_bias=0., in_channel_multiplier=1, + out_channel_multiplier=1): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.var_weight = var_weight + self.var_bias = var_bias + self.kernel_has_row_of_zeros = False + if padding == "same": + self.padding = dilation*(kernel_size//2) + if kernel_size % 2 == 0: + self.kernel_has_row_of_zeros = True + else: + self.padding = padding + + if self.kernel_has_row_of_zeros: + # We need to pad one side larger than the other. We just make a + # kernel that is slightly too large and make its last column and + # row zeros. + kernel = t.ones(1, 1, self.kernel_size+1, self.kernel_size+1) + kernel[:, :, 0, :] = 0. + kernel[:, :, :, 0] = 0. + else: + kernel = t.ones(1, 1, self.kernel_size, self.kernel_size) + self.register_buffer('kernel', kernel + * (self.var_weight / self.kernel_size**2)) + self.in_channel_multiplier, self.out_channel_multiplier = ( + in_channel_multiplier, out_channel_multiplier) + + def propagate(self, kp): + kp = ConvKP(kp) + def f(patch): + return (F.conv2d(patch, self.kernel, stride=self.stride, + padding=self.padding, dilation=self.dilation) + + self.var_bias) + return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy)) + + def nn(self, channels, in_channels=None, out_channels=None): + if in_channels is None: + in_channels = channels + if out_channels is None: + out_channels = channels + conv2d = nn.Conv2d( + in_channels=in_channels * self.in_channel_multiplier, + out_channels=out_channels * self.out_channel_multiplier, + kernel_size=self.kernel_size + ( + 1 if self.kernel_has_row_of_zeros else 0), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=(self.var_bias > 0.), + ) + conv2d.weight.data.normal_(0, math.sqrt( + self.var_weight / conv2d.in_channels) / self.kernel_size) + if self.kernel_has_row_of_zeros: + conv2d.weight.data[:, :, 0, :] = 0 + conv2d.weight.data[:, :, :, 0] = 0 + if self.var_bias > 0.: + conv2d.bias.data.normal_(0, math.sqrt(self.var_bias)) + return conv2d + + def layers(self): + return 1 + + +class ReLU(NNGPKernel): + """ + A ReLU nonlinearity, the covariance is numerically stabilised by clamping + values. + """ + f32_tiny = np.finfo(np.float32).tiny + def propagate(self, kp): + kp = NonlinKP(kp) + """ + We need to calculate (xy, xx, yy == c, v₁, v₂): + ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤ + √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂) + + which is equivalent to: + 1/2π ( √(v₁v₂ - c²) + (π - θ)c ) + + # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2) + """ + xx_yy = kp.xx * kp.yy + self.f32_tiny + + # Clamp these so the outputs are not NaN + cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1) + sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0)) + theta = t.acos(cos_theta) + xy = (sin_theta + (math.pi - theta)*kp.xy) / (2*math.pi) + + xx = kp.xx/2. + if kp.same: + yy = xx + if kp.diag: + xy = xx + else: + # Make sure the diagonal agrees with `xx` + eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device) + xy = (1-eye)*xy + eye*xx + else: + yy = kp.yy/2. + return NonlinKP(kp.same, kp.diag, xy, xx, yy) + + def nn(self, channels, in_channels=None, out_channels=None): + assert in_channels is None + assert out_channels is None + return nn.ReLU() + + def layers(self): + return 0 + + +#### Combination classes + +class Sequential(NNGPKernel): + def __init__(self, *mods): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + def propagate(self, kp): + for mod in self.mods: + kp = mod.propagate(kp) + return kp + def nn(self, channels, in_channels=None, out_channels=None): + if len(self.mods) == 0: + return nn.Sequential() + elif len(self.mods) == 1: + return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels) + else: + return nn.Sequential( + self.mods[0].nn(channels, in_channels=in_channels), + *[mod.nn(channels) for mod in self.mods[1:-1]], + self.mods[-1].nn(channels, out_channels=out_channels) + ) + def layers(self): + return sum(mod.layers() for mod in self.mods) + + +class Mixture(NNGPKernel): + """ + Applys multiple modules to the input, and sums the result + (e.g. for the implementation of a ResNet). + + Parameterised by proportion of each module (proportions add + up to one, such that, if each model has average variance 1, + then the output will also have average variance 1. + """ + def __init__(self, mods, logit_proportions=None): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + if logit_proportions is None: + logit_proportions = t.zeros(len(mods)) + self.logit = nn.Parameter(logit_proportions) + def propagate(self, kp): + proportions = F.softmax(self.logit, dim=0) + total = self.mods[0].propagate(kp) * proportions[0] + for i in range(1, len(self.mods)): + total = total + (self.mods[i].propagate(kp) * proportions[i]) + return total + def nn(self, channels, in_channels=None, out_channels=None): + return MixtureModule([mod.nn(channels, in_channels=in_channels, out_channels=out_channels) for mod in self.mods], self.logit) + def layers(self): + return max(mod.layers() for mod in self.mods) + +class MixtureModule(nn.Module): + def __init__(self, mods, logit_parameter): + super().__init__() + self.mods = mods + self.logit = t.tensor(logit_parameter) + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + def forward(self, input): + sqrt_proportions = F.softmax(self.logit, dim=0).sqrt() + total = self.mods[0](input)*sqrt_proportions[0] + for i in range(1, len(self.mods)): + total = total + self.mods[i](input) # *sqrt_proportions[i] + return total + + +class Sum(NNGPKernel): + def __init__(self, mods): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + def propagate(self, kp): + # This adds 0 to the first kp, hopefully that's a noop + return sum(m.propagate(kp) for m in self.mods) + def nn(self, channels, in_channels=None, out_channels=None): + return SumModule([ + mod.nn(channels, in_channels=in_channels, out_channels=out_channels) + for mod in self.mods]) + def layers(self): + return max(mod.layers() for mod in self.mods) + + +class SumModule(nn.Module): + def __init__(self, mods): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + def forward(self, input): + # This adds 0 to the first value, hopefully that's a noop + return sum(m(input) for m in self.mods) + + +def resnet_block(stride=1, projection_shortcut=False, multiplier=1): + if stride == 1 and not projection_shortcut: + return Sum([ + Sequential(), + Sequential( + ReLU(), + Conv2d(3, stride=stride, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), + ReLU(), + Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), + ) + ]) + else: + return Sequential( + ReLU(), + Sum([ + Conv2d(1, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), + Sequential( + Conv2d(3, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), + ReLU(), + Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), + ) + ]), + ) From 27d9afce65df912be6d9341cbd2acb0f6b298dab Mon Sep 17 00:00:00 2001 From: shihy Date: Sat, 28 Oct 2023 19:21:55 +0800 Subject: [PATCH 04/24] [ENH] Supports NNGP inner product computation --- learnware/specification/image.py | 119 +++++++++++++------------------ 1 file changed, 50 insertions(+), 69 deletions(-) diff --git a/learnware/specification/image.py b/learnware/specification/image.py index 88cd116..4dd4647 100644 --- a/learnware/specification/image.py +++ b/learnware/specification/image.py @@ -16,24 +16,22 @@ from torch.func import jacrev, functional_call from torch.utils.data import TensorDataset, DataLoader from torchvision.transforms import Resize +from . import cnn_gp from .base import BaseStatSpecification from .rkme import solve_qp, choose_device, setup_seed class RKMEImageStatSpecification(BaseStatSpecification): - inner_prod_buffer = dict() INNER_PRODUCT_COUNT = 0 IMAGE_WIDTH = 32 - def __init__(self, cuda_idx: int = -1, buffering: bool=True, **kwargs): + def __init__(self, cuda_idx: int = -1, **kwargs): """Initializing RKME Image specification's parameters. Parameters ---------- cuda_idx : int A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. - buffering: bool - When buffering is True, the result of inner_prod will be buffered according to id(object), avoiding duplicate kernel function calculations, by default True. """ self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. # torch.cuda.empty_cache() @@ -42,12 +40,11 @@ class RKMEImageStatSpecification(BaseStatSpecification): self.beta = None self.cuda_idx = cuda_idx self.device = choose_device(cuda_idx=cuda_idx) - self.buffering = buffering + self.cache = False self.n_models = kwargs["n_models"] if "n_models" in kwargs else 16 self.model_config = { - "k": 2, "mu": 0, "sigma": None, 'chopped_head': True, - "net_width": 128, "net_depth": 3, "net_act": "relu" + "k": 2, "mu": 0, "sigma": None, "net_width": 128, "net_depth": 3 } if "model_config" not in kwargs else kwargs["model_config"] setup_seed(0) @@ -70,7 +67,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): steps: int=100, resize: bool = False, nonnegative_beta: bool = True, - reduce: bool = True + reduce: bool = True, + **kwargs ): """Construct reduced set from raw dataset using iterative optimization. @@ -125,7 +123,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): Z_shape = tuple([K] + list(X_shape)[1:]) X_train = (X - torch.mean(X, [0, 2, 3], keepdim=True)) / (torch.std(X, [0, 2, 3], keepdim=True)) - if X_train.shape[1] > 1: + + if X_train.shape[1] > 1 and ("whitening" not in kwargs or kwargs["whitening"]): whitening = _get_zca_matrix(X_train) X_train = X_train.reshape(num_points, -1) @ whitening X_train = X_train.view(*X_shape) @@ -259,30 +258,21 @@ class RKMEImageStatSpecification(BaseStatSpecification): float The inner product between two RKME Image specifications. """ - - if self.buffering and Phi2.buffering: - if (id(self), id(Phi2)) in RKMEImageStatSpecification.inner_prod_buffer: - return RKMEImageStatSpecification.inner_prod_buffer[(id(self), id(Phi2))] - v = self._inner_prod_ntk(Phi2) - if self.buffering and Phi2.buffering: - RKMEImageStatSpecification.inner_prod_buffer[(id(self), id(Phi2))] = v - RKMEImageStatSpecification.inner_prod_buffer[(id(Phi2), id(self))] = v return v def _inner_prod_ntk(self, Phi2: RKMEImageStatSpecification) -> float: - beta_1 = self.beta.reshape(1, -1).detach() - beta_2 = Phi2.beta.reshape(1, -1).detach() + beta_1 = self.beta.reshape(1, -1).detach().to(self.device) + beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) Z1 = self.z.to(self.device) Z2 = Phi2.z.to(self.device) - # Use the old way - assert Z1.shape[1] == Z2.shape[1] - random_models = self._generate_models(n_models=self.n_models * 4, channel=Z1.shape[1]) - z1_features, z2_features = self._generate_random_feature(data_X=Z1, data_Y=Z2, random_models=random_models) - K_zz = self._calc_ntk_from_feature(z1_features, z2_features) - + kernel_fn = _build_ConvNet_NNGP(channel=Z1.shape[1], **self.model_config).to(self.device) + if id(self) == id(Phi2): + K_zz = kernel_fn(Z1) + else: + K_zz = kernel_fn(Z1, Z2) v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 @@ -299,42 +289,22 @@ class RKMEImageStatSpecification(BaseStatSpecification): True if the inner product of self with itself can be omitted, by default False. """ - with torch.no_grad(): - if omit_term1: - term1 = 0 - else: - term1 = self.inner_prod(self) - term2 = self.inner_prod(Phi2) - term3 = Phi2.inner_prod(Phi2) + if omit_term1: + term1 = 0 + else: + term1 = self.inner_prod(self) + term2 = self.inner_prod(Phi2) + term3 = Phi2.inner_prod(Phi2) + + v = float(term1 - 2 * term2 + term3) - return float(term1 - 2 * term2 + term3) + return v @staticmethod def _calc_ntk_from_feature(x1_feature: torch.Tensor, x2_feature: torch.Tensor): K_12 = x1_feature @ x2_feature.T + 0.01 return K_12 - # def _calc_ntk_empirical(self, x1: torch.Tensor, x2: torch.Tensor): - # if x1.shape[1] != x2.shape[1]: - # raise ValueError("The channel of two rkme image specification should be equal (e.g. 3 or 1).") - # - # results = [] - # for m, model in enumerate(self._generate_models(n_models=self.n_models, channel=x1.shape[1])): - # # Compute J(x1) - # # jac1 = vamp(lambda x: jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x))(x1) - # jac1 = jacrev(lambda p, i: functional_call(model, p, i), argnums=0)(dict(model.named_parameters()), x1) - # jac1 = [j.flatten(2) for j in jac1] - # - # # Compute J(x2) - # jac2 = functional_call(model, model.parameters(), x2) - # jac2 = [j.flatten(2) for j in jac2] - # - # result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)]) - # results.append(result.sum(0)) - # - # results = torch.stack(results) - # return results.mean(0) - def herding(self, T: int) -> np.ndarray: raise NotImplementedError( "The function herding hasn't been supported in Image RKME Specification.") @@ -392,8 +362,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): obj_text = fin.read() rkme_load = json.loads(obj_text) rkme_load["device"] = choose_device(rkme_load["cuda_idx"]) - rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"])).float() - rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"])) + rkme_load["z"] = torch.from_numpy(np.array(rkme_load["z"], dtype="float32")) + rkme_load["beta"] = torch.from_numpy(np.array(rkme_load["beta"], dtype="float64")) for d in self.__dir__(): if d in rkme_load.keys(): @@ -420,26 +390,21 @@ def _get_zca_matrix(X, reg_coef=0.1): class _ConvNet_wide(nn.Module): - def __init__(self, channel, mu=None, sigma=None, k=4, net_width=128, net_depth=3, - net_act='relu', net_norm='none', net_pooling='avgpooling', im_size=(32, 32), chopped_head=False): + def __init__(self, channel, mu=None, sigma=None, k=2, net_width=128, + net_depth=3, im_size=(32, 32)): self.k = k - # print('Building Conv Model') super().__init__() - - # net_depth = 1 - self.features, shape_feat = self._make_layers(channel, net_width, net_depth, net_norm, - net_act, net_pooling, im_size, mu, sigma) - # print(shape_feat) - self.chopped_head = chopped_head + self.features, shape_feat = self._make_layers(channel, net_width, net_depth, + im_size, mu, sigma) + # self.aggregation = nn.AvgPool2d(kernel_size=shape_feat[1]) def forward(self, x): out = self.features(x) - # print(out.size()) out = out.reshape(out.size(0), -1) - # print(out.size()) + # out = self.aggregation(out).reshape(out.size(0), -1) return out - def _make_layers(self, channel, net_width, net_depth, net_norm, net_act, net_pooling, im_size, mu, sigma): + def _make_layers(self, channel, net_width, net_depth, im_size, mu, sigma): k = self.k layers = [] @@ -468,4 +433,20 @@ def _build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean= # print('Initializing Conv. Mean=%.2f, std=%.2f'%(mean, std)) torch.nn.init.normal_(layer.weight, mean, std) torch.nn.init.normal_(layer.bias, 0, .1) - return layer \ No newline at end of file + return layer + +def _build_ConvNet_NNGP(channel, k=2, net_width=128, + net_depth=3, kernel_size=3, im_size=(32, 32), **kwargs): + layers = [] + for d in range(net_depth): + layers += [cnn_gp.Conv2d(kernel_size=kernel_size, padding="same", var_bias=0.1, + var_weight=np.sqrt(2))] + # /np.sqrt(kernel_size * kernel_size * channel) + layers += [cnn_gp.ReLU()] + # AvgPooling + layers += [cnn_gp.Conv2d(kernel_size=2, padding=0, stride=2)] + + assert im_size[0] % (2 ** net_depth) == 0 + layers.append(cnn_gp.Conv2d(kernel_size=im_size[0] // (2 ** net_depth), padding=0)) + + return cnn_gp.Sequential(*layers) From 106ffc2a0a45016938b329b7faf5c15b9b0d8ec4 Mon Sep 17 00:00:00 2001 From: shihy Date: Sat, 28 Oct 2023 19:25:15 +0800 Subject: [PATCH 05/24] [ENH] Adaptation of the new RKME Image --- learnware/specification/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 556aefb..e14d7e5 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,4 @@ from .utils import generate_stat_spec from .base import Specification, BaseStatSpecification from .rkme import RKMEStatSpecification +from .image import RKMEImageStatSpecification From f49be25a3997f185f812acfb7b82eba4dca46be5 Mon Sep 17 00:00:00 2001 From: shihy Date: Sat, 28 Oct 2023 19:36:04 +0800 Subject: [PATCH 06/24] [ENH] Organize the project structure --- learnware/specification/image/__init__.py | 1 + learnware/specification/{ => image}/cnn_gp/__init__.py | 0 learnware/specification/{ => image}/cnn_gp/data.py | 0 learnware/specification/{ => image}/cnn_gp/kernel_patch.py | 0 .../specification/{ => image}/cnn_gp/kernel_save_tools.py | 0 learnware/specification/{ => image}/cnn_gp/kernels.py | 0 learnware/specification/{image.py => image/rkme.py} | 7 +++---- 7 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 learnware/specification/image/__init__.py rename learnware/specification/{ => image}/cnn_gp/__init__.py (100%) rename learnware/specification/{ => image}/cnn_gp/data.py (100%) rename learnware/specification/{ => image}/cnn_gp/kernel_patch.py (100%) rename learnware/specification/{ => image}/cnn_gp/kernel_save_tools.py (100%) rename learnware/specification/{ => image}/cnn_gp/kernels.py (100%) rename learnware/specification/{image.py => image/rkme.py} (98%) diff --git a/learnware/specification/image/__init__.py b/learnware/specification/image/__init__.py new file mode 100644 index 0000000..b8bd2d2 --- /dev/null +++ b/learnware/specification/image/__init__.py @@ -0,0 +1 @@ +from .rkme import RKMEImageStatSpecification \ No newline at end of file diff --git a/learnware/specification/cnn_gp/__init__.py b/learnware/specification/image/cnn_gp/__init__.py similarity index 100% rename from learnware/specification/cnn_gp/__init__.py rename to learnware/specification/image/cnn_gp/__init__.py diff --git a/learnware/specification/cnn_gp/data.py b/learnware/specification/image/cnn_gp/data.py similarity index 100% rename from learnware/specification/cnn_gp/data.py rename to learnware/specification/image/cnn_gp/data.py diff --git a/learnware/specification/cnn_gp/kernel_patch.py b/learnware/specification/image/cnn_gp/kernel_patch.py similarity index 100% rename from learnware/specification/cnn_gp/kernel_patch.py rename to learnware/specification/image/cnn_gp/kernel_patch.py diff --git a/learnware/specification/cnn_gp/kernel_save_tools.py b/learnware/specification/image/cnn_gp/kernel_save_tools.py similarity index 100% rename from learnware/specification/cnn_gp/kernel_save_tools.py rename to learnware/specification/image/cnn_gp/kernel_save_tools.py diff --git a/learnware/specification/cnn_gp/kernels.py b/learnware/specification/image/cnn_gp/kernels.py similarity index 100% rename from learnware/specification/cnn_gp/kernels.py rename to learnware/specification/image/cnn_gp/kernels.py diff --git a/learnware/specification/image.py b/learnware/specification/image/rkme.py similarity index 98% rename from learnware/specification/image.py rename to learnware/specification/image/rkme.py index 4dd4647..103821a 100644 --- a/learnware/specification/image.py +++ b/learnware/specification/image/rkme.py @@ -6,19 +6,18 @@ import functools import json import os -from typing import Any, Union +from typing import Any import numpy as np import torch import torch_optimizer from torch import nn -from torch.func import jacrev, functional_call from torch.utils.data import TensorDataset, DataLoader from torchvision.transforms import Resize from . import cnn_gp -from .base import BaseStatSpecification -from .rkme import solve_qp, choose_device, setup_seed +from ..base import BaseStatSpecification +from ..rkme import solve_qp, choose_device, setup_seed class RKMEImageStatSpecification(BaseStatSpecification): From e26aec5e008ffac2d7b4dc961d04d5b96fb10892 Mon Sep 17 00:00:00 2001 From: shihy Date: Sun, 29 Oct 2023 11:20:15 +0800 Subject: [PATCH 07/24] [MNT] Remove redundant code --- learnware/specification/image/rkme.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/learnware/specification/image/rkme.py b/learnware/specification/image/rkme.py index 103821a..7cf2086 100644 --- a/learnware/specification/image/rkme.py +++ b/learnware/specification/image/rkme.py @@ -21,7 +21,7 @@ from ..rkme import solve_qp, choose_device, setup_seed class RKMEImageStatSpecification(BaseStatSpecification): - INNER_PRODUCT_COUNT = 0 + # INNER_PRODUCT_COUNT = 0 IMAGE_WIDTH = 32 def __init__(self, cuda_idx: int = -1, **kwargs): @@ -164,8 +164,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): x_features = x_features.to(self.device) z_features = self._generate_random_feature(Z, random_models=random_models) - K = self._calc_ntk_from_feature(z_features, z_features).to(self.device) - C = self._calc_ntk_from_feature(z_features, x_features).to(self.device) + K = self._calc_nngp_from_feature(z_features, z_features).to(self.device) + C = self._calc_nngp_from_feature(z_features, x_features).to(self.device) C = torch.sum(C, dim=1) / x_features.shape[0] if nonnegative_beta: @@ -196,8 +196,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): for i in range(3): z_features = self._generate_random_feature(Z, random_models=random_models) - K_z = self._calc_ntk_from_feature(z_features, z_features) - K_zx = self._calc_ntk_from_feature(x_features, z_features) + K_z = self._calc_nngp_from_feature(z_features, z_features) + K_zx = self._calc_nngp_from_feature(x_features, z_features) term_1 = torch.sum(K_z * (beta.T @ beta)) term_2 = torch.sum(K_zx * beta / x_features.shape[0]) loss = term_1 - 2 * term_2 @@ -257,10 +257,10 @@ class RKMEImageStatSpecification(BaseStatSpecification): float The inner product between two RKME Image specifications. """ - v = self._inner_prod_ntk(Phi2) + v = self._inner_prod_nngp(Phi2) return v - def _inner_prod_ntk(self, Phi2: RKMEImageStatSpecification) -> float: + def _inner_prod_nngp(self, Phi2: RKMEImageStatSpecification) -> float: beta_1 = self.beta.reshape(1, -1).detach().to(self.device) beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) @@ -274,7 +274,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): K_zz = kernel_fn(Z1, Z2) v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() - RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 + # RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 return v def dist(self, Phi2: RKMEImageStatSpecification, omit_term1: bool = False) -> float: @@ -300,7 +300,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): return v @staticmethod - def _calc_ntk_from_feature(x1_feature: torch.Tensor, x2_feature: torch.Tensor): + def _calc_nngp_from_feature(x1_feature: torch.Tensor, x2_feature: torch.Tensor): K_12 = x1_feature @ x2_feature.T + 0.01 return K_12 From 0f0442ff5ce5b3b99d236dab0325213c0c4b2986 Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 15:06:14 +0800 Subject: [PATCH 08/24] [MNT] Merging and removing redundant code --- learnware/specification/image/cnn_gp.py | 303 ++++++++++++++++++ .../specification/image/cnn_gp/__init__.py | 6 - learnware/specification/image/cnn_gp/data.py | 196 ----------- .../image/cnn_gp/kernel_patch.py | 89 ----- .../image/cnn_gp/kernel_save_tools.py | 58 ---- .../specification/image/cnn_gp/kernels.py | 295 ----------------- 6 files changed, 303 insertions(+), 644 deletions(-) create mode 100644 learnware/specification/image/cnn_gp.py delete mode 100644 learnware/specification/image/cnn_gp/__init__.py delete mode 100644 learnware/specification/image/cnn_gp/data.py delete mode 100644 learnware/specification/image/cnn_gp/kernel_patch.py delete mode 100644 learnware/specification/image/cnn_gp/kernel_save_tools.py delete mode 100644 learnware/specification/image/cnn_gp/kernels.py diff --git a/learnware/specification/image/cnn_gp.py b/learnware/specification/image/cnn_gp.py new file mode 100644 index 0000000..2bea13d --- /dev/null +++ b/learnware/specification/image/cnn_gp.py @@ -0,0 +1,303 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + + +__all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "ConvKP", "NonlinKP") +""" +References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen, ‘Deep Convolutional Networks as shallow Gaussian Processes’, in International Conference on Learning Representations, 2019. +""" + + +class NNGPKernel(nn.Module): + """ + Transforms one kernel matrix into another. + [N1, N2, W, H] -> [N1, N2, W, H] + """ + + def forward(self, x, y=None, same=None, diag=False): + """ + Either takes one minibatch (x), or takes two minibatches (x and y), and + a boolean indicating whether they're the same. + """ + if y is None: + assert same is None + y = x + same = True + + assert not diag or len(x) == len(y), "diagonal kernels must operate with data of equal length" + + assert 4 == len(x.size()) + assert 4 == len(y.size()) + assert x.size(1) == y.size(1) + assert x.size(2) == y.size(2) + assert x.size(3) == y.size(3) + + N1 = x.size(0) + N2 = y.size(0) + C = x.size(1) + W = x.size(2) + H = x.size(3) + + # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H] + if diag: + xy = (x * y).mean(1, keepdim=True) + else: + xy = (x.unsqueeze(1) * y).mean(2).view(N1 * N2, 1, W, H) + xx = (x**2).mean(1, keepdim=True) + yy = (y**2).mean(1, keepdim=True) + + initial_kp = ConvKP(same, diag, xy, xx, yy) + final_kp = self.propagate(initial_kp) + r = NonlinKP(final_kp).xy + if diag: + return r.view(N1) + else: + return r.view(N1, N2) + + +class Conv2d(NNGPKernel): + def __init__( + self, + kernel_size, + stride=1, + padding="same", + dilation=1, + var_weight=1.0, + var_bias=0.0, + in_channel_multiplier=1, + out_channel_multiplier=1, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.var_weight = var_weight + self.var_bias = var_bias + self.kernel_has_row_of_zeros = False + if padding == "same": + self.padding = dilation * (kernel_size // 2) + if kernel_size % 2 == 0: + self.kernel_has_row_of_zeros = True + else: + self.padding = padding + + if self.kernel_has_row_of_zeros: + # We need to pad one side larger than the other. We just make a + # kernel that is slightly too large and make its last column and + # row zeros. + kernel = t.ones(1, 1, self.kernel_size + 1, self.kernel_size + 1) + kernel[:, :, 0, :] = 0.0 + kernel[:, :, :, 0] = 0.0 + else: + kernel = t.ones(1, 1, self.kernel_size, self.kernel_size) + self.register_buffer("kernel", kernel * (self.var_weight / self.kernel_size**2)) + self.in_channel_multiplier, self.out_channel_multiplier = (in_channel_multiplier, out_channel_multiplier) + + def propagate(self, kp): + kp = ConvKP(kp) + + def f(patch): + return ( + F.conv2d(patch, self.kernel, stride=self.stride, padding=self.padding, dilation=self.dilation) + + self.var_bias + ) + + return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy)) + + def nn(self, channels, in_channels=None, out_channels=None): + if in_channels is None: + in_channels = channels + if out_channels is None: + out_channels = channels + conv2d = nn.Conv2d( + in_channels=in_channels * self.in_channel_multiplier, + out_channels=out_channels * self.out_channel_multiplier, + kernel_size=self.kernel_size + (1 if self.kernel_has_row_of_zeros else 0), + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=(self.var_bias > 0.0), + ) + conv2d.weight.data.normal_(0, math.sqrt(self.var_weight / conv2d.in_channels) / self.kernel_size) + if self.kernel_has_row_of_zeros: + conv2d.weight.data[:, :, 0, :] = 0 + conv2d.weight.data[:, :, :, 0] = 0 + if self.var_bias > 0.0: + conv2d.bias.data.normal_(0, math.sqrt(self.var_bias)) + return conv2d + + def layers(self): + return 1 + + +class ReLU(NNGPKernel): + """ + A ReLU nonlinearity, the covariance is numerically stabilised by clamping + values. + """ + + f32_tiny = np.finfo(np.float32).tiny + + def propagate(self, kp): + kp = NonlinKP(kp) + """ + We need to calculate (xy, xx, yy == c, v₁, v₂): + ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤ + √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂) + + which is equivalent to: + 1/2π ( √(v₁v₂ - c²) + (π - θ)c ) + + # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2) + """ + xx_yy = kp.xx * kp.yy + self.f32_tiny + + # Clamp these so the outputs are not NaN + cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1) + sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0)) + theta = t.acos(cos_theta) + xy = (sin_theta + (math.pi - theta) * kp.xy) / (2 * math.pi) + + xx = kp.xx / 2.0 + if kp.same: + yy = xx + if kp.diag: + xy = xx + else: + # Make sure the diagonal agrees with `xx` + eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device) + xy = (1 - eye) * xy + eye * xx + else: + yy = kp.yy / 2.0 + return NonlinKP(kp.same, kp.diag, xy, xx, yy) + + def nn(self, channels, in_channels=None, out_channels=None): + assert in_channels is None + assert out_channels is None + return nn.ReLU() + + def layers(self): + return 0 + + +#### Combination classes + + +class Sequential(NNGPKernel): + def __init__(self, *mods): + super().__init__() + self.mods = mods + for idx, mod in enumerate(mods): + self.add_module(str(idx), mod) + + def propagate(self, kp): + for mod in self.mods: + kp = mod.propagate(kp) + return kp + + def nn(self, channels, in_channels=None, out_channels=None): + if len(self.mods) == 0: + return nn.Sequential() + elif len(self.mods) == 1: + return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels) + else: + return nn.Sequential( + self.mods[0].nn(channels, in_channels=in_channels), + *[mod.nn(channels) for mod in self.mods[1:-1]], + self.mods[-1].nn(channels, out_channels=out_channels) + ) + + def layers(self): + return sum(mod.layers() for mod in self.mods) + + +class KernelPatch: + """ + Represents a block of the kernel matrix. + Critically, we need the variances of the rows and columns, even if the + diagonal isn't part of the block, and this introduces considerable + complexity. + In particular, we also need to know whether the + rows and columns of the matrix correspond, in which case, we need to do + something different when we add IID noise. + """ + + def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None): + if isinstance(same_or_kp, KernelPatch): + same = same_or_kp.same + diag = same_or_kp.diag + xy = same_or_kp.xy + xx = same_or_kp.xx + yy = same_or_kp.yy + else: + same = same_or_kp + + self.Nx = xx.size(0) + self.Ny = yy.size(0) + self.W = xy.size(-2) + self.H = xy.size(-1) + + self.init(same, diag, xy, xx, yy) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __add__(self, other): + return self._do_elementwise(other, "__add__") + + def __mul__(self, other): + return self._do_elementwise(other, "__mul__") + + def _do_elementwise(self, other, op): + KP = type(self) + if isinstance(other, KernelPatch): + other = KP(other) + assert self.same == other.same + assert self.diag == other.diag + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other.xy), + getattr(self.xx, op)(other.xx), + getattr(self.yy, op)(other.yy), + ) + else: + return KP( + self.same, + self.diag, + getattr(self.xy, op)(other), + getattr(self.xx, op)(other), + getattr(self.yy, op)(other), + ) + + +class ConvKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx * self.Ny, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + + +class NonlinKP(KernelPatch): + def init(self, same, diag, xy, xx, yy): + self.same = same + self.diag = diag + if diag: + self.xy = xy.view(self.Nx, 1, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, 1, self.W, self.H) + else: + self.xy = xy.view(self.Nx, self.Ny, self.W, self.H) + self.xx = xx.view(self.Nx, 1, self.W, self.H) + self.yy = yy.view(self.Ny, self.W, self.H) diff --git a/learnware/specification/image/cnn_gp/__init__.py b/learnware/specification/image/cnn_gp/__init__.py deleted file mode 100644 index 044ef69..0000000 --- a/learnware/specification/image/cnn_gp/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import kernels, data, kernel_save_tools -from .kernels import * -from .data import * -from .kernel_save_tools import * - -__all__ = kernels.__all__ + data.__all__ + kernel_save_tools.__all__ diff --git a/learnware/specification/image/cnn_gp/data.py b/learnware/specification/image/cnn_gp/data.py deleted file mode 100644 index 1f8b446..0000000 --- a/learnware/specification/image/cnn_gp/data.py +++ /dev/null @@ -1,196 +0,0 @@ -import torchvision -from torch.utils.data import ConcatDataset, DataLoader, Subset -import os -import numpy as np -import itertools - -__all__ = ('DatasetFromConfig', 'ProductIterator', 'DiagIterator', - 'print_timings') - - -def _this_worker_batch(N_batches, worker_rank, n_workers): - batches_per_worker = np.zeros([n_workers], dtype=np.int) - batches_per_worker[:] = N_batches // n_workers - batches_per_worker[:N_batches % n_workers] += 1 - - start_batch = np.sum(batches_per_worker[:worker_rank]) - batches_this_worker = batches_per_worker[worker_rank] - - return int(start_batch), int(batches_this_worker) - - -def _product_generator(N_batches_X, N_batches_X2, same): - for i in range(N_batches_X): - if same: - # Yield only upper triangle - yield (True, i, i) - for j in range(i+1 if same else 0, - N_batches_X2): - yield (False, i, j) - - -def _round_up_div(a, b): - return (a+b-1)//b - - -class ProductIterator(object): - """ - Returns an iterator for loading data from both X and X2. It divides the - load equally among `n_workers`, returning only the one that belongs to - `worker_rank`. - """ - def __init__(self, batch_size, X, X2=None, worker_rank=0, n_workers=1): - N_batches_X = _round_up_div(len(X), batch_size) - if X2 is None: - same = True - X2 = X - N_batches_X2 = N_batches_X - N_batches = max(1, N_batches_X * (N_batches_X+1) // 2) - else: - same = False - N_batches_X2 = _round_up_div(len(X2), batch_size) - N_batches = N_batches_X * N_batches_X2 - - start_batch, self.batches_this_worker = _this_worker_batch( - N_batches, worker_rank, n_workers) - - self.idx_iter = itertools.islice( - _product_generator(N_batches_X, N_batches_X2, same), - start_batch, - start_batch + self.batches_this_worker) - - self.worker_rank = worker_rank - self.prev_j = -2 # this + 1 = -1, which is not a valid j - self.X_loader = None - self.X2_loader = None - self.x_batch = None - self.X = X - self.X2 = X2 - self.same = same - self.batch_size = batch_size - - def __len__(self): - return self.batches_this_worker - - def __iter__(self): - return self - - def dataloader_beginning_at(self, i, dataset): - return iter(DataLoader( - Subset(dataset, range(i*self.batch_size, len(dataset))), - batch_size=self.batch_size)) - - def __next__(self): - same, i, j = next(self.idx_iter) - - if self.X_loader is None: - self.X_loader = self.dataloader_beginning_at(i, self.X) - - if j != self.prev_j+1: - self.X2_loader = self.dataloader_beginning_at(j, self.X2) - self.x_batch = next(self.X_loader) - self.prev_j = j - - return (same, - (i*self.batch_size, self.x_batch), - (j*self.batch_size, next(self.X2_loader))) - - -class DiagIterator(object): - def __init__(self, batch_size, X, X2=None): - self.batch_size = batch_size - dl = DataLoader(X, batch_size=batch_size) - if X2 is None: - self.same = True - self.it = iter(enumerate(dl)) - self.length = len(dl) - else: - dl2 = DataLoader(X2, batch_size=batch_size) - self.same = False - self.it = iter(enumerate(zip(dl, dl2))) - self.length = min(len(dl), len(dl2)) - - def __iter__(self): - return self - - def __len__(self): - return self.length - - def __next__(self): - if self.same: - i, xy = next(self.it) - xy2 = xy - else: - i, xy, xy2 = next(self.it) - ib = i*self.batch_size - return (self.same, (ib, xy), (ib, xy2)) - - -class DatasetFromConfig(object): - """ - A dataset that contains train, validation and test, and is created from a - config file. - """ - def __init__(self, datasets_path, config): - """ - Requires: - config.dataset_name (e.g. "MNIST") - config.train_range - config.test_range - """ - self.config = config - - trans = torchvision.transforms.ToTensor() - if len(config.transforms) > 0: - trans = torchvision.transforms.Compose([trans] + config.transforms) - - # Full datasets - datasets_path = os.path.join(datasets_path, config.dataset_name) - train_full = config.dataset(datasets_path, train=True, download=True, - transform=trans) - test_full = config.dataset(datasets_path, train=False, transform=trans) - self.data_full = ConcatDataset([train_full, test_full]) - - # Our training/test split - # (could omit some data, or include validation in test) - self.train = Subset(self.data_full, config.train_range) - self.validation = Subset(self.data_full, config.validation_range) - self.test = Subset(self.data_full, config.test_range) - - @staticmethod - def load_full(dataset): - return next(iter(DataLoader(dataset, batch_size=len(dataset)))) - - -def _hhmmss(s): - m, s = divmod(int(s), 60) - h, m = divmod(m, 60) - if h == 0.0: - return f"{m:02d}:{s:02d}" - else: - return f"{h:02d}:{m:02d}:{s:02d}" - - -def print_timings(iterator, desc="time", print_interval=2.): - """ - Prints the current total number of iterations, speed of iteration, and - elapsed time. - - Meant as a rudimentary replacement for `tqdm` that prints a new line at - each iteration, and thus can be used in multiple parallel processes in the - same terminal. - """ - import time - start_time = time.perf_counter() - total = len(iterator) - last_printed = -print_interval - for i, value in enumerate(iterator): - yield value - cur_time = time.perf_counter() - elapsed = cur_time - start_time - it_s = (i+1)/elapsed - total_s = total/it_s - if elapsed > last_printed + print_interval: - print(f"{desc}: {i+1}/{total} it, {it_s:.02f} it/s," - f"[{_hhmmss(elapsed)}<{_hhmmss(total_s)}]") - last_printed = elapsed diff --git a/learnware/specification/image/cnn_gp/kernel_patch.py b/learnware/specification/image/cnn_gp/kernel_patch.py deleted file mode 100644 index 53fbce3..0000000 --- a/learnware/specification/image/cnn_gp/kernel_patch.py +++ /dev/null @@ -1,89 +0,0 @@ -__all__ = ('ConvKP', 'NonlinKP') - - -class KernelPatch: - """ - Represents a block of the kernel matrix. - Critically, we need the variances of the rows and columns, even if the - diagonal isn't part of the block, and this introduces considerable - complexity. - In particular, we also need to know whether the - rows and columns of the matrix correspond, in which case, we need to do - something different when we add IID noise. - """ - def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None): - if isinstance(same_or_kp, KernelPatch): - same = same_or_kp.same - diag = same_or_kp.diag - xy = same_or_kp.xy - xx = same_or_kp.xx - yy = same_or_kp.yy - else: - same = same_or_kp - - self.Nx = xx.size(0) - self.Ny = yy.size(0) - self.W = xy.size(-2) - self.H = xy.size(-1) - - self.init(same, diag, xy, xx, yy) - - def __radd__(self, other): - return self.__add__(other) - - def __rmul__(self, other): - return self.__mul__(other) - - def __add__(self, other): - return self._do_elementwise(other, '__add__') - - def __mul__(self, other): - return self._do_elementwise(other, '__mul__') - - def _do_elementwise(self, other, op): - KP = type(self) - if isinstance(other, KernelPatch): - other = KP(other) - assert self.same == other.same - assert self.diag == other.diag - return KP( - self.same, - self.diag, - getattr(self.xy, op)(other.xy), - getattr(self.xx, op)(other.xx), - getattr(self.yy, op)(other.yy) - ) - else: - return KP( - self.same, - self.diag, - getattr(self.xy, op)(other), - getattr(self.xx, op)(other), - getattr(self.yy, op)(other) - ) - - -class ConvKP(KernelPatch): - def init(self, same, diag, xy, xx, yy): - self.same = same - self.diag = diag - if diag: - self.xy = xy.view(self.Nx, 1, self.W, self.H) - else: - self.xy = xy.view(self.Nx*self.Ny, 1, self.W, self.H) - self.xx = xx.view(self.Nx, 1, self.W, self.H) - self.yy = yy.view(self.Ny, 1, self.W, self.H) - - -class NonlinKP(KernelPatch): - def init(self, same, diag, xy, xx, yy): - self.same = same - self.diag = diag - if diag: - self.xy = xy.view(self.Nx, 1, self.W, self.H) - self.xx = xx.view(self.Nx, 1, self.W, self.H) - self.yy = yy.view(self.Ny, 1, self.W, self.H) - else: - self.xy = xy.view(self.Nx, self.Ny, self.W, self.H) - self.xx = xx.view(self.Nx, 1, self.W, self.H) - self.yy = yy.view( self.Ny, self.W, self.H) diff --git a/learnware/specification/image/cnn_gp/kernel_save_tools.py b/learnware/specification/image/cnn_gp/kernel_save_tools.py deleted file mode 100644 index b0952af..0000000 --- a/learnware/specification/image/cnn_gp/kernel_save_tools.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -from .data import ProductIterator, DiagIterator, print_timings - -__all__ = ('create_h5py_dataset', 'save_K') - - -def create_h5py_dataset(f, batch_size, name, diag, N, N2): - """ - Creates a dataset named `name` on `f`, with chunks of size `batch_size`. - The chunks have leading dimension 1, so as to accommodate future resizing - of the leading dimension of the dataset (which starts at 1). - """ - if diag: - chunk_shape = (1, min(batch_size, N)) - shape = (1, N) - maxshape = (None, N) - else: - chunk_shape = (1, min(batch_size, N), min(batch_size, N2)) - shape = (1, N, N2) - maxshape = (None, N, N2) - return f.create_dataset(name, shape=shape, dtype=np.float32, - fillvalue=np.nan, chunks=chunk_shape, - maxshape=maxshape) - - -def save_K(f, kern, name, X, X2, diag, batch_size, worker_rank=0, n_workers=1, - print_interval=2.): - """ - Saves a kernel to the h5py file `f`. Creates its dataset with name `name` - if necessary. - """ - if name in f.keys(): - print("Skipping {} (group exists)".format(name)) - return - else: - N = len(X) - N2 = N if X2 is None else len(X2) - out = create_h5py_dataset(f, batch_size, name, diag, N, N2) - - if diag: - # Don't split the load for diagonals, they are cheap - it = DiagIterator(batch_size, X, X2) - else: - it = ProductIterator(batch_size, X, X2, worker_rank=worker_rank, - n_workers=n_workers) - it = print_timings(it, desc=f"{name} (worker {worker_rank}/{n_workers})", - print_interval=print_interval) - - for same, (i, (x, _y)), (j, (x2, _y2)) in it: - k = kern(x, x2, same, diag) - if np.any(np.isinf(k)) or np.any(np.isnan(k)): - print(f"About to write a nan or inf for {i},{j}") - import ipdb; ipdb.set_trace() - - if diag: - out[0, i:i+len(x)] = k - else: - out[0, i:i+len(x), j:j+len(x2)] = k diff --git a/learnware/specification/image/cnn_gp/kernels.py b/learnware/specification/image/cnn_gp/kernels.py deleted file mode 100644 index 0e7af1c..0000000 --- a/learnware/specification/image/cnn_gp/kernels.py +++ /dev/null @@ -1,295 +0,0 @@ -import torch as t -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from .kernel_patch import ConvKP, NonlinKP -import math - - -__all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "Mixture", - "MixtureModule", "Sum", "SumModule", "resnet_block") - -class NNGPKernel(nn.Module): - """ - Transforms one kernel matrix into another. - [N1, N2, W, H] -> [N1, N2, W, H] - """ - def forward(self, x, y=None, same=None, diag=False): - """ - Either takes one minibatch (x), or takes two minibatches (x and y), and - a boolean indicating whether they're the same. - """ - if y is None: - assert same is None - y = x - same = True - - assert not diag or len(x) == len(y), ( - "diagonal kernels must operate with data of equal length") - - assert 4==len(x.size()) - assert 4==len(y.size()) - assert x.size(1) == y.size(1) - assert x.size(2) == y.size(2) - assert x.size(3) == y.size(3) - - N1 = x.size(0) - N2 = y.size(0) - C = x.size(1) - W = x.size(2) - H = x.size(3) - - # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H] - if diag: - xy = (x*y).mean(1, keepdim=True) - else: - xy = (x.unsqueeze(1)*y).mean(2).view(N1*N2, 1, W, H) - xx = (x**2).mean(1, keepdim=True) - yy = (y**2).mean(1, keepdim=True) - - initial_kp = ConvKP(same, diag, xy, xx, yy) - final_kp = self.propagate(initial_kp) - r = NonlinKP(final_kp).xy - if diag: - return r.view(N1) - else: - return r.view(N1, N2) - - -class Conv2d(NNGPKernel): - def __init__(self, kernel_size, stride=1, padding="same", dilation=1, - var_weight=1., var_bias=0., in_channel_multiplier=1, - out_channel_multiplier=1): - super().__init__() - self.kernel_size = kernel_size - self.stride = stride - self.dilation = dilation - self.var_weight = var_weight - self.var_bias = var_bias - self.kernel_has_row_of_zeros = False - if padding == "same": - self.padding = dilation*(kernel_size//2) - if kernel_size % 2 == 0: - self.kernel_has_row_of_zeros = True - else: - self.padding = padding - - if self.kernel_has_row_of_zeros: - # We need to pad one side larger than the other. We just make a - # kernel that is slightly too large and make its last column and - # row zeros. - kernel = t.ones(1, 1, self.kernel_size+1, self.kernel_size+1) - kernel[:, :, 0, :] = 0. - kernel[:, :, :, 0] = 0. - else: - kernel = t.ones(1, 1, self.kernel_size, self.kernel_size) - self.register_buffer('kernel', kernel - * (self.var_weight / self.kernel_size**2)) - self.in_channel_multiplier, self.out_channel_multiplier = ( - in_channel_multiplier, out_channel_multiplier) - - def propagate(self, kp): - kp = ConvKP(kp) - def f(patch): - return (F.conv2d(patch, self.kernel, stride=self.stride, - padding=self.padding, dilation=self.dilation) - + self.var_bias) - return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy)) - - def nn(self, channels, in_channels=None, out_channels=None): - if in_channels is None: - in_channels = channels - if out_channels is None: - out_channels = channels - conv2d = nn.Conv2d( - in_channels=in_channels * self.in_channel_multiplier, - out_channels=out_channels * self.out_channel_multiplier, - kernel_size=self.kernel_size + ( - 1 if self.kernel_has_row_of_zeros else 0), - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - bias=(self.var_bias > 0.), - ) - conv2d.weight.data.normal_(0, math.sqrt( - self.var_weight / conv2d.in_channels) / self.kernel_size) - if self.kernel_has_row_of_zeros: - conv2d.weight.data[:, :, 0, :] = 0 - conv2d.weight.data[:, :, :, 0] = 0 - if self.var_bias > 0.: - conv2d.bias.data.normal_(0, math.sqrt(self.var_bias)) - return conv2d - - def layers(self): - return 1 - - -class ReLU(NNGPKernel): - """ - A ReLU nonlinearity, the covariance is numerically stabilised by clamping - values. - """ - f32_tiny = np.finfo(np.float32).tiny - def propagate(self, kp): - kp = NonlinKP(kp) - """ - We need to calculate (xy, xx, yy == c, v₁, v₂): - ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤ - √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂) - - which is equivalent to: - 1/2π ( √(v₁v₂ - c²) + (π - θ)c ) - - # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2) - """ - xx_yy = kp.xx * kp.yy + self.f32_tiny - - # Clamp these so the outputs are not NaN - cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1) - sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0)) - theta = t.acos(cos_theta) - xy = (sin_theta + (math.pi - theta)*kp.xy) / (2*math.pi) - - xx = kp.xx/2. - if kp.same: - yy = xx - if kp.diag: - xy = xx - else: - # Make sure the diagonal agrees with `xx` - eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device) - xy = (1-eye)*xy + eye*xx - else: - yy = kp.yy/2. - return NonlinKP(kp.same, kp.diag, xy, xx, yy) - - def nn(self, channels, in_channels=None, out_channels=None): - assert in_channels is None - assert out_channels is None - return nn.ReLU() - - def layers(self): - return 0 - - -#### Combination classes - -class Sequential(NNGPKernel): - def __init__(self, *mods): - super().__init__() - self.mods = mods - for idx, mod in enumerate(mods): - self.add_module(str(idx), mod) - def propagate(self, kp): - for mod in self.mods: - kp = mod.propagate(kp) - return kp - def nn(self, channels, in_channels=None, out_channels=None): - if len(self.mods) == 0: - return nn.Sequential() - elif len(self.mods) == 1: - return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels) - else: - return nn.Sequential( - self.mods[0].nn(channels, in_channels=in_channels), - *[mod.nn(channels) for mod in self.mods[1:-1]], - self.mods[-1].nn(channels, out_channels=out_channels) - ) - def layers(self): - return sum(mod.layers() for mod in self.mods) - - -class Mixture(NNGPKernel): - """ - Applys multiple modules to the input, and sums the result - (e.g. for the implementation of a ResNet). - - Parameterised by proportion of each module (proportions add - up to one, such that, if each model has average variance 1, - then the output will also have average variance 1. - """ - def __init__(self, mods, logit_proportions=None): - super().__init__() - self.mods = mods - for idx, mod in enumerate(mods): - self.add_module(str(idx), mod) - if logit_proportions is None: - logit_proportions = t.zeros(len(mods)) - self.logit = nn.Parameter(logit_proportions) - def propagate(self, kp): - proportions = F.softmax(self.logit, dim=0) - total = self.mods[0].propagate(kp) * proportions[0] - for i in range(1, len(self.mods)): - total = total + (self.mods[i].propagate(kp) * proportions[i]) - return total - def nn(self, channels, in_channels=None, out_channels=None): - return MixtureModule([mod.nn(channels, in_channels=in_channels, out_channels=out_channels) for mod in self.mods], self.logit) - def layers(self): - return max(mod.layers() for mod in self.mods) - -class MixtureModule(nn.Module): - def __init__(self, mods, logit_parameter): - super().__init__() - self.mods = mods - self.logit = t.tensor(logit_parameter) - for idx, mod in enumerate(mods): - self.add_module(str(idx), mod) - def forward(self, input): - sqrt_proportions = F.softmax(self.logit, dim=0).sqrt() - total = self.mods[0](input)*sqrt_proportions[0] - for i in range(1, len(self.mods)): - total = total + self.mods[i](input) # *sqrt_proportions[i] - return total - - -class Sum(NNGPKernel): - def __init__(self, mods): - super().__init__() - self.mods = mods - for idx, mod in enumerate(mods): - self.add_module(str(idx), mod) - def propagate(self, kp): - # This adds 0 to the first kp, hopefully that's a noop - return sum(m.propagate(kp) for m in self.mods) - def nn(self, channels, in_channels=None, out_channels=None): - return SumModule([ - mod.nn(channels, in_channels=in_channels, out_channels=out_channels) - for mod in self.mods]) - def layers(self): - return max(mod.layers() for mod in self.mods) - - -class SumModule(nn.Module): - def __init__(self, mods): - super().__init__() - self.mods = mods - for idx, mod in enumerate(mods): - self.add_module(str(idx), mod) - def forward(self, input): - # This adds 0 to the first value, hopefully that's a noop - return sum(m(input) for m in self.mods) - - -def resnet_block(stride=1, projection_shortcut=False, multiplier=1): - if stride == 1 and not projection_shortcut: - return Sum([ - Sequential(), - Sequential( - ReLU(), - Conv2d(3, stride=stride, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), - ReLU(), - Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), - ) - ]) - else: - return Sequential( - ReLU(), - Sum([ - Conv2d(1, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), - Sequential( - Conv2d(3, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), - ReLU(), - Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), - ) - ]), - ) From fb9b57a317c6a2f80e9ac35a12f575672a4d27db Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 15:06:33 +0800 Subject: [PATCH 09/24] [MNT] format code by black --- examples/dataset_image_workflow/main.py | 2 +- .../pfs/pfs_cross_transfer.py | 4 +- learnware/specification/image/__init__.py | 2 +- learnware/specification/image/rkme.py | 75 +++++++++---------- 4 files changed, 40 insertions(+), 43 deletions(-) diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 9495498..3f14e05 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -49,7 +49,7 @@ semantic_specs = [ "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, "Name": {"Values": "learnware_1", "Type": "String"}, - "Output": {"Dimension": 10} + "Output": {"Dimension": 10}, } ] diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 5f69127..93a3fa3 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,9 +85,7 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[ - split:, - ], + train_xs[split:,], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/specification/image/__init__.py b/learnware/specification/image/__init__.py index b8bd2d2..fa46969 100644 --- a/learnware/specification/image/__init__.py +++ b/learnware/specification/image/__init__.py @@ -1 +1 @@ -from .rkme import RKMEImageStatSpecification \ No newline at end of file +from .rkme import RKMEImageStatSpecification diff --git a/learnware/specification/image/rkme.py b/learnware/specification/image/rkme.py index 7cf2086..429a8eb 100644 --- a/learnware/specification/image/rkme.py +++ b/learnware/specification/image/rkme.py @@ -32,7 +32,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): cuda_idx : int A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. """ - self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. + self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. # torch.cuda.empty_cache() self.z = None @@ -42,13 +42,15 @@ class RKMEImageStatSpecification(BaseStatSpecification): self.cache = False self.n_models = kwargs["n_models"] if "n_models" in kwargs else 16 - self.model_config = { - "k": 2, "mu": 0, "sigma": None, "net_width": 128, "net_depth": 3 - } if "model_config" not in kwargs else kwargs["model_config"] + self.model_config = ( + {"k": 2, "mu": 0, "sigma": None, "net_width": 128, "net_depth": 3} + if "model_config" not in kwargs + else kwargs["model_config"] + ) setup_seed(0) - def _generate_models(self, n_models: int, channel: int=3, fixed_seed=None): + def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) def __builder(i): @@ -63,11 +65,11 @@ class RKMEImageStatSpecification(BaseStatSpecification): X: np.ndarray, K: int = 50, step_size: float = 0.01, - steps: int=100, + steps: int = 100, resize: bool = False, nonnegative_beta: bool = True, reduce: bool = True, - **kwargs + **kwargs, ): """Construct reduced set from raw dataset using iterative optimization. @@ -92,11 +94,15 @@ class RKMEImageStatSpecification(BaseStatSpecification): ------- """ - if (X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or - X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH) and not resize: - raise ValueError("X should be in shape of [N, C, {0:d}, {0:d}]. " - "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}." - .format(RKMEImageStatSpecification.IMAGE_WIDTH)) + if ( + X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH + ) and not resize: + raise ValueError( + "X should be in shape of [N, C, {0:d}, {0:d}]. " + "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( + RKMEImageStatSpecification.IMAGE_WIDTH + ) + ) if not torch.is_tensor(X): X = torch.from_numpy(X) @@ -112,10 +118,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): img_mean = torch.nanmean(img) X[i] = torch.where(is_nan, img_mean, img) - if (X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or - X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH): - X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, - RKMEImageStatSpecification.IMAGE_WIDTH))(X) + if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: + X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH))(X) num_points = X.shape[0] X_shape = X.shape @@ -140,8 +144,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): x_features = self._generate_random_feature(X_train, random_models=random_models) self._update_beta(x_features, nonnegative_beta, random_models=random_models) - optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], - lr=step_size, eps=1e-16) + optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) for i in range(steps): # Regenerate Random Models @@ -216,8 +219,9 @@ class RKMEImageStatSpecification(BaseStatSpecification): dataloader_Y = DataLoader(dataset_Y, batch_size=batch_size, shuffle=True) assert data_X.shape[1] == data_Y.shape[1] - for m, model in enumerate(random_models if random_models else - self._generate_models(n_models=self.n_models, channel=data_X.shape[1])): + for m, model in enumerate( + random_models if random_models else self._generate_models(n_models=self.n_models, channel=data_X.shape[1]) + ): model.eval() curr_features_list = [] @@ -305,8 +309,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): return K_12 def herding(self, T: int) -> np.ndarray: - raise NotImplementedError( - "The function herding hasn't been supported in Image RKME Specification.") + raise NotImplementedError("The function herding hasn't been supported in Image RKME Specification.") def _sampling_candidates(self, N: int) -> np.ndarray: raise NotImplementedError() @@ -382,19 +385,16 @@ def _get_zca_matrix(X, reg_coef=0.1): reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda()) inv_sqrt_zca_eigs = s ** (-0.5) - whitening_transform = torch.einsum( - 'ij,j,kj->ik', u, inv_sqrt_zca_eigs, u) + whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u) return whitening_transform class _ConvNet_wide(nn.Module): - def __init__(self, channel, mu=None, sigma=None, k=2, net_width=128, - net_depth=3, im_size=(32, 32)): + def __init__(self, channel, mu=None, sigma=None, k=2, net_width=128, net_depth=3, im_size=(32, 32)): self.k = k super().__init__() - self.features, shape_feat = self._make_layers(channel, net_width, net_depth, - im_size, mu, sigma) + self.features, shape_feat = self._make_layers(channel, net_width, net_depth, im_size, mu, sigma) # self.aggregation = nn.AvgPool2d(kernel_size=shape_feat[1]) def forward(self, x): @@ -410,8 +410,7 @@ class _ConvNet_wide(nn.Module): in_channels = channel shape_feat = [in_channels, im_size[0], im_size[1]] for d in range(net_depth): - layers += [_build_conv2d_gaussian(in_channels, int(k * net_width), 3, - 1, mean=mu, std=sigma)] + layers += [_build_conv2d_gaussian(in_channels, int(k * net_width), 3, 1, mean=mu, std=sigma)] shape_feat[0] = int(k * net_width) layers += [nn.ReLU(inplace=True)] @@ -423,29 +422,29 @@ class _ConvNet_wide(nn.Module): return nn.Sequential(*layers), shape_feat + def _build_conv2d_gaussian(in_channels, out_channels, kernel=3, padding=1, mean=None, std=None): layer = nn.Conv2d(in_channels, out_channels, kernel, padding=padding) if mean is None: mean = 0 if std is None: - std = np.sqrt(2)/np.sqrt(layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]) + std = np.sqrt(2) / np.sqrt(layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]) # print('Initializing Conv. Mean=%.2f, std=%.2f'%(mean, std)) torch.nn.init.normal_(layer.weight, mean, std) - torch.nn.init.normal_(layer.bias, 0, .1) + torch.nn.init.normal_(layer.bias, 0, 0.1) return layer -def _build_ConvNet_NNGP(channel, k=2, net_width=128, - net_depth=3, kernel_size=3, im_size=(32, 32), **kwargs): + +def _build_ConvNet_NNGP(channel, k=2, net_width=128, net_depth=3, kernel_size=3, im_size=(32, 32), **kwargs): layers = [] for d in range(net_depth): - layers += [cnn_gp.Conv2d(kernel_size=kernel_size, padding="same", var_bias=0.1, - var_weight=np.sqrt(2))] + layers += [cnn_gp.Conv2d(kernel_size=kernel_size, padding="same", var_bias=0.1, var_weight=np.sqrt(2))] # /np.sqrt(kernel_size * kernel_size * channel) layers += [cnn_gp.ReLU()] # AvgPooling layers += [cnn_gp.Conv2d(kernel_size=2, padding=0, stride=2)] - assert im_size[0] % (2 ** net_depth) == 0 - layers.append(cnn_gp.Conv2d(kernel_size=im_size[0] // (2 ** net_depth), padding=0)) + assert im_size[0] % (2**net_depth) == 0 + layers.append(cnn_gp.Conv2d(kernel_size=im_size[0] // (2**net_depth), padding=0)) return cnn_gp.Sequential(*layers) From e0ac546aeb0da03c54618636b7a588804c143921 Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 15:50:34 +0800 Subject: [PATCH 10/24] [MNT] Add attribute type and dependencies --- learnware/specification/image/rkme.py | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/learnware/specification/image/rkme.py b/learnware/specification/image/rkme.py index 429a8eb..7511f06 100644 --- a/learnware/specification/image/rkme.py +++ b/learnware/specification/image/rkme.py @@ -337,6 +337,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" + rkme_to_save["type"] = self.__class__.__name__ json.dump( rkme_to_save, diff --git a/setup.py b/setup.py index 52a4299..bf5834c 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ REQUIRED = [ "geatpy>=2.7.0", "docker>=6.1.3", "rapidfuzz>=3.4.0", + "torch-optimizer>=0.3.0" ] if get_platform() != MACOS: From 29b1df9fde684f1a602f5f38fca30a9ae0b35462 Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 15:51:07 +0800 Subject: [PATCH 11/24] [MNT] Add more details about cnn_gp --- learnware/specification/image/cnn_gp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/learnware/specification/image/cnn_gp.py b/learnware/specification/image/cnn_gp.py index 2bea13d..4cd176d 100644 --- a/learnware/specification/image/cnn_gp.py +++ b/learnware/specification/image/cnn_gp.py @@ -7,6 +7,10 @@ import math __all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "ConvKP", "NonlinKP") """ +With this package, we are able to accurately and efficiently compute the kernel matrix corresponding to the NNGP during the search phase. + +Github Repository: https://github.com/cambridge-mlg/cnn-gp + References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen, ‘Deep Convolutional Networks as shallow Gaussian Processes’, in International Conference on Learning Representations, 2019. """ From e8d1550a74f14b9dbea78849cf80477577173230 Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 20:38:12 +0800 Subject: [PATCH 12/24] [ENH] add test for image rkme --- tests/test_specification/test_rkme.py | 32 ++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 4cdc246..a5e2eda 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -1,18 +1,19 @@ import os import json +import torch import unittest import tempfile import numpy as np import learnware -import learnware.specification as specification -from learnware.specification import RKMEStatSpecification +from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification class TestRKME(unittest.TestCase): def test_rkme(self): X = np.random.uniform(-10000, 10000, size=(5000, 200)) - rkme = specification.utils.generate_rkme_spec(X) + rkme = RKMEStatSpecification() + rkme.generate_stat_spec_from_data(X) with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: rkme_path = os.path.join(tempdir, "rkme.json") @@ -26,6 +27,31 @@ class TestRKME(unittest.TestCase): rkme2.load(rkme_path) assert rkme2.type == "RKMEStatSpecification" + def test_image_rkme(self): + def _test_image_rkme(X): + image_rkme = RKMEImageStatSpecification() + image_rkme.generate_stat_spec_from_data(X, resize=True) + + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + rkme_path = os.path.join(tempdir, "rkme.json") + rkme.save(rkme_path) + + with open(rkme_path, "r") as f: + data = json.load(f) + assert data["type"] == "RKMEImageStatSpecification" + + rkme2 = RKMEImageStatSpecification() + rkme2.load(rkme_path) + assert rkme2.type == "RKMEImageStatSpecification" + + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) + _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)) / 255) + + _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32))) + _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128))) + _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255) + if __name__ == "__main__": unittest.main() From 4399ba01c1b1358549edfd25c5442c4e1379e01b Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 30 Oct 2023 20:38:35 +0800 Subject: [PATCH 13/24] [MNT] Modify details --- learnware/specification/image/cnn_gp.py | 5 +---- learnware/specification/image/rkme.py | 4 ++-- learnware/specification/table/rkme.py | 1 - 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/learnware/specification/image/cnn_gp.py b/learnware/specification/image/cnn_gp.py index 4cd176d..6ceb7f6 100644 --- a/learnware/specification/image/cnn_gp.py +++ b/learnware/specification/image/cnn_gp.py @@ -11,7 +11,7 @@ With this package, we are able to accurately and efficiently compute the kernel Github Repository: https://github.com/cambridge-mlg/cnn-gp -References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen, ‘Deep Convolutional Networks as shallow Gaussian Processes’, in International Conference on Learning Representations, 2019. +References: [1] A. Garriga-Alonso, L. Aitchison, and C. E. Rasmussen. Deep Convolutional Networks as shallow Gaussian Processes. In: International Conference on Learning Representations (ICLR'19), 2019. """ @@ -187,9 +187,6 @@ class ReLU(NNGPKernel): return 0 -#### Combination classes - - class Sequential(NNGPKernel): def __init__(self, *mods): super().__init__() diff --git a/learnware/specification/image/rkme.py b/learnware/specification/image/rkme.py index 7511f06..3da562f 100644 --- a/learnware/specification/image/rkme.py +++ b/learnware/specification/image/rkme.py @@ -17,7 +17,7 @@ from torchvision.transforms import Resize from . import cnn_gp from ..base import BaseStatSpecification -from ..rkme import solve_qp, choose_device, setup_seed +from ..table.rkme import solve_qp, choose_device, setup_seed class RKMEImageStatSpecification(BaseStatSpecification): @@ -49,6 +49,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): ) setup_seed(0) + super(RKMEImageStatSpecification, self).__init__(type=self.__class__.__name__) def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) @@ -337,7 +338,6 @@ class RKMEImageStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - rkme_to_save["type"] = self.__class__.__name__ json.dump( rkme_to_save, diff --git a/learnware/specification/table/rkme.py b/learnware/specification/table/rkme.py index 9769800..a0422ae 100644 --- a/learnware/specification/table/rkme.py +++ b/learnware/specification/table/rkme.py @@ -428,7 +428,6 @@ class RKMEStatSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - rkme_to_save["type"] = self.type json.dump( rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), From 10dfd3d7c8d2dfe98438f289dc850512d690a030 Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 21:24:25 +0800 Subject: [PATCH 14/24] [MNT] [MNT] format code by black --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bf5834c..80e48bd 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ REQUIRED = [ "geatpy>=2.7.0", "docker>=6.1.3", "rapidfuzz>=3.4.0", - "torch-optimizer>=0.3.0" + "torch-optimizer>=0.3.0", ] if get_platform() != MACOS: From 4cfd2cb7cc88b0e349a02b2cbf5436bb74aaa231 Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 21:37:27 +0800 Subject: [PATCH 15/24] [ENH, Fix] Add generate_rkme_image_spec, fix device bugs --- learnware/specification/__init__.py | 2 +- learnware/specification/image/rkme.py | 18 ++++---- learnware/specification/utils.py | 66 +++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 9 deletions(-) diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index a5ff474..935210b 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,4 +1,4 @@ -from .utils import generate_stat_spec +from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification from .image import RKMEImageStatSpecification from .table import RKMEStatSpecification diff --git a/learnware/specification/image/rkme.py b/learnware/specification/image/rkme.py index 3da562f..d651b6a 100644 --- a/learnware/specification/image/rkme.py +++ b/learnware/specification/image/rkme.py @@ -14,6 +14,7 @@ import torch_optimizer from torch import nn from torch.utils.data import TensorDataset, DataLoader from torchvision.transforms import Resize +from tqdm import tqdm from . import cnn_gp from ..base import BaseStatSpecification @@ -33,7 +34,6 @@ class RKMEImageStatSpecification(BaseStatSpecification): A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. """ self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. - # torch.cuda.empty_cache() self.z = None self.beta = None @@ -67,9 +67,10 @@ class RKMEImageStatSpecification(BaseStatSpecification): K: int = 50, step_size: float = 0.01, steps: int = 100, - resize: bool = False, + resize: bool = True, nonnegative_beta: bool = True, reduce: bool = True, + verbose: bool = True, **kwargs, ): """Construct reduced set from raw dataset using iterative optimization. @@ -77,7 +78,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): Parameters ---------- X : np.ndarray or torch.tensor - Raw data in np.ndarray format. + Raw data in [N, C, H, W] format. K : int Size of the construced reduced set. step_size : float @@ -90,7 +91,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): True if weights for the reduced set are intended to be kept non-negative, by default False. reduce : bool, optional Whether shrink original data to a smaller set, by default True - + verbose : bool, optional + Whether to print training progress, by default True Returns ------- @@ -107,7 +109,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): if not torch.is_tensor(X): X = torch.from_numpy(X) - X = X.to(self.device) + X = X.to(self.device).float() X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan if torch.any(torch.isnan(X)): @@ -120,7 +122,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): X[i] = torch.where(is_nan, img_mean, img) if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: - X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH))(X) + X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None)(X) num_points = X.shape[0] X_shape = X.shape @@ -147,7 +149,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) - for i in range(steps): + for _ in tqdm(range(steps)) if verbose else range(steps): # Regenerate Random Models random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) @@ -384,7 +386,7 @@ def _get_zca_matrix(X, reg_coef=0.1): X_flat = X.reshape(X.shape[0], -1) cov = (X_flat.T @ X_flat) / X_flat.shape[0] reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] - u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda()) + u, s, _ = torch.svd(cov + reg_amount * torch.eye(cov.shape[0]).to(X.device)) inv_sqrt_zca_eigs = s ** (-0.5) whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u) diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index c3693b7..d480de9 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from typing import Union +from .image import RKMEImageStatSpecification from .base import BaseStatSpecification from .table import RKMEStatSpecification from ..config import C @@ -99,6 +100,71 @@ def generate_rkme_spec( return rkme_spec +def generate_rkme_image_spec( + X: Union[np.ndarray, torch.Tensor], + reduced_set_size: int = 50, + step_size: float = 0.01, + steps: int = 100, + resize: bool = True, + nonnegative_beta: bool = True, + reduce: bool = True, + verbose: bool = True, + cuda_idx: int = None, +) -> RKMEImageStatSpecification: + """ + Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. + Return a RKMEImageStatSpecification object, use .save() method to save as json file. + + Parameters + ---------- + X : np.ndarray, or torch.Tensor + Raw data in np.ndarray, or torch.Tensor format. + The shape of X: [N, C, H, W] + N: Number of images. + C: Number of channels. + H: Height of images. + W: Width of images.s + For example, if X has shape (100, 3, 32, 32), it means there are 100 samples, and each sample is a 3-channel (RGB) image of size 32x32. + reduced_set_size : int + Size of the construced reduced set. + step_size : float + Step size for gradient descent in the iterative optimization. + steps : int + Total rounds in the iterative optimization. + resize : bool + Whether to scale the image to the requested size, by default True. + nonnegative_beta : bool, optional + True if weights for the reduced set are intended to be kept non-negative, by default False. + reduce : bool, optional + Whether shrink original data to a smaller set, by default True + cuda_idx : int + A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. + None indicates that CUDA is automatically selected. + verbose : bool, optional + Whether to print training progress, by default True + + Returns + ------- + RKMEImageStatSpecification + A RKMEImageStatSpecification object + """ + + # Check cuda_idx + if not torch.cuda.is_available() or cuda_idx == -1: + cuda_idx = -1 + else: + num_cuda_devices = torch.cuda.device_count() + if cuda_idx is None or not (0 <= cuda_idx < num_cuda_devices): + cuda_idx = 0 + + # Generate rkme spec + rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx) + rkme_image_spec.generate_stat_spec_from_data( + X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose + ) + return rkme_image_spec + + def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification: """ Interface for users to generate statistical specification. From 9daba5e1adad5cd4a62fcf731cf7e66b6227504c Mon Sep 17 00:00:00 2001 From: shihy Date: Mon, 30 Oct 2023 21:38:42 +0800 Subject: [PATCH 16/24] [MNT] Use generate_rkme_image_spec --- tests/test_specification/test_rkme.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index a5e2eda..bdf8e74 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -7,12 +7,13 @@ import numpy as np import learnware from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification +from learnware.specification import generate_rkme_image_spec, generate_rkme_spec class TestRKME(unittest.TestCase): def test_rkme(self): X = np.random.uniform(-10000, 10000, size=(5000, 200)) - rkme = RKMEStatSpecification() + rkme = generate_rkme_spec(X) rkme.generate_stat_spec_from_data(X) with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: @@ -29,12 +30,11 @@ class TestRKME(unittest.TestCase): def test_image_rkme(self): def _test_image_rkme(X): - image_rkme = RKMEImageStatSpecification() - image_rkme.generate_stat_spec_from_data(X, resize=True) + image_rkme = generate_rkme_image_spec(X, steps=10) with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: rkme_path = os.path.join(tempdir, "rkme.json") - rkme.save(rkme_path) + image_rkme.save(rkme_path) with open(rkme_path, "r") as f: data = json.load(f) From f91c247159ae9ee6956d192bf93a0d66f9e7f9e1 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 15:39:56 +0800 Subject: [PATCH 17/24] [MNT] fix some details --- .../dataset_image_workflow/example_files/example_yaml.yaml | 2 +- learnware/market/easy2/organizer.py | 2 +- learnware/specification/__init__.py | 2 +- learnware/specification/regular/__init__.py | 1 + learnware/specification/regular/table/rkme.py | 3 +-- tests/test_specification/test_rkme.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/dataset_image_workflow/example_files/example_yaml.yaml b/examples/dataset_image_workflow/example_files/example_yaml.yaml index 2f2b4cd..c49ba4a 100644 --- a/examples/dataset_image_workflow/example_files/example_yaml.yaml +++ b/examples/dataset_image_workflow/example_files/example_yaml.yaml @@ -2,7 +2,7 @@ model: class_name: Model kwargs: {} stat_specifications: - - module_path: learnware.specification.image + - module_path: learnware.specification class_name: RKMEImageStatSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 830b5d3..9b6bf8c 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -20,7 +20,7 @@ from ... import utils from ...config import C as conf from ...logger import get_module_logger from ...learnware import Learnware, get_learnware_from_dirpath -from ...specification import RKMEStatSpecification, Specification +from ...specification import Specification from ..base import BaseOrganizer, BaseChecker from ...logger import get_module_logger diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index c54cafc..f086c61 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RKMEStatSpecification, RKMEImageStatSpecification +from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMEImageStatSpecification diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index ba6c866..29e78c2 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,2 +1,3 @@ from .table import RKMEStatSpecification from .image import RKMEImageStatSpecification +from .base import RegularStatsSpecification \ No newline at end of file diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index e001dd5..da3e3c6 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -26,8 +26,7 @@ from ....logger import get_module_logger logger = get_module_logger("rkme") if not _FAISS_INSTALLED: - logger.warning("Required faiss version >= 1.7.1 is not detected!") - logger.warning('Please run "conda install -c pytorch faiss-cpu" first.') + logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") class RKMEStatSpecification(RegularStatsSpecification): diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index bdf8e74..89905f6 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -5,13 +5,13 @@ import unittest import tempfile import numpy as np -import learnware from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification from learnware.specification import generate_rkme_image_spec, generate_rkme_spec class TestRKME(unittest.TestCase): def test_rkme(self): + pass X = np.random.uniform(-10000, 10000, size=(5000, 200)) rkme = generate_rkme_spec(X) rkme.generate_stat_spec_from_data(X) From 6b1694825c4adfcc7bbed30d141665919ed0261c Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 16:08:47 +0800 Subject: [PATCH 18/24] [MNT] rename table and image rkme --- README.md | 6 +-- docs/references/api.rst | 2 +- docs/start/client.rst | 8 ++-- docs/start/quick.rst | 6 +-- docs/workflow/identify.rst | 4 +- docs/workflow/submit.rst | 4 +- .../example_files/example_yaml.yaml | 2 +- examples/dataset_image_workflow/main.py | 8 ++-- examples/dataset_m5_workflow/example.yaml | 2 +- examples/dataset_m5_workflow/main.py | 2 +- examples/dataset_pfs_workflow/example.yaml | 2 +- examples/dataset_pfs_workflow/main.py | 2 +- .../learnware_example/example.yaml | 2 +- examples/workflow_by_code/main.py | 6 +-- learnware/learnware/__init__.py | 2 +- learnware/market/easy.py | 42 +++++++++---------- learnware/market/easy2/checker.py | 2 +- learnware/market/easy2/searcher.py | 40 +++++++++--------- learnware/reuse/job_selector.py | 8 ++-- learnware/specification/__init__.py | 2 +- learnware/specification/regular/__init__.py | 4 +- .../specification/regular/image/__init__.py | 2 +- learnware/specification/regular/image/rkme.py | 24 +++++------ .../specification/regular/table/__init__.py | 2 +- learnware/specification/regular/table/rkme.py | 17 +++++--- learnware/specification/utils.py | 22 +++++----- .../learnware_example/example.yaml | 2 +- tests/test_market/test_easy.py | 4 +- tests/test_specification/test_rkme.py | 14 +++---- .../learnware_example/example.yaml | 2 +- tests/test_workflow/test_workflow.py | 6 +-- 31 files changed, 128 insertions(+), 123 deletions(-) diff --git a/README.md b/README.md index 5629530..a01c2d7 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ is composed of the following four parts. - ``learnware.yaml`` - A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and + A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and the file name of your statistical specification file. - ``environment.yaml`` @@ -178,10 +178,10 @@ For example, the following code is designed to work with Reduced Set Kernel Embe ```python import learnware.specification as specification -user_spec = specification.RKMEStatSpecification() +user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/references/api.rst b/docs/references/api.rst index a2f723b..20de7bc 100644 --- a/docs/references/api.rst +++ b/docs/references/api.rst @@ -50,7 +50,7 @@ Specification .. autoclass:: learnware.specification.BaseStatSpecification :members: -.. autoclass:: learnware.specification.RKMEStatSpecification +.. autoclass:: learnware.specification.RKMETableSpecification :members: Model diff --git a/docs/start/client.rst b/docs/start/client.rst index cbe1f2e..664fdf1 100644 --- a/docs/start/client.rst +++ b/docs/start/client.rst @@ -117,13 +117,13 @@ You can search learnwares in official market using semantic specification. All t Statistical Specification Search --------------------------------- -You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMEStatSpecification`: +You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMETableSpecification`: .. code-block:: python import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() @@ -138,7 +138,7 @@ You can search learnware by providing a statistical specification. The statistic Combine Semantic and Statistical Search ---------------------------------------- -You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMEStatSpecification`: +You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMETableSpecification`: .. code-block:: python @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares senarioes=[], input_description={}, output_description={}) - stat_spec = specification.RKMEStatSpecification() + stat_spec = specification.RKMETableSpecification() stat_spec.load(os.path.join(unzip_path, "rkme.json")) specification = learnware.specification.Specification() specification.update_semantic_spec(semantic_spec) diff --git a/docs/start/quick.rst b/docs/start/quick.rst index 2140aaa..6d8a7a8 100644 --- a/docs/start/quick.rst +++ b/docs/start/quick.rst @@ -47,7 +47,7 @@ includes the following four components: - ``learnware.yaml`` - A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMEStatSpecification`` for Reduced Kernel Mean Embedding), and + A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMETableSpecification`` for Reduced Kernel Mean Embedding), and the file name of your statistical specification file. - ``environment.yaml`` or ``requirements.txt`` @@ -170,12 +170,12 @@ For example, the code below executes learnware search when using Reduced Set Ker import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() # unzip_path: directory for unzipped learnware zipfile user_spec.load(os.path.join(unzip_path, "rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/workflow/identify.rst b/docs/workflow/identify.rst index ffd7dbb..ed7bb55 100644 --- a/docs/workflow/identify.rst +++ b/docs/workflow/identify.rst @@ -73,10 +73,10 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb import learnware.specification as specification - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join("rkme.json")) user_info = BaseUserInfo( - semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} + semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} ) (sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) diff --git a/docs/workflow/submit.rst b/docs/workflow/submit.rst index 928e108..fe097c3 100644 --- a/docs/workflow/submit.rst +++ b/docs/workflow/submit.rst @@ -94,7 +94,7 @@ guaranteeing the security and privacy of your local original data. ------------------ Additionally, you are asked to prepare a configuration file in YAML format. -The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and +The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of how your learnware configuration file should be structured, based on our previous discussion: @@ -105,7 +105,7 @@ how your learnware configuration file should be structured, based on our previou kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: stat.json kwargs: {} diff --git a/examples/dataset_image_workflow/example_files/example_yaml.yaml b/examples/dataset_image_workflow/example_files/example_yaml.yaml index c49ba4a..9aaf820 100644 --- a/examples/dataset_image_workflow/example_files/example_yaml.yaml +++ b/examples/dataset_image_workflow/example_files/example_yaml.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEImageStatSpecification + class_name: RKMEImageSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_image_workflow/main.py b/examples/dataset_image_workflow/main.py index 3f14e05..74e125b 100644 --- a/examples/dataset_image_workflow/main.py +++ b/examples/dataset_image_workflow/main.py @@ -6,7 +6,7 @@ from get_data import * import os import random -from learnware.specification.image import RKMEImageStatSpecification +from learnware.specification.image import RKMEImageSpecification from learnware.reuse.averaging import AveragingReuser from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction from learnware.learnware import Learnware @@ -100,7 +100,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo X_sampled = X[indices] st = time.time() - user_spec = RKMEImageStatSpecification(cuda_idx=0) + user_spec = RKMEImageSpecification(cuda_idx=0) user_spec.generate_stat_spec_from_data(X=X_sampled) ed = time.time() logger.info("Stat spec generated in %.3f s" % (ed - st)) @@ -164,9 +164,9 @@ def test_search(gamma=0.1, load_market=True): user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) user_data = np.load(user_data_path) user_label = np.load(user_label_path) - user_stat_spec = RKMEImageStatSpecification(cuda_idx=0) + user_stat_spec = RKMEImageSpecification(cuda_idx=0) user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) logger.info("Searching Market for user: %d" % i) sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( user_info diff --git a/examples/dataset_m5_workflow/example.yaml b/examples/dataset_m5_workflow/example.yaml index 6ca01c9..cd539c8 100644 --- a/examples/dataset_m5_workflow/example.yaml +++ b/examples/dataset_m5_workflow/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_m5_workflow/main.py b/examples/dataset_m5_workflow/main.py index a720b30..009b557 100644 --- a/examples/dataset_m5_workflow/main.py +++ b/examples/dataset_m5_workflow/main.py @@ -144,7 +144,7 @@ class M5DatasetWorkflow: user_spec_path = f"./user_spec/user_{idx}.json" user_spec.save(user_spec_path) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/examples/dataset_pfs_workflow/example.yaml b/examples/dataset_pfs_workflow/example.yaml index 6ca01c9..cd539c8 100644 --- a/examples/dataset_pfs_workflow/example.yaml +++ b/examples/dataset_pfs_workflow/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: rkme.json kwargs: {} \ No newline at end of file diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index a465241..48ff7d0 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -142,7 +142,7 @@ class PFSDatasetWorkflow: user_spec_path = f"./user_spec/user_{idx}.json" user_spec.save(user_spec_path) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/examples/workflow_by_code/learnware_example/example.yaml b/examples/workflow_by_code/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/examples/workflow_by_code/learnware_example/example.yaml +++ b/examples/workflow_by_code/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 29d2e69..2f62db0 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -148,9 +148,9 @@ class LearnwareMarketWorkflow: with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, @@ -175,7 +175,7 @@ class LearnwareMarketWorkflow: _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index 32ef7bd..d3dd704 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -37,7 +37,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: "stat_specifications": [ { "module_path": "learnware.specification", - "class_name": "RKMEStatSpecification", + "class_name": "RKMETableSpecification", "file_name": "stat_spec.json", "kwargs": {}, }, diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 591cf05..957efda 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -18,7 +18,7 @@ from .. import utils from ..config import C as conf from ..logger import get_module_logger from ..learnware import Learnware, get_learnware_from_dirpath -from ..specification import RKMEStatSpecification, Specification +from ..specification import RKMETableSpecification, Specification logger = get_module_logger("market", "INFO") @@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket): pass # check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") @@ -296,7 +296,7 @@ class EasyMarket(LearnwareMarket): def _calculate_rkme_spec_mixture_weight( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[List[float], float]: @@ -306,7 +306,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] A list of existing learnwares - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket): """ learnware_num = len(learnware_list) RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] if type(intermediate_K) == np.ndarray: @@ -365,7 +365,7 @@ class EasyMarket(LearnwareMarket): def _calculate_intermediate_K_and_C( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray]: @@ -375,7 +375,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares up till now - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket): """ num = intermediate_K.shape[0] - 1 RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) @@ -400,7 +400,7 @@ class EasyMarket(LearnwareMarket): def _search_by_rkme_spec_mixture_auto( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, weight_cutoff: float = 0.98, ) -> Tuple[float, List[float], List[Learnware]]: @@ -410,7 +410,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -488,7 +488,7 @@ class EasyMarket(LearnwareMarket): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -496,7 +496,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification Returns @@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -518,7 +518,7 @@ class EasyMarket(LearnwareMarket): def _search_by_rkme_spec_mixture_greedy( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, score_cutoff: float = 0.001, ) -> Tuple[float, List[float], List[Learnware]]: @@ -528,7 +528,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -588,7 +588,7 @@ class EasyMarket(LearnwareMarket): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -596,7 +596,7 @@ class EasyMarket(LearnwareMarket): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification user RKME statistical specification Returns @@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket): both lists are sorted by mmd dist """ RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] mmd_dist_list = [] for RKME in RKME_list: @@ -819,12 +819,12 @@ class EasyMarket(LearnwareMarket): # if len(learnware_list) == 0: learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) - if "RKMEStatSpecification" not in user_info.stat_info: + if "RKMETableSpecification" not in user_info.stat_info: return None, learnware_list, 0.0, None elif len(learnware_list) == 0: return [], [], 0.0, [] else: - user_rkme = user_info.stat_info["RKMEStatSpecification"] + user_rkme = user_info.stat_info["RKMETableSpecification"] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 7f26b91..5bccfcd 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -77,7 +77,7 @@ class EasyStatisticalChecker(BaseChecker): input_shape = learnware_model.input_shape # Check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index 9d758fc..dcf3335 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -7,7 +7,7 @@ from typing import Tuple, List from .organizer import EasyOrganizer from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMEStatSpecification +from ...specification import RKMETableSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -227,7 +227,7 @@ class EasyTableSearcher(BaseSearcher): def _calculate_rkme_spec_mixture_weight( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[List[float], float]: @@ -237,7 +237,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] A list of existing learnwares - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -252,7 +252,7 @@ class EasyTableSearcher(BaseSearcher): """ learnware_num = len(learnware_list) RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] if type(intermediate_K) == np.ndarray: @@ -296,7 +296,7 @@ class EasyTableSearcher(BaseSearcher): def _calculate_intermediate_K_and_C( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, intermediate_K: np.ndarray = None, intermediate_C: np.ndarray = None, ) -> Tuple[np.ndarray, np.ndarray]: @@ -306,7 +306,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares up till now - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification intermediate_K : np.ndarray, optional Intermediate kernel matrix K, by default None @@ -321,7 +321,7 @@ class EasyTableSearcher(BaseSearcher): """ num = intermediate_K.shape[0] - 1 RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) @@ -331,7 +331,7 @@ class EasyTableSearcher(BaseSearcher): def _search_by_rkme_spec_mixture_auto( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, weight_cutoff: float = 0.98, ) -> Tuple[float, List[float], List[Learnware]]: @@ -341,7 +341,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -377,7 +377,7 @@ class EasyTableSearcher(BaseSearcher): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification Returns @@ -439,7 +439,7 @@ class EasyTableSearcher(BaseSearcher): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -449,7 +449,7 @@ class EasyTableSearcher(BaseSearcher): def _search_by_rkme_spec_mixture_greedy( self, learnware_list: List[Learnware], - user_rkme: RKMEStatSpecification, + user_rkme: RKMETableSpecification, max_search_num: int, score_cutoff: float = 0.001, ) -> Tuple[float, List[float], List[Learnware]]: @@ -459,7 +459,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification User RKME statistical specification max_search_num : int The maximum number of the returned learnwares @@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification + self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMEStatSpecification + user_rkme : RKMETableSpecification user RKME statistical specification Returns @@ -538,7 +538,7 @@ class EasyTableSearcher(BaseSearcher): both lists are sorted by mmd dist """ RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] mmd_dist_list = [] for RKME in RKME_list: @@ -558,7 +558,7 @@ class EasyTableSearcher(BaseSearcher): max_search_num: int = 5, search_method: str = "greedy", ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: - user_rkme = user_info.stat_info["RKMEStatSpecification"] + user_rkme = user_info.stat_info["RKMETableSpecification"] learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") @@ -631,7 +631,7 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] - elif "RKMEStatSpecification" in user_info.stat_info: + elif "RKMETableSpecification" in user_info.stat_info: return self.table_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index e786e15..21745fe 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score from learnware.learnware import Learnware import learnware.specification as specification from .base import BaseReuser -from ..specification import RKMEStatSpecification +from ..specification import RKMETableSpecification from ..logger import get_module_logger logger = get_module_logger("job_selector_reuse") @@ -86,7 +86,7 @@ class JobSelectorReuser(BaseReuser): return np.array([0] * user_data_num) else: learnware_rkme_spec_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in self.learnware_list ] @@ -154,7 +154,7 @@ class JobSelectorReuser(BaseReuser): return job_select_result def _calculate_rkme_spec_mixture_weight( - self, user_data: np.ndarray, task_rkme_list: List[RKMEStatSpecification], task_rkme_matrix: np.ndarray + self, user_data: np.ndarray, task_rkme_list: List[RKMETableSpecification], task_rkme_matrix: np.ndarray ) -> List[float]: """_summary_ @@ -162,7 +162,7 @@ class JobSelectorReuser(BaseReuser): ---------- user_data : np.ndarray Raw user data. - task_rkme_list : List[RKMEStatSpecification] + task_rkme_list : List[RKMETableSpecification] The list of learwares' rkmes whose mixture approximates the user's rkme task_rkme_matrix : np.ndarray Inner product matrix calculated from task_rkme_list. diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index f086c61..269fe5f 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,3 +1,3 @@ from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .base import Specification, BaseStatSpecification -from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMEImageStatSpecification +from .regular import RegularStatsSpecification, RKMETableSpecification, RKMEImageSpecification diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 29e78c2..4373eb0 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,3 +1,3 @@ -from .table import RKMEStatSpecification -from .image import RKMEImageStatSpecification +from .table import RKMETableSpecification, RKMEStatSpecification +from .image import RKMEImageSpecification from .base import RegularStatsSpecification \ No newline at end of file diff --git a/learnware/specification/regular/image/__init__.py b/learnware/specification/regular/image/__init__.py index fa46969..0a18ded 100644 --- a/learnware/specification/regular/image/__init__.py +++ b/learnware/specification/regular/image/__init__.py @@ -1 +1 @@ -from .rkme import RKMEImageStatSpecification +from .rkme import RKMEImageSpecification diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index e65dc27..e0454da 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -21,7 +21,7 @@ from ..base import BaseStatSpecification from ..table.rkme import solve_qp, choose_device, setup_seed -class RKMEImageStatSpecification(BaseStatSpecification): +class RKMEImageSpecification(BaseStatSpecification): # INNER_PRODUCT_COUNT = 0 IMAGE_WIDTH = 32 @@ -49,7 +49,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): ) setup_seed(0) - super(RKMEImageStatSpecification, self).__init__(type=self.__class__.__name__) + super(RKMEImageSpecification, self).__init__(type=self.__class__.__name__) def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) @@ -98,12 +98,12 @@ class RKMEImageStatSpecification(BaseStatSpecification): """ if ( - X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH + X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH ) and not resize: raise ValueError( "X should be in shape of [N, C, {0:d}, {0:d}]. " "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( - RKMEImageStatSpecification.IMAGE_WIDTH + RKMEImageSpecification.IMAGE_WIDTH ) ) @@ -121,9 +121,9 @@ class RKMEImageStatSpecification(BaseStatSpecification): img_mean = torch.nanmean(img) X[i] = torch.where(is_nan, img_mean, img) - if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: + if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: X = Resize( - (RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None + (RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None )(X) num_points = X.shape[0] @@ -253,12 +253,12 @@ class RKMEImageStatSpecification(BaseStatSpecification): Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) return X_features, Y_features - def inner_prod(self, Phi2: RKMEImageStatSpecification) -> float: + def inner_prod(self, Phi2: RKMEImageSpecification) -> float: """Compute the inner product between two RKME Image specifications Parameters ---------- - Phi2 : RKMEImageStatSpecification + Phi2 : RKMEImageSpecification The other RKME Image specification. Returns @@ -269,7 +269,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): v = self._inner_prod_nngp(Phi2) return v - def _inner_prod_nngp(self, Phi2: RKMEImageStatSpecification) -> float: + def _inner_prod_nngp(self, Phi2: RKMEImageSpecification) -> float: beta_1 = self.beta.reshape(1, -1).detach().to(self.device) beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) @@ -283,15 +283,15 @@ class RKMEImageStatSpecification(BaseStatSpecification): K_zz = kernel_fn(Z1, Z2) v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() - # RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 + # RKMEImageSpecification.INNER_PRODUCT_COUNT += 1 return v - def dist(self, Phi2: RKMEImageStatSpecification, omit_term1: bool = False) -> float: + def dist(self, Phi2: RKMEImageSpecification, omit_term1: bool = False) -> float: """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications Parameters ---------- - Phi2 : RKMEImageStatSpecification + Phi2 : RKMEImageSpecification The other RKME specification. omit_term1 : bool, optional True if the inner product of self with itself can be omitted, by default False. diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index dc94b1e..3cc9bd0 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1 +1 @@ -from .rkme import RKMEStatSpecification +from .rkme import RKMETableSpecification diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index da3e3c6..ab763d8 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -29,7 +29,7 @@ if not _FAISS_INSTALLED: logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") -class RKMEStatSpecification(RegularStatsSpecification): +class RKMETableSpecification(RegularStatsSpecification): """Reduced Kernel Mean Embedding (RKME) Specification""" def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): @@ -50,7 +50,7 @@ class RKMEStatSpecification(RegularStatsSpecification): torch.cuda.empty_cache() self.device = choose_device(cuda_idx=cuda_idx) setup_seed(0) - super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) + super(RKMETableSpecification, self).__init__(type=self.__class__.__name__) def get_beta(self) -> np.ndarray: """Move beta(RKME weights) back to memory accessible to the CPU. @@ -333,12 +333,12 @@ class RKMEStatSpecification(RegularStatsSpecification): else: logger.warning("Not enough candidates for herding!") - def inner_prod(self, Phi2: RKMEStatSpecification) -> float: + def inner_prod(self, Phi2: RKMETableSpecification) -> float: """Compute the inner product between two RKME specifications Parameters ---------- - Phi2 : RKMEStatSpecification + Phi2 : RKMETableSpecification The other RKME specification. Returns @@ -354,12 +354,12 @@ class RKMEStatSpecification(RegularStatsSpecification): return float(v) - def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float: + def dist(self, Phi2: RKMETableSpecification, omit_term1: bool = False) -> float: """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications Parameters ---------- - Phi2 : RKMEStatSpecification + Phi2 : RKMETableSpecification The other RKME specification. omit_term1 : bool, optional True if the inner product of self with itself can be omitted, by default False. @@ -463,6 +463,11 @@ class RKMEStatSpecification(RegularStatsSpecification): else: return False +class RKMEStatSpecification(RKMETableSpecification): + """nickname for RKMETableSpecification, for compatibility currently. + TODO: modify all learnware in database and remove this nickname + """ + pass def setup_seed(seed): """Fix a random seed for addressing reproducibility issues. diff --git a/learnware/specification/utils.py b/learnware/specification/utils.py index f0108ab..91fe226 100644 --- a/learnware/specification/utils.py +++ b/learnware/specification/utils.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Union from .base import BaseStatSpecification -from .regular import RKMEStatSpecification, RKMEImageStatSpecification +from .regular import RKMETableSpecification, RKMEImageSpecification from ..config import C @@ -42,10 +42,10 @@ def generate_rkme_spec( nonnegative_beta: bool = True, reduce: bool = True, cuda_idx: int = None, -) -> RKMEStatSpecification: +) -> RKMETableSpecification: """ Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification. - Return a RKMEStatSpecification object, use .save() method to save as json file. + Return a RKMETableSpecification object, use .save() method to save as json file. Parameters ---------- @@ -73,8 +73,8 @@ def generate_rkme_spec( Returns ------- - RKMEStatSpecification - A RKMEStatSpecification object + RKMETableSpecification + A RKMETableSpecification object """ # Convert data type X = convert_to_numpy(X) @@ -94,7 +94,7 @@ def generate_rkme_spec( cuda_idx = 0 # Generate rkme spec - rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx) + rkme_spec = RKMETableSpecification(gamma=gamma, cuda_idx=cuda_idx) rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) return rkme_spec @@ -109,10 +109,10 @@ def generate_rkme_image_spec( reduce: bool = True, verbose: bool = True, cuda_idx: int = None, -) -> RKMEImageStatSpecification: +) -> RKMEImageSpecification: """ Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. - Return a RKMEImageStatSpecification object, use .save() method to save as json file. + Return a RKMEImageSpecification object, use .save() method to save as json file. Parameters ---------- @@ -144,8 +144,8 @@ def generate_rkme_image_spec( Returns ------- - RKMEImageStatSpecification - A RKMEImageStatSpecification object + RKMEImageSpecification + A RKMEImageSpecification object """ # Check cuda_idx @@ -157,7 +157,7 @@ def generate_rkme_image_spec( cuda_idx = 0 # Generate rkme spec - rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx) + rkme_image_spec = RKMEImageSpecification(cuda_idx=cuda_idx) rkme_image_spec.generate_stat_spec_from_data( X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose ) diff --git a/tests/test_market/learnware_example/example.yaml b/tests/test_market/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/tests/test_market/learnware_example/example.yaml +++ b/tests/test_market/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/tests/test_market/test_easy.py b/tests/test_market/test_easy.py index 5f22729..16729e2 100644 --- a/tests/test_market/test_easy.py +++ b/tests/test_market/test_easy.py @@ -170,9 +170,9 @@ class TestMarket(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.rkme.RKMEStatSpecification() + user_spec = specification.rkme.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 89905f6..613e40b 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -5,7 +5,7 @@ import unittest import tempfile import numpy as np -from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification +from learnware.specification import RKMETableSpecification, RKMEImageSpecification from learnware.specification import generate_rkme_image_spec, generate_rkme_spec @@ -22,11 +22,11 @@ class TestRKME(unittest.TestCase): with open(rkme_path, "r") as f: data = json.load(f) - assert data["type"] == "RKMEStatSpecification" + assert data["type"] == "RKMETableSpecification" - rkme2 = RKMEStatSpecification() + rkme2 = RKMETableSpecification() rkme2.load(rkme_path) - assert rkme2.type == "RKMEStatSpecification" + assert rkme2.type == "RKMETableSpecification" def test_image_rkme(self): def _test_image_rkme(X): @@ -38,11 +38,11 @@ class TestRKME(unittest.TestCase): with open(rkme_path, "r") as f: data = json.load(f) - assert data["type"] == "RKMEImageStatSpecification" + assert data["type"] == "RKMEImageSpecification" - rkme2 = RKMEImageStatSpecification() + rkme2 = RKMEImageSpecification() rkme2.load(rkme_path) - assert rkme2.type == "RKMEImageStatSpecification" + assert rkme2.type == "RKMEImageSpecification" _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) diff --git a/tests/test_workflow/learnware_example/example.yaml b/tests/test_workflow/learnware_example/example.yaml index 254bca4..32aa52e 100644 --- a/tests/test_workflow/learnware_example/example.yaml +++ b/tests/test_workflow/learnware_example/example.yaml @@ -3,6 +3,6 @@ model: kwargs: {} stat_specifications: - module_path: learnware.specification - class_name: RKMEStatSpecification + class_name: RKMETableSpecification file_name: svm.json kwargs: {} \ No newline at end of file diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 1da7db3..f4507c5 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -155,9 +155,9 @@ class TestAllWorkflow(unittest.TestCase): with zipfile.ZipFile(zip_path, "r") as zip_obj: zip_obj.extractall(path=unzip_dir) - user_spec = specification.RKMEStatSpecification() + user_spec = specification.RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) ( sorted_score_list, single_learnware_list, @@ -182,7 +182,7 @@ class TestAllWorkflow(unittest.TestCase): train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) - user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) + user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) From fda945f5cc76a6f765d3200f6fe110203a2f8fcb Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 16:38:38 +0800 Subject: [PATCH 19/24] [MNT, FIX] modify typehint for easysearch, and fix file not close warning --- learnware/market/easy2/organizer.py | 17 ++-------------- learnware/market/easy2/searcher.py | 20 +++++++++---------- learnware/specification/regular/__init__.py | 2 +- learnware/specification/regular/base.py | 2 ++ learnware/specification/regular/image/rkme.py | 11 +++------- .../specification/regular/table/__init__.py | 2 +- learnware/specification/regular/table/rkme.py | 7 ++++++- tests/test_specification/test_rkme.py | 1 - 8 files changed, 25 insertions(+), 37 deletions(-) diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 9b6bf8c..18f67eb 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -1,28 +1,15 @@ import os -import json import copy -import torch import zipfile -import traceback import tempfile -import numpy as np -import pandas as pd -from rapidfuzz import fuzz -from cvxopt import solvers, matrix from shutil import copyfile, rmtree -from typing import Tuple, Any, List, Union, Dict +from typing import Tuple, List, Union from .database_ops import DatabaseOperations -from ..base import LearnwareMarket, BaseUserInfo - - -from ... import utils +from ..base import BaseOrganizer, BaseChecker from ...config import C as conf from ...logger import get_module_logger from ...learnware import Learnware, get_learnware_from_dirpath -from ...specification import Specification - -from ..base import BaseOrganizer, BaseChecker from ...logger import get_module_logger logger = get_module_logger("easy_organizer") diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index dcf3335..aa741e3 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -2,12 +2,12 @@ import torch import numpy as np from rapidfuzz import fuzz from cvxopt import solvers, matrix -from typing import Tuple, List +from typing import Tuple, List, Union from .organizer import EasyOrganizer from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMETableSpecification +from ...specification import RKMETableSpecification, RKMEImageSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -188,7 +188,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): return final_result -class EasyTableSearcher(BaseSearcher): +class EasyStatSearcher(BaseSearcher): def _convert_dist_to_score( self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 ) -> List[float]: @@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMETableSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] User RKME statistical specification Returns @@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMETableSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] user RKME statistical specification Returns @@ -599,12 +599,12 @@ class EasySearcher(BaseSearcher): def __init__(self, organizer: EasyOrganizer = None): super(EasySearcher, self).__init__(organizer) self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) - self.table_searcher = EasyTableSearcher(organizer) + self.stat_searcher = EasyStatSearcher(organizer) def reset(self, organizer): self.learnware_oganizer = organizer self.semantic_searcher.reset(organizer) - self.table_searcher.reset(organizer) + self.stat_searcher.reset(organizer) def __call__( self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" @@ -632,6 +632,6 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 4373eb0..eeb4b3f 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,3 +1,3 @@ from .table import RKMETableSpecification, RKMEStatSpecification from .image import RKMEImageSpecification -from .base import RegularStatsSpecification \ No newline at end of file +from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/base.py b/learnware/specification/regular/base.py index 48a7e1f..6916177 100644 --- a/learnware/specification/regular/base.py +++ b/learnware/specification/regular/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ..base import BaseStatSpecification diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index e0454da..1f05382 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -122,9 +122,7 @@ class RKMEImageSpecification(BaseStatSpecification): X[i] = torch.where(is_nan, img_mean, img) if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: - X = Resize( - (RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None - )(X) + X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X) num_points = X.shape[0] X_shape = X.shape @@ -343,11 +341,8 @@ class RKMEImageSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), - ) + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(rkme_to_save, fout, separators=(",", ":")) def load(self, filepath: str) -> bool: """Load a RKME Image specification file in JSON format from the specified path. diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index 3cc9bd0..19fa956 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1 +1 @@ -from .rkme import RKMETableSpecification +from .rkme import RKMETableSpecification, RKMEStatSpecification diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index ab763d8..ba76f6b 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -26,7 +26,9 @@ from ....logger import get_module_logger logger = get_module_logger("rkme") if not _FAISS_INSTALLED: - logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") + logger.warning( + "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" + ) class RKMETableSpecification(RegularStatsSpecification): @@ -463,12 +465,15 @@ class RKMETableSpecification(RegularStatsSpecification): else: return False + class RKMEStatSpecification(RKMETableSpecification): """nickname for RKMETableSpecification, for compatibility currently. TODO: modify all learnware in database and remove this nickname """ + pass + def setup_seed(seed): """Fix a random seed for addressing reproducibility issues. diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 613e40b..c77e654 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -11,7 +11,6 @@ from learnware.specification import generate_rkme_image_spec, generate_rkme_spec class TestRKME(unittest.TestCase): def test_rkme(self): - pass X = np.random.uniform(-10000, 10000, size=(5000, 200)) rkme = generate_rkme_spec(X) rkme.generate_stat_spec_from_data(X) From 5fb0f46ceb2222b41e11907394d0b94f98c6a25a Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 19:05:18 +0800 Subject: [PATCH 20/24] [FIX] fix json file not close bug in unitest --- learnware/specification/regular/table/rkme.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index ba76f6b..82c81a2 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -429,11 +429,8 @@ class RKMETableSpecification(RegularStatsSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].detach().cpu().numpy() rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), - ) + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(rkme_to_save, fout, separators=(",", ":")) def load(self, filepath: str) -> bool: """Load a RKME specification file in JSON format from the specified path. From 8df0f3925ea71009c6a81fac97a708babc4c8b34 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 19:31:30 +0800 Subject: [PATCH 21/24] [FIX] fix CI yaml error --- .github/workflows/install_learnware_with_pip.yaml | 2 +- .github/workflows/install_learnware_with_source.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/install_learnware_with_pip.yaml b/.github/workflows/install_learnware_with_pip.yaml index 137909e..f22f3c9 100644 --- a/.github/workflows/install_learnware_with_pip.yaml +++ b/.github/workflows/install_learnware_with_pip.yaml @@ -25,7 +25,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Add conda to system path + - name: Add conda to system path run: | # $CONDA is an environment variable pointing to the root of the miniconda directory echo $CONDA/bin >> $GITHUB_PATH diff --git a/.github/workflows/install_learnware_with_source.yaml b/.github/workflows/install_learnware_with_source.yaml index e9589e3..bffc260 100644 --- a/.github/workflows/install_learnware_with_source.yaml +++ b/.github/workflows/install_learnware_with_source.yaml @@ -25,7 +25,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Add conda to system path + - name: Add conda to system path run: | # $CONDA is an environment variable pointing to the root of the miniconda directory echo $CONDA/bin >> $GITHUB_PATH From 72b9133774a142036e721ad0d5f7c169efc2ac43 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 19:34:48 +0800 Subject: [PATCH 22/24] [FIX] fix CI conda install error --- .github/workflows/install_learnware_with_pip.yaml | 2 +- .github/workflows/install_learnware_with_source.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/install_learnware_with_pip.yaml b/.github/workflows/install_learnware_with_pip.yaml index f22f3c9..6b24b74 100644 --- a/.github/workflows/install_learnware_with_pip.yaml +++ b/.github/workflows/install_learnware_with_pip.yaml @@ -33,7 +33,7 @@ jobs: - name: Create conda env for macos run: | conda create -n learnware python=${{ matrix.python-version }} - conda create activate learnware + conda activate learnware - name: Update pip to the latest version run: | diff --git a/.github/workflows/install_learnware_with_source.yaml b/.github/workflows/install_learnware_with_source.yaml index bffc260..5bac535 100644 --- a/.github/workflows/install_learnware_with_source.yaml +++ b/.github/workflows/install_learnware_with_source.yaml @@ -33,7 +33,7 @@ jobs: - name: Create conda env for macos run: | conda create -n learnware python=${{ matrix.python-version }} - conda create activate learnware + conda activate learnware - name: Update pip to the latest version run: | From 978b07077e0240df6ebd0478a402bf881f708c43 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 19:36:47 +0800 Subject: [PATCH 23/24] [FIX] update CI --- .github/workflows/install_learnware_with_pip.yaml | 2 +- .github/workflows/install_learnware_with_source.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/install_learnware_with_pip.yaml b/.github/workflows/install_learnware_with_pip.yaml index 6b24b74..ff5a34c 100644 --- a/.github/workflows/install_learnware_with_pip.yaml +++ b/.github/workflows/install_learnware_with_pip.yaml @@ -14,7 +14,7 @@ jobs: strategy: matrix: os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest] - python-version: [3.8, 3.9, 3.10] + python-version: [3.8, 3.9] steps: - name: Test learnware from pip diff --git a/.github/workflows/install_learnware_with_source.yaml b/.github/workflows/install_learnware_with_source.yaml index 5bac535..d0fb6af 100644 --- a/.github/workflows/install_learnware_with_source.yaml +++ b/.github/workflows/install_learnware_with_source.yaml @@ -13,8 +13,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest] - python-version: [3.8, 3.9, 3.10] + os: [ubuntu-20.04] + python-version: [3.8, 3.9] steps: - name: Test learnware from pip From b0aaae48e77fb5d49d2b7a1c31a2023580ea2115 Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 19:37:17 +0800 Subject: [PATCH 24/24] [FIX] update CI --- .github/workflows/install_learnware_with_pip.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/install_learnware_with_pip.yaml b/.github/workflows/install_learnware_with_pip.yaml index ff5a34c..4b4f86a 100644 --- a/.github/workflows/install_learnware_with_pip.yaml +++ b/.github/workflows/install_learnware_with_pip.yaml @@ -13,7 +13,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest] + os: [ubuntu-20.04] python-version: [3.8, 3.9] steps: