|
|
|
@@ -14,6 +14,7 @@ import torch_optimizer |
|
|
|
from torch import nn |
|
|
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
|
from torchvision.transforms import Resize |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
from . import cnn_gp |
|
|
|
from ..base import BaseStatSpecification |
|
|
|
@@ -33,7 +34,6 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. |
|
|
|
""" |
|
|
|
self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. |
|
|
|
# torch.cuda.empty_cache() |
|
|
|
|
|
|
|
self.z = None |
|
|
|
self.beta = None |
|
|
|
@@ -67,9 +67,10 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
K: int = 50, |
|
|
|
step_size: float = 0.01, |
|
|
|
steps: int = 100, |
|
|
|
resize: bool = False, |
|
|
|
resize: bool = True, |
|
|
|
nonnegative_beta: bool = True, |
|
|
|
reduce: bool = True, |
|
|
|
verbose: bool = True, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
"""Construct reduced set from raw dataset using iterative optimization. |
|
|
|
@@ -77,7 +78,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
X : np.ndarray or torch.tensor |
|
|
|
Raw data in np.ndarray format. |
|
|
|
Raw data in [N, C, H, W] format. |
|
|
|
K : int |
|
|
|
Size of the construced reduced set. |
|
|
|
step_size : float |
|
|
|
@@ -90,7 +91,8 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
True if weights for the reduced set are intended to be kept non-negative, by default False. |
|
|
|
reduce : bool, optional |
|
|
|
Whether shrink original data to a smaller set, by default True |
|
|
|
|
|
|
|
verbose : bool, optional |
|
|
|
Whether to print training progress, by default True |
|
|
|
Returns |
|
|
|
------- |
|
|
|
|
|
|
|
@@ -107,7 +109,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
|
|
|
|
if not torch.is_tensor(X): |
|
|
|
X = torch.from_numpy(X) |
|
|
|
X = X.to(self.device) |
|
|
|
X = X.to(self.device).float() |
|
|
|
|
|
|
|
X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan |
|
|
|
if torch.any(torch.isnan(X)): |
|
|
|
@@ -120,7 +122,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
X[i] = torch.where(is_nan, img_mean, img) |
|
|
|
|
|
|
|
if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: |
|
|
|
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH))(X) |
|
|
|
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None)(X) |
|
|
|
|
|
|
|
num_points = X.shape[0] |
|
|
|
X_shape = X.shape |
|
|
|
@@ -147,7 +149,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): |
|
|
|
|
|
|
|
optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) |
|
|
|
|
|
|
|
for i in range(steps): |
|
|
|
for _ in tqdm(range(steps)) if verbose else range(steps): |
|
|
|
# Regenerate Random Models |
|
|
|
random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) |
|
|
|
|
|
|
|
@@ -384,7 +386,7 @@ def _get_zca_matrix(X, reg_coef=0.1): |
|
|
|
X_flat = X.reshape(X.shape[0], -1) |
|
|
|
cov = (X_flat.T @ X_flat) / X_flat.shape[0] |
|
|
|
reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] |
|
|
|
u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda()) |
|
|
|
u, s, _ = torch.svd(cov + reg_amount * torch.eye(cov.shape[0]).to(X.device)) |
|
|
|
inv_sqrt_zca_eigs = s ** (-0.5) |
|
|
|
whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u) |
|
|
|
|
|
|
|
|