Browse Source

[MNT] Update image example

tags/v0.3.2
chenzx 3 years ago
parent
commit
1ad8bc5643
4 changed files with 10 additions and 10 deletions
  1. +1
    -0
      examples/example_image/example_files/example_init.py
  2. +7
    -8
      examples/example_image/main.py
  3. +1
    -1
      learnware/learnware/reuse.py
  4. +1
    -1
      learnware/market/easy.py

+ 1
- 0
examples/example_image/example_files/example_init.py View File

@@ -8,6 +8,7 @@ import torch

class Model(BaseModel):
def __init__(self):
super().__init__(input_shape=(3, 32, 32), output_shape=(10,))
dir_path = os.path.dirname(os.path.abspath(__file__))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = ConvModel(channel=3, n_random_features=10).to(self.device)


+ 7
- 8
examples/example_image/main.py View File

@@ -167,13 +167,12 @@ def test_search(gamma=0.1, load_market=True):
acc_list.append(acc)
logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
# test reuse
"""
reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list)
reuse_predict = reuse_baseline.predict(user_data=user_data)
reuse_score = eval_prediction(reuse_predict, user_label)
job_selector_score_list.append(reuse_score)
print(f"mixture reuse loss: {reuse_score}\n")
"""

reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote")
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
@@ -188,10 +187,10 @@ def test_search(gamma=0.1, load_market=True):
% (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
)
logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
# logger.info(
# "Average Job Selector Reuse Performance: %.3f +/- %.3f"
# % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
# )
logger.info(
"Average Job Selector Reuse Performance: %.3f +/- %.3f"
% (np.mean(job_selector_score_list), np.std(job_selector_score_list))
)
logger.info(
"Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
)
@@ -199,5 +198,5 @@ def test_search(gamma=0.1, load_market=True):

if __name__ == "__main__":
# prepare_data()
# prepare_model()
test_search(load_market=True)
prepare_model()
test_search(load_market=False)

+ 1
- 1
learnware/learnware/reuse.py View File

@@ -91,7 +91,7 @@ class JobSelectorReuser(BaseReuser):
task_spec = learnware_rkme_spec_list[i]
task_herding_num = max(5, int(self.herding_num * task_mixture_weight[i]))
task_val_num = task_herding_num // 5
if self.use_herding:
herding_X_i = task_spec.herding(task_herding_num).detach().cpu().numpy()
else:


+ 1
- 1
learnware/market/easy.py View File

@@ -476,7 +476,7 @@ class EasyMarket(BaseMarket):
"""
learnware_num = len(learnware_list)
if learnware_num == 0:
return [], []
return None, [], []
if learnware_num < max_search_num:
logger.warning("Available Learnware num less than search_num!")
max_search_num = learnware_num


Loading…
Cancel
Save