From 4e249efab0e2efd337dfaa400177ca88070d7ee5 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 21 Apr 2023 16:55:20 +0800 Subject: [PATCH] [FIX] Fix bugs in M5 --- examples/example_image/main.py | 6 +++--- examples/example_m5/example_init.py | 3 ++- examples/example_pfs/example_init.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/example_image/main.py b/examples/example_image/main.py index 70589dc..90a34e5 100644 --- a/examples/example_image/main.py +++ b/examples/example_image/main.py @@ -4,7 +4,7 @@ from get_data import * import os import random from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction -from learnware.learnware import Learnware, JobSelectorReuser, EnsembleReuser +from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser import time from learnware.market import EasyMarket, BaseUserInfo @@ -173,7 +173,7 @@ def test_search(gamma=0.1, load_market=True): job_selector_score_list.append(reuse_score) print(f"mixture reuse loss: {reuse_score}\n") - reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list, mode="vote") + reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_score = eval_prediction(ensemble_predict_y, user_label) ensemble_score_list.append(ensemble_score) @@ -198,4 +198,4 @@ def test_search(gamma=0.1, load_market=True): if __name__ == "__main__": # prepare_data() # prepare_model() - test_search(load_market=True) + test_search(load_market=False) diff --git a/examples/example_m5/example_init.py b/examples/example_m5/example_init.py index d875d96..70f366d 100644 --- a/examples/example_m5/example_init.py +++ b/examples/example_m5/example_init.py @@ -1,13 +1,14 @@ import os import joblib import numpy as np +import lightgbm as lgb from learnware.model import BaseModel class Model(BaseModel): def __init__(self): dir_path = os.path.dirname(os.path.abspath(__file__)) - self.model = joblib.load(os.path.join(dir_path, "model.out")) + self.model = lgb.Booster(model_file=os.path.join(dir_path, "model.out")) def fit(self, X: np.ndarray, y: np.ndarray): pass diff --git a/examples/example_pfs/example_init.py b/examples/example_pfs/example_init.py index d875d96..88b788a 100644 --- a/examples/example_pfs/example_init.py +++ b/examples/example_pfs/example_init.py @@ -6,6 +6,7 @@ from learnware.model import BaseModel class Model(BaseModel): def __init__(self): + super(Model, self).__init__(input_shape=(31,), output_shape=()) dir_path = os.path.dirname(os.path.abspath(__file__)) self.model = joblib.load(os.path.join(dir_path, "model.out"))