|
|
|
@@ -10,8 +10,13 @@ from learnware.market import database_ops |
|
|
|
from learnware.learnware import Learnware |
|
|
|
import learnware.specification as specification |
|
|
|
|
|
|
|
from shutil import copyfile, rmtree |
|
|
|
import zipfile |
|
|
|
|
|
|
|
origin_data_root = "./data/origin_data" |
|
|
|
processed_data_root = "./data/processed_data" |
|
|
|
tmp_dir = "./data/tmp" |
|
|
|
learnware_pool_dir = "./data/learnware_pool" |
|
|
|
dataset = "cifar10" |
|
|
|
n_uploaders = 50 |
|
|
|
n_users = 10 |
|
|
|
@@ -98,17 +103,24 @@ def prepare_model(): |
|
|
|
print("Model saved to '%s'" % (model_save_path)) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_learnware(): |
|
|
|
pass |
|
|
|
def prepare_learnware(data_path, model_path, init_file_path, yaml_path): |
|
|
|
X = np.load(data_path) |
|
|
|
user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) |
|
|
|
print(user_spec.shape) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_market(): |
|
|
|
image_market = EasyMarket(rebuild=True) |
|
|
|
os.makedirs(learnware_pool_dir) |
|
|
|
for i in range(n_uploaders): |
|
|
|
data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i)) |
|
|
|
model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) |
|
|
|
init_file_path = "./example_init.py" |
|
|
|
yaml_file_path = "./example_yaml.yaml" |
|
|
|
prepare_learnware(data_path, model_path, init_file_path, yaml_file_path) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
prepare_data() |
|
|
|
prepare_model() |
|
|
|
# prepare_data() |
|
|
|
# prepare_model() |
|
|
|
prepare_market() |