Browse Source

[MNT] Update image example

tags/v0.3.2
chenzx 3 years ago
parent
commit
fb133fc24d
4 changed files with 29 additions and 7 deletions
  1. +8
    -0
      examples/example_image/example_yaml.yaml
  2. +16
    -4
      examples/example_image/main.py
  3. +2
    -2
      examples/example_pfs/pfs/pfs_cross_transfer.py
  4. +3
    -1
      learnware/learnware/reuse.py

+ 8
- 0
examples/example_image/example_yaml.yaml View File

@@ -0,0 +1,8 @@
model:
class_name: Model
kwargs: {}
stat_specifications:
- module_path: learnware.specification
class_name: RKMEStatSpecification
file_name: rkme.json
kwargs: {}

+ 16
- 4
examples/example_image/main.py View File

@@ -10,8 +10,13 @@ from learnware.market import database_ops
from learnware.learnware import Learnware
import learnware.specification as specification

from shutil import copyfile, rmtree
import zipfile

origin_data_root = "./data/origin_data"
processed_data_root = "./data/processed_data"
tmp_dir = "./data/tmp"
learnware_pool_dir = "./data/learnware_pool"
dataset = "cifar10"
n_uploaders = 50
n_users = 10
@@ -98,17 +103,24 @@ def prepare_model():
print("Model saved to '%s'" % (model_save_path))


def prepare_learnware():
pass
def prepare_learnware(data_path, model_path, init_file_path, yaml_path):
X = np.load(data_path)
user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0)
print(user_spec.shape)


def prepare_market():
image_market = EasyMarket(rebuild=True)
os.makedirs(learnware_pool_dir)
for i in range(n_uploaders):
data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
init_file_path = "./example_init.py"
yaml_file_path = "./example_yaml.yaml"
prepare_learnware(data_path, model_path, init_file_path, yaml_file_path)


if __name__ == "__main__":
prepare_data()
prepare_model()
# prepare_data()
# prepare_model()
prepare_market()

+ 2
- 2
examples/example_pfs/pfs/pfs_cross_transfer.py View File

@@ -67,7 +67,7 @@ def get_split_errs(algo):
for tmp in range(len(proportion_list)):
model = lgb.LGBMModel(
boosting_type="gbdt",
num_leaves=2 ** 7 - 1,
num_leaves=2**7 - 1,
learning_rate=0.01,
objective="rmse",
metric="rmse",
@@ -119,7 +119,7 @@ def get_errors(algo):
if algo == "lgb":
model = lgb.LGBMModel(
boosting_type="gbdt",
num_leaves=2 ** 7 - 1,
num_leaves=2**7 - 1,
learning_rate=0.01,
objective="rmse",
metric="rmse",


+ 3
- 1
learnware/learnware/reuse.py View File

@@ -208,6 +208,8 @@ class ReuseBaseline:
booster="gbtree",
seed=0,
)
model.fit(org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=-1, early_stopping_rounds=300)
model.fit(
org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=-1, early_stopping_rounds=300
)

return model

Loading…
Cancel
Save