From 1ad8bc5643b29927ac3a690aef52286232a7d2fb Mon Sep 17 00:00:00 2001 From: chenzx Date: Fri, 21 Apr 2023 17:02:39 +0800 Subject: [PATCH] [MNT] Update image example --- .../example_image/example_files/example_init.py | 1 + examples/example_image/main.py | 15 +++++++-------- learnware/learnware/reuse.py | 2 +- learnware/market/easy.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/example_image/example_files/example_init.py b/examples/example_image/example_files/example_init.py index e75a116..b318ee8 100644 --- a/examples/example_image/example_files/example_init.py +++ b/examples/example_image/example_files/example_init.py @@ -8,6 +8,7 @@ import torch class Model(BaseModel): def __init__(self): + super().__init__(input_shape=(3, 32, 32), output_shape=(10,)) dir_path = os.path.dirname(os.path.abspath(__file__)) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = ConvModel(channel=3, n_random_features=10).to(self.device) diff --git a/examples/example_image/main.py b/examples/example_image/main.py index e96094a..ed1fb69 100644 --- a/examples/example_image/main.py +++ b/examples/example_image/main.py @@ -167,13 +167,12 @@ def test_search(gamma=0.1, load_market=True): acc_list.append(acc) logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) # test reuse - """ + reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list) reuse_predict = reuse_baseline.predict(user_data=user_data) reuse_score = eval_prediction(reuse_predict, user_label) job_selector_score_list.append(reuse_score) print(f"mixture reuse loss: {reuse_score}\n") - """ reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) @@ -188,10 +187,10 @@ def test_search(gamma=0.1, load_market=True): % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list)) ) logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) - # logger.info( - # "Average Job Selector Reuse Performance: %.3f +/- %.3f" - # % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) - # ) + logger.info( + "Average Job Selector Reuse Performance: %.3f +/- %.3f" + % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) + ) logger.info( "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) @@ -199,5 +198,5 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": # prepare_data() - # prepare_model() - test_search(load_market=True) + prepare_model() + test_search(load_market=False) diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index 8d1d0fe..fbdcc44 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -91,7 +91,7 @@ class JobSelectorReuser(BaseReuser): task_spec = learnware_rkme_spec_list[i] task_herding_num = max(5, int(self.herding_num * task_mixture_weight[i])) task_val_num = task_herding_num // 5 - + if self.use_herding: herding_X_i = task_spec.herding(task_herding_num).detach().cpu().numpy() else: diff --git a/learnware/market/easy.py b/learnware/market/easy.py index b8d9e1d..4fbe859 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -476,7 +476,7 @@ class EasyMarket(BaseMarket): """ learnware_num = len(learnware_list) if learnware_num == 0: - return [], [] + return None, [], [] if learnware_num < max_search_num: logger.warning("Available Learnware num less than search_num!") max_search_num = learnware_num