Browse Source

[FIX] Fix typo and remove mock.py

tags/v0.3.2
shihy 2 years ago
parent
commit
4ea360bc0e
3 changed files with 3 additions and 86 deletions
  1. +1
    -1
      examples/dataset_image_workflow/benchmarks/dataset/data.py
  2. +2
    -2
      examples/dataset_image_workflow/main.py
  3. +0
    -83
      examples/dataset_image_workflow/mock.py

+ 1
- 1
examples/dataset_image_workflow/benchmarks/dataset/data.py View File

@@ -8,7 +8,7 @@ from torchvision.transforms import transforms
from torch.utils.data import TensorDataset

from .utils import cached
from examples.dataset_cifar_workflow.benchmarks.dataset.utils import split_dataset, build_transforms
from examples.dataset_image_workflow.benchmarks.dataset.utils import split_dataset, build_transforms

cache_root = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', '..', 'cache'))



+ 2
- 2
examples/dataset_image_workflow/main.py View File

@@ -29,7 +29,7 @@ class ImageDatasetWorkflow:
learnware.init()
assert not rebuild

market_id = "dataset_cifar_workflow" if market_id is None else market_id
market_id = "dataset_image_workflow" if market_id is None else market_id
orders = np.stack([np.random.permutation(10) for _ in range(market_size)])

print("Using market_id", market_id)
@@ -48,7 +48,7 @@ class ImageDatasetWorkflow:
def evaluate(self, user_size=100, market_id=None, faster=True):
learnware.init()

market_id = "dataset_cifar_workflow" if market_id is None else market_id
market_id = "dataset_image_workflow" if market_id is None else market_id
orders = np.stack([np.random.permutation(10) for _ in range(user_size)])

print("Using market_id", market_id)


+ 0
- 83
examples/dataset_image_workflow/mock.py View File

@@ -1,83 +0,0 @@
import os.path
import random

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets
from torchvision.transforms import transforms

import learnware
from examples.dataset_cifar_workflow.benchmarks.dataset import user_data, split_dataset
from examples.dataset_image_workflow.get_data import get_zca_matrix, transform_data
from learnware import setup_seed
from learnware.specification import generate_rkme_image_spec, RKMEImageSpecification


def f(d):
return np.exp(-d / 0.00005)

def get_spec(path, order=None):
if path is not None and os.path.exists(path):
spec = RKMEImageSpecification()
spec.load(path)
return spec, spec.msg

test_user, spec_user, _, order = user_data(order=order)
loader = DataLoader(spec_user, batch_size=3000, shuffle=True)
sampled_X, _ = next(iter(loader))
spec = generate_rkme_image_spec(sampled_X, whitening=False)
spec.msg = order

if path is not None:
spec.save(path)

return spec, order

DATA_ROOT = "cache"
def get_cifar10(output_channels=3, image_size=32, z_score=True, order=None):
ds_train = datasets.CIFAR10(DATA_ROOT, train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor(), transforms.Resize([image_size, image_size])]))
X_train = ds_train.data
y_train = ds_train.targets
ds_test = datasets.CIFAR10(DATA_ROOT, train=False, download=True, transform=transforms.Compose(
[transforms.ToTensor(), transforms.Resize([image_size, image_size])]))

X_test = ds_test.data
y_test = ds_test.targets

X_train = torch.Tensor(np.moveaxis(X_train, 3, 1))
y_train = torch.Tensor(y_train).long()
X_test = torch.Tensor(np.moveaxis(X_test, 3, 1))
y_test = torch.Tensor(y_test).long()

if output_channels == 1:
X_train = torch.mean(X_train, 1, keepdim=True)
X_test = torch.mean(X_test, 1, keepdim=True)

if z_score:
X_test = (X_test - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (torch.std(X_train, [0, 2, 3], keepdim=True))
X_train = (X_train - torch.mean(X_train, [0, 2, 3], keepdim=True)) / (
torch.std(X_train, [0, 2, 3], keepdim=True))

whitening_mat = get_zca_matrix(X_train, reg_coef=0.1)
train_X = transform_data(X_train, whitening_mat)
test_X = transform_data(X_train, whitening_mat)

selected_data_indexes, order = split_dataset(y_test, 10000, split="user", order=order)

return TensorDataset(test_X[selected_data_indexes], y_test[selected_data_indexes]), order




if __name__ == "__main__":
# old1, order1 = get_spec("spec_1_V100.json", order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# old2, order2 = get_spec("spec_2_A100.json", order=[2, 3, 4, 5, 6, 7, 0, 1, 8, 9])

old3, order3 = get_spec("spec_3_A100.json", order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
old4, order4 = get_spec("spec_6_A100.json", order=[2, 3, 4, 5, 6, 7, 0, 1, 8, 9])

print(order3, order4)
print(f(old3.dist(old4)))


Loading…
Cancel
Save