Browse Source

[FIX] Fix bugs in M5

tags/v0.3.2
Gene 3 years ago
parent
commit
4e249efab0
3 changed files with 6 additions and 4 deletions
  1. +3
    -3
      examples/example_image/main.py
  2. +2
    -1
      examples/example_m5/example_init.py
  3. +1
    -0
      examples/example_pfs/example_init.py

+ 3
- 3
examples/example_image/main.py View File

@@ -4,7 +4,7 @@ from get_data import *
import os
import random
from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
from learnware.learnware import Learnware, JobSelectorReuser, EnsembleReuser
from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser
import time

from learnware.market import EasyMarket, BaseUserInfo
@@ -173,7 +173,7 @@ def test_search(gamma=0.1, load_market=True):
job_selector_score_list.append(reuse_score)
print(f"mixture reuse loss: {reuse_score}\n")

reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list, mode="vote")
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote")
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
ensemble_score = eval_prediction(ensemble_predict_y, user_label)
ensemble_score_list.append(ensemble_score)
@@ -198,4 +198,4 @@ def test_search(gamma=0.1, load_market=True):
if __name__ == "__main__":
# prepare_data()
# prepare_model()
test_search(load_market=True)
test_search(load_market=False)

+ 2
- 1
examples/example_m5/example_init.py View File

@@ -1,13 +1,14 @@
import os
import joblib
import numpy as np
import lightgbm as lgb
from learnware.model import BaseModel


class Model(BaseModel):
def __init__(self):
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "model.out"))
self.model = lgb.Booster(model_file=os.path.join(dir_path, "model.out"))

def fit(self, X: np.ndarray, y: np.ndarray):
pass


+ 1
- 0
examples/example_pfs/example_init.py View File

@@ -6,6 +6,7 @@ from learnware.model import BaseModel

class Model(BaseModel):
def __init__(self):
super(Model, self).__init__(input_shape=(31,), output_shape=())
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "model.out"))



Loading…
Cancel
Save