Browse Source

[ENH, Fix] Add generate_rkme_image_spec, fix device bugs

tags/v0.3.2
shihy 2 years ago
parent
commit
4cfd2cb7cc
3 changed files with 77 additions and 9 deletions
  1. +1
    -1
      learnware/specification/__init__.py
  2. +10
    -8
      learnware/specification/image/rkme.py
  3. +66
    -0
      learnware/specification/utils.py

+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,4 +1,4 @@
from .utils import generate_stat_spec
from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec
from .base import Specification, BaseStatSpecification from .base import Specification, BaseStatSpecification
from .image import RKMEImageStatSpecification from .image import RKMEImageStatSpecification
from .table import RKMEStatSpecification from .table import RKMEStatSpecification

+ 10
- 8
learnware/specification/image/rkme.py View File

@@ -14,6 +14,7 @@ import torch_optimizer
from torch import nn from torch import nn
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import Resize from torchvision.transforms import Resize
from tqdm import tqdm


from . import cnn_gp from . import cnn_gp
from ..base import BaseStatSpecification from ..base import BaseStatSpecification
@@ -33,7 +34,6 @@ class RKMEImageStatSpecification(BaseStatSpecification):
A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used. A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used.
""" """
self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility. self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility.
# torch.cuda.empty_cache()


self.z = None self.z = None
self.beta = None self.beta = None
@@ -67,9 +67,10 @@ class RKMEImageStatSpecification(BaseStatSpecification):
K: int = 50, K: int = 50,
step_size: float = 0.01, step_size: float = 0.01,
steps: int = 100, steps: int = 100,
resize: bool = False,
resize: bool = True,
nonnegative_beta: bool = True, nonnegative_beta: bool = True,
reduce: bool = True, reduce: bool = True,
verbose: bool = True,
**kwargs, **kwargs,
): ):
"""Construct reduced set from raw dataset using iterative optimization. """Construct reduced set from raw dataset using iterative optimization.
@@ -77,7 +78,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
Parameters Parameters
---------- ----------
X : np.ndarray or torch.tensor X : np.ndarray or torch.tensor
Raw data in np.ndarray format.
Raw data in [N, C, H, W] format.
K : int K : int
Size of the construced reduced set. Size of the construced reduced set.
step_size : float step_size : float
@@ -90,7 +91,8 @@ class RKMEImageStatSpecification(BaseStatSpecification):
True if weights for the reduced set are intended to be kept non-negative, by default False. True if weights for the reduced set are intended to be kept non-negative, by default False.
reduce : bool, optional reduce : bool, optional
Whether shrink original data to a smaller set, by default True Whether shrink original data to a smaller set, by default True

verbose : bool, optional
Whether to print training progress, by default True
Returns Returns
------- -------


@@ -107,7 +109,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):


if not torch.is_tensor(X): if not torch.is_tensor(X):
X = torch.from_numpy(X) X = torch.from_numpy(X)
X = X.to(self.device)
X = X.to(self.device).float()


X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan
if torch.any(torch.isnan(X)): if torch.any(torch.isnan(X)):
@@ -120,7 +122,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
X[i] = torch.where(is_nan, img_mean, img) X[i] = torch.where(is_nan, img_mean, img)


if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH))(X)
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None)(X)


num_points = X.shape[0] num_points = X.shape[0]
X_shape = X.shape X_shape = X.shape
@@ -147,7 +149,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):


optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)


for i in range(steps):
for _ in tqdm(range(steps)) if verbose else range(steps):
# Regenerate Random Models # Regenerate Random Models
random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1])) random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1]))


@@ -384,7 +386,7 @@ def _get_zca_matrix(X, reg_coef=0.1):
X_flat = X.reshape(X.shape[0], -1) X_flat = X.reshape(X.shape[0], -1)
cov = (X_flat.T @ X_flat) / X_flat.shape[0] cov = (X_flat.T @ X_flat) / X_flat.shape[0]
reg_amount = reg_coef * torch.trace(cov) / cov.shape[0] reg_amount = reg_coef * torch.trace(cov) / cov.shape[0]
u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda())
u, s, _ = torch.svd(cov + reg_amount * torch.eye(cov.shape[0]).to(X.device))
inv_sqrt_zca_eigs = s ** (-0.5) inv_sqrt_zca_eigs = s ** (-0.5)
whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u) whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u)




+ 66
- 0
learnware/specification/utils.py View File

@@ -3,6 +3,7 @@ import numpy as np
import pandas as pd import pandas as pd
from typing import Union from typing import Union


from .image import RKMEImageStatSpecification
from .base import BaseStatSpecification from .base import BaseStatSpecification
from .table import RKMEStatSpecification from .table import RKMEStatSpecification
from ..config import C from ..config import C
@@ -99,6 +100,71 @@ def generate_rkme_spec(
return rkme_spec return rkme_spec




def generate_rkme_image_spec(
X: Union[np.ndarray, torch.Tensor],
reduced_set_size: int = 50,
step_size: float = 0.01,
steps: int = 100,
resize: bool = True,
nonnegative_beta: bool = True,
reduce: bool = True,
verbose: bool = True,
cuda_idx: int = None,
) -> RKMEImageStatSpecification:
"""
Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image.
Return a RKMEImageStatSpecification object, use .save() method to save as json file.

Parameters
----------
X : np.ndarray, or torch.Tensor
Raw data in np.ndarray, or torch.Tensor format.
The shape of X: [N, C, H, W]
N: Number of images.
C: Number of channels.
H: Height of images.
W: Width of images.s
For example, if X has shape (100, 3, 32, 32), it means there are 100 samples, and each sample is a 3-channel (RGB) image of size 32x32.
reduced_set_size : int
Size of the construced reduced set.
step_size : float
Step size for gradient descent in the iterative optimization.
steps : int
Total rounds in the iterative optimization.
resize : bool
Whether to scale the image to the requested size, by default True.
nonnegative_beta : bool, optional
True if weights for the reduced set are intended to be kept non-negative, by default False.
reduce : bool, optional
Whether shrink original data to a smaller set, by default True
cuda_idx : int
A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used.
None indicates that CUDA is automatically selected.
verbose : bool, optional
Whether to print training progress, by default True

Returns
-------
RKMEImageStatSpecification
A RKMEImageStatSpecification object
"""

# Check cuda_idx
if not torch.cuda.is_available() or cuda_idx == -1:
cuda_idx = -1
else:
num_cuda_devices = torch.cuda.device_count()
if cuda_idx is None or not (0 <= cuda_idx < num_cuda_devices):
cuda_idx = 0

# Generate rkme spec
rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx)
rkme_image_spec.generate_stat_spec_from_data(
X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose
)
return rkme_image_spec


def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification: def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification:
""" """
Interface for users to generate statistical specification. Interface for users to generate statistical specification.


Loading…
Cancel
Save