Browse Source

[ENH, Fix] Add generate_rkme_image_spec, fix device bugs

tags/v0.3.2
shihy 2 years ago
parent
commit
4cfd2cb7cc
3 changed files with 77 additions and 9 deletions
  1. +1
    -1
      learnware/specification/__init__.py
  2. +10
    -8
      learnware/specification/image/rkme.py
  3. +66
    -0
      learnware/specification/utils.py

+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,4 +1,4 @@
from .utils import generate_stat_spec
from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec
from .base import Specification, BaseStatSpecification
from .image import RKMEImageStatSpecification
from .table import RKMEStatSpecification

+ 10
- 8
learnware/specification/image/rkme.py View File

@@ -14,6 +14,7 @@ import torch_optimizer
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import Resize
from tqdm import tqdm

from . import cnn_gp
from ..base import BaseStatSpecification
@@ -33,7 +34,6 @@ class RKMEImageStatSpecification(BaseStatSpecification):
A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used.
"""
self.RKME_IMAGE_VERSION = 1 # Please maintain backward compatibility.
# torch.cuda.empty_cache()

self.z = None
self.beta = None
@@ -67,9 +67,10 @@ class RKMEImageStatSpecification(BaseStatSpecification):
K: int = 50,
step_size: float = 0.01,
steps: int = 100,
resize: bool = False,
resize: bool = True,
nonnegative_beta: bool = True,
reduce: bool = True,
verbose: bool = True,
**kwargs,
):
"""Construct reduced set from raw dataset using iterative optimization.
@@ -77,7 +78,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
Parameters
----------
X : np.ndarray or torch.tensor
Raw data in np.ndarray format.
Raw data in [N, C, H, W] format.
K : int
Size of the construced reduced set.
step_size : float
@@ -90,7 +91,8 @@ class RKMEImageStatSpecification(BaseStatSpecification):
True if weights for the reduced set are intended to be kept non-negative, by default False.
reduce : bool, optional
Whether shrink original data to a smaller set, by default True

verbose : bool, optional
Whether to print training progress, by default True
Returns
-------

@@ -107,7 +109,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):

if not torch.is_tensor(X):
X = torch.from_numpy(X)
X = X.to(self.device)
X = X.to(self.device).float()

X[torch.isinf(X) | torch.isneginf(X) | torch.isposinf(X) | torch.isneginf(X)] = torch.nan
if torch.any(torch.isnan(X)):
@@ -120,7 +122,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
X[i] = torch.where(is_nan, img_mean, img)

if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH))(X)
X = Resize((RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None)(X)

num_points = X.shape[0]
X_shape = X.shape
@@ -147,7 +149,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):

optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)

for i in range(steps):
for _ in tqdm(range(steps)) if verbose else range(steps):
# Regenerate Random Models
random_models = list(self._generate_models(n_models=self.n_models, channel=X.shape[1]))

@@ -384,7 +386,7 @@ def _get_zca_matrix(X, reg_coef=0.1):
X_flat = X.reshape(X.shape[0], -1)
cov = (X_flat.T @ X_flat) / X_flat.shape[0]
reg_amount = reg_coef * torch.trace(cov) / cov.shape[0]
u, s, _ = torch.svd(cov.cuda() + reg_amount * torch.eye(cov.shape[0]).cuda())
u, s, _ = torch.svd(cov + reg_amount * torch.eye(cov.shape[0]).to(X.device))
inv_sqrt_zca_eigs = s ** (-0.5)
whitening_transform = torch.einsum("ij,j,kj->ik", u, inv_sqrt_zca_eigs, u)



+ 66
- 0
learnware/specification/utils.py View File

@@ -3,6 +3,7 @@ import numpy as np
import pandas as pd
from typing import Union

from .image import RKMEImageStatSpecification
from .base import BaseStatSpecification
from .table import RKMEStatSpecification
from ..config import C
@@ -99,6 +100,71 @@ def generate_rkme_spec(
return rkme_spec


def generate_rkme_image_spec(
X: Union[np.ndarray, torch.Tensor],
reduced_set_size: int = 50,
step_size: float = 0.01,
steps: int = 100,
resize: bool = True,
nonnegative_beta: bool = True,
reduce: bool = True,
verbose: bool = True,
cuda_idx: int = None,
) -> RKMEImageStatSpecification:
"""
Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image.
Return a RKMEImageStatSpecification object, use .save() method to save as json file.

Parameters
----------
X : np.ndarray, or torch.Tensor
Raw data in np.ndarray, or torch.Tensor format.
The shape of X: [N, C, H, W]
N: Number of images.
C: Number of channels.
H: Height of images.
W: Width of images.s
For example, if X has shape (100, 3, 32, 32), it means there are 100 samples, and each sample is a 3-channel (RGB) image of size 32x32.
reduced_set_size : int
Size of the construced reduced set.
step_size : float
Step size for gradient descent in the iterative optimization.
steps : int
Total rounds in the iterative optimization.
resize : bool
Whether to scale the image to the requested size, by default True.
nonnegative_beta : bool, optional
True if weights for the reduced set are intended to be kept non-negative, by default False.
reduce : bool, optional
Whether shrink original data to a smaller set, by default True
cuda_idx : int
A flag indicating whether use CUDA during RKME computation. -1 indicates CUDA not used.
None indicates that CUDA is automatically selected.
verbose : bool, optional
Whether to print training progress, by default True

Returns
-------
RKMEImageStatSpecification
A RKMEImageStatSpecification object
"""

# Check cuda_idx
if not torch.cuda.is_available() or cuda_idx == -1:
cuda_idx = -1
else:
num_cuda_devices = torch.cuda.device_count()
if cuda_idx is None or not (0 <= cuda_idx < num_cuda_devices):
cuda_idx = 0

# Generate rkme spec
rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx)
rkme_image_spec.generate_stat_spec_from_data(
X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose
)
return rkme_image_spec


def generate_stat_spec(X: np.ndarray) -> BaseStatSpecification:
"""
Interface for users to generate statistical specification.


Loading…
Cancel
Save