From fb133fc24dffe8c4e152db676fc9df037b730112 Mon Sep 17 00:00:00 2001 From: chenzx Date: Tue, 18 Apr 2023 11:29:05 +0800 Subject: [PATCH] [MNT] Update image example --- examples/example_image/example_yaml.yaml | 8 ++++++++ examples/example_image/main.py | 20 +++++++++++++++---- .../example_pfs/pfs/pfs_cross_transfer.py | 4 ++-- learnware/learnware/reuse.py | 4 +++- 4 files changed, 29 insertions(+), 7 deletions(-) create mode 100644 examples/example_image/example_yaml.yaml diff --git a/examples/example_image/example_yaml.yaml b/examples/example_image/example_yaml.yaml new file mode 100644 index 0000000..6ca01c9 --- /dev/null +++ b/examples/example_image/example_yaml.yaml @@ -0,0 +1,8 @@ +model: + class_name: Model + kwargs: {} +stat_specifications: + - module_path: learnware.specification + class_name: RKMEStatSpecification + file_name: rkme.json + kwargs: {} \ No newline at end of file diff --git a/examples/example_image/main.py b/examples/example_image/main.py index 9f5168a..060cda9 100644 --- a/examples/example_image/main.py +++ b/examples/example_image/main.py @@ -10,8 +10,13 @@ from learnware.market import database_ops from learnware.learnware import Learnware import learnware.specification as specification +from shutil import copyfile, rmtree +import zipfile + origin_data_root = "./data/origin_data" processed_data_root = "./data/processed_data" +tmp_dir = "./data/tmp" +learnware_pool_dir = "./data/learnware_pool" dataset = "cifar10" n_uploaders = 50 n_users = 10 @@ -98,17 +103,24 @@ def prepare_model(): print("Model saved to '%s'" % (model_save_path)) -def prepare_learnware(): - pass +def prepare_learnware(data_path, model_path, init_file_path, yaml_path): + X = np.load(data_path) + user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0) + print(user_spec.shape) def prepare_market(): + image_market = EasyMarket(rebuild=True) + os.makedirs(learnware_pool_dir) for i in range(n_uploaders): data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i)) model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i)) + init_file_path = "./example_init.py" + yaml_file_path = "./example_yaml.yaml" + prepare_learnware(data_path, model_path, init_file_path, yaml_file_path) if __name__ == "__main__": - prepare_data() - prepare_model() + # prepare_data() + # prepare_model() prepare_market() diff --git a/examples/example_pfs/pfs/pfs_cross_transfer.py b/examples/example_pfs/pfs/pfs_cross_transfer.py index a106fb7..93a3fa3 100644 --- a/examples/example_pfs/pfs/pfs_cross_transfer.py +++ b/examples/example_pfs/pfs/pfs_cross_transfer.py @@ -67,7 +67,7 @@ def get_split_errs(algo): for tmp in range(len(proportion_list)): model = lgb.LGBMModel( boosting_type="gbdt", - num_leaves=2 ** 7 - 1, + num_leaves=2**7 - 1, learning_rate=0.01, objective="rmse", metric="rmse", @@ -119,7 +119,7 @@ def get_errors(algo): if algo == "lgb": model = lgb.LGBMModel( boosting_type="gbdt", - num_leaves=2 ** 7 - 1, + num_leaves=2**7 - 1, learning_rate=0.01, objective="rmse", metric="rmse", diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index 9e7b4d3..eb20d7f 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -208,6 +208,8 @@ class ReuseBaseline: booster="gbtree", seed=0, ) - model.fit(org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=-1, early_stopping_rounds=300) + model.fit( + org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=-1, early_stopping_rounds=300 + ) return model