Browse Source

[MNT] Save test data copy to disk.

tags/v0.3.2
shihy 2 years ago
parent
commit
bb3686b636
2 changed files with 33 additions and 9 deletions
  1. +15
    -8
      examples/dataset_image_workflow/benchmarks/utils.py
  2. +18
    -1
      examples/dataset_image_workflow/main.py

+ 15
- 8
examples/dataset_image_workflow/benchmarks/utils.py View File

@@ -1,5 +1,6 @@
import json
import os
import pickle
import zipfile
from collections import defaultdict
from shutil import rmtree
@@ -86,7 +87,7 @@ def build_learnware(name: str, market: LearnwareMarket, order, model_name="conv"
# build specification
loader = DataLoader(spec_set, batch_size=3000, shuffle=True)
sampled_X, _ = next(iter(loader))
spec = generate_rkme_image_spec(sampled_X, whitening=False, experimental=True)
spec = generate_rkme_image_spec(sampled_X, whitening=False, experimental=False)

# add to market
model_dir = os.path.abspath(os.path.join(__file__, "..", "models"))
@@ -177,24 +178,30 @@ def train_model(model: nn.Module, train_set: Dataset, valid_set: Dataset,

def build_specification(name: str, cache_id, order, sampled_size=3000):
cache_dir = os.path.abspath(os.path.join(
os.path.dirname( __file__ ), '..', 'cache', 'spec'))
os.path.dirname(__file__), '..', 'cache'))
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, "spec_{}.json".format(cache_id))
spec_cache_path = os.path.join(cache_dir, 'spec', "spec_{}.json".format(cache_id))

if os.path.exists(cache_path):
if os.path.exists(spec_cache_path):
spec = RKMEImageSpecification()
spec.load(cache_path)
spec.load(spec_cache_path)

test_dataset, spec_dataset, _, _ = user_data(indices=torch.asarray(spec.msg))
else:
test_dataset, spec_dataset, indices, _ = user_data(order=order)
loader = DataLoader(spec_dataset, batch_size=sampled_size, shuffle=True)
sampled_X, _ = next(iter(loader))
spec = generate_rkme_image_spec(sampled_X, whitening=False, experimental=True)
spec = generate_rkme_image_spec(sampled_X, whitening=False, experimental=False)

spec.msg = indices.tolist()
spec.save(cache_path)

spec.save(spec_cache_path)

# Save test_dataset to disk, spec_dataset is same as test_dataset for now
X, y = next(iter(DataLoader(test_dataset, batch_size=len(test_dataset))))
with open(os.path.join(cache_dir, 'test_data', "user{}_X.pkl".format(cache_id)), "wb") as f:
pickle.dump(X.detach().cpu().numpy(), f)
with open(os.path.join(cache_dir, 'test_data', "user{}_y.pkl".format(cache_id)), "wb") as f:
pickle.dump(y.detach().cpu().numpy(), f)
return spec, test_dataset




+ 18
- 1
examples/dataset_image_workflow/main.py View File

@@ -1,4 +1,5 @@
import os
import random
from datetime import datetime

import fire
@@ -29,6 +30,8 @@ class ImageDatasetWorkflow:
learnware.init()
assert not rebuild

np.random.seed(0)
random.seed(0)
market_id = "dataset_image_workflow" if market_id is None else market_id
orders = np.stack([np.random.permutation(10) for _ in range(market_size)])

@@ -48,12 +51,20 @@ class ImageDatasetWorkflow:
def evaluate(self, user_size=100, market_id=None, faster=True):
learnware.init()

np.random.seed(1)
random.seed(1)
market_id = "dataset_image_workflow" if market_id is None else market_id
orders = np.stack([np.random.permutation(10) for _ in range(user_size)])

print("Using market_id", market_id)
market = instantiate_learnware_market(name="easy", market_id=market_id, rebuild=False)

# Create Folder to save data
train_data_cache_folder = os.path.abspath(os.path.join(__file__, '..', "cache", "train_data"))
test_data_cache_folder = os.path.abspath(os.path.join(__file__, '..', "cache", "test_data"))
os.makedirs(train_data_cache_folder, exist_ok=True)
os.makedirs(test_data_cache_folder, exist_ok=True)

device = choose_device(0)
if faster:
faster_train(device)
@@ -97,7 +108,13 @@ class ImageDatasetWorkflow:
job_loss, job_acc = evaluate(reuse_job_selector, dataset)
unlabeled.record("Job Selector", job_acc, job_loss)

train_set, valid_set, spec_set, order = uploader_data(order=order)
train_set, _, _, _ = uploader_data(order=order)
X, y = next(iter(DataLoader(train_set, batch_size=len(train_set))))
with open(os.path.join(train_data_cache_folder, "user{}_X.pkl".format(i)), "wb") as f:
pickle.dump(X.detach().cpu().numpy(), f)
with open(os.path.join(train_data_cache_folder, "user{}_y.pkl".format(i)), "wb") as f:
pickle.dump(y.detach().cpu().numpy(), f)

for labeled_size in tqdm.tqdm([100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000]):
loader = DataLoader(train_set, batch_size=labeled_size, shuffle=True)
X, y = next(iter(loader))


Loading…
Cancel
Save