|
|
|
@@ -22,7 +22,7 @@ processed_data_root = "./data/processed_data" |
|
|
|
tmp_dir = "./data/tmp" |
|
|
|
learnware_pool_dir = "./data/learnware_pool" |
|
|
|
dataset = "cifar10" |
|
|
|
n_uploaders = 20 |
|
|
|
n_uploaders = 50 |
|
|
|
n_users = 20 |
|
|
|
n_classes = 10 |
|
|
|
data_root = os.path.join(origin_data_root, dataset) |
|
|
|
@@ -110,7 +110,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo |
|
|
|
|
|
|
|
|
|
|
|
def prepare_market(): |
|
|
|
image_market = EasyMarket(market_id='cifar10',rebuild=True) |
|
|
|
image_market = EasyMarket(market_id="cifar10", rebuild=True) |
|
|
|
try: |
|
|
|
rmtree(learnware_pool_dir) |
|
|
|
except: |
|
|
|
@@ -136,10 +136,10 @@ def prepare_market(): |
|
|
|
|
|
|
|
def test_search(gamma=0.1, load_market=True): |
|
|
|
if load_market: |
|
|
|
image_market = EasyMarket(market_id="image") |
|
|
|
image_market = EasyMarket(market_id="cifar10") |
|
|
|
else: |
|
|
|
prepare_market() |
|
|
|
image_market = EasyMarket(market_id="image") |
|
|
|
image_market = EasyMarket(market_id="cifar10") |
|
|
|
logger.info("Number of items in the market: %d" % len(image_market)) |
|
|
|
|
|
|
|
select_list = [] |
|
|
|
@@ -203,6 +203,6 @@ def test_search(gamma=0.1, load_market=True): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
# prepare_data() |
|
|
|
prepare_data() |
|
|
|
prepare_model() |
|
|
|
test_search(load_market=False) |