|
|
|
@@ -4,7 +4,7 @@ from get_data import * |
|
|
|
import os |
|
|
|
import random |
|
|
|
from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction |
|
|
|
from learnware.learnware import Learnware, JobSelectorReuser, EnsembleReuser |
|
|
|
from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser |
|
|
|
import time |
|
|
|
|
|
|
|
from learnware.market import EasyMarket, BaseUserInfo |
|
|
|
@@ -173,7 +173,7 @@ def test_search(gamma=0.1, load_market=True): |
|
|
|
job_selector_score_list.append(reuse_score) |
|
|
|
print(f"mixture reuse loss: {reuse_score}\n") |
|
|
|
|
|
|
|
reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list, mode="vote") |
|
|
|
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") |
|
|
|
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) |
|
|
|
ensemble_score = eval_prediction(ensemble_predict_y, user_label) |
|
|
|
ensemble_score_list.append(ensemble_score) |
|
|
|
@@ -198,4 +198,4 @@ def test_search(gamma=0.1, load_market=True): |
|
|
|
if __name__ == "__main__": |
|
|
|
# prepare_data() |
|
|
|
# prepare_model() |
|
|
|
test_search(load_market=True) |
|
|
|
test_search(load_market=False) |