diff --git a/examples/example_pfs/main.py b/examples/example_pfs/main.py index ae20c98..5d3ae6a 100644 --- a/examples/example_pfs/main.py +++ b/examples/example_pfs/main.py @@ -8,7 +8,7 @@ from shutil import copyfile, rmtree import learnware from learnware.market import EasyMarket, BaseUserInfo from learnware.market import database_ops -from learnware.learnware import Learnware, JobSelectorReuser, EnsembleReuser +from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser import learnware.specification as specification from pfs import Dataloader @@ -163,7 +163,7 @@ class PFSDatasetWorkflow: job_selector_score = pfs.score(test_y, job_selector_predict_y) print(f"mixture reuse loss (job selector): {job_selector_score}") - reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list) + reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list) ensemble_predict_y = reuse_ensemble.predict(user_data=test_x) ensemble_score = pfs.score(test_y, ensemble_predict_y) print(f"mixture reuse loss (ensemble): {ensemble_score}\n") diff --git a/examples/workflow_by_code/learnware_example/example_init.py b/examples/workflow_by_code/learnware_example/example_init.py index 82d0cb4..cc42047 100644 --- a/examples/workflow_by_code/learnware_example/example_init.py +++ b/examples/workflow_by_code/learnware_example/example_init.py @@ -6,6 +6,7 @@ from learnware.model import BaseModel class SVM(BaseModel): def __init__(self): + super(SVM, self).__init__(input_shape=(20,), output_shape=()) dir_path = os.path.dirname(os.path.abspath(__file__)) self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 01cc4ac..01f09f9 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -161,11 +161,17 @@ class LearnwareMarketWorkflow: user_info = BaseUserInfo( id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec} ) - sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + ( + sorted_score_list, + single_learnware_list, + mixture_score, + mixture_learnware_list, + ) = easy_market.search_learnware(user_info) print(f"search result of user{idx}:") for score, learnware in zip(sorted_score_list, single_learnware_list): print(f"score: {score}, learnware_id: {learnware.id}") + print(f"mixture_score: {mixture_score}\n") mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) print(f"mixture_learnware: {mixture_id}\n") diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index d1cf547..8095254 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -2,7 +2,7 @@ import os import copy from .base import Learnware, BaseReuser -from .reuse import JobSelectorReuser, EnsembleReuser +from .reuse import JobSelectorReuser, AveragingReuser from .utils import get_stat_spec_from_config, get_model_from_config from ..specification import Specification diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index 8207b21..ccd1da8 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -226,7 +226,7 @@ class JobSelectorReuser(BaseReuser): return model -class EnsembleReuser(BaseReuser): +class AveragingReuser(BaseReuser): """Baseline Multiple Learnware Reuser uing Ensemble Method""" def __init__(self, learnware_list: List[Learnware], mode="mean"): @@ -237,7 +237,7 @@ class EnsembleReuser(BaseReuser): learnware_list : List[Learnware] The learnware list, which should have RKME Specification for each learnweare """ - super(EnsembleReuser, self).__init__(learnware_list) + super(AveragingReuser, self).__init__(learnware_list) self.mode = mode def predict(self, user_data: np.ndarray) -> np.ndarray: diff --git a/learnware/market/easy.py b/learnware/market/easy.py index ed86e87..b8d9e1d 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -78,12 +78,12 @@ class EasyMarket(BaseMarket): try: learnware.instantiate_model() except Exception as e: - logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {repr(e)}") + logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}") return cls.NONUSABLE_LEARNWARE try: learnware_model = learnware.get_model() - inputs = np.random.randn((10, *learnware_model.input_shape)) + inputs = np.random.randn(10, *learnware_model.input_shape) outputs = learnware.predict(inputs) if outputs.shape[1:] != learnware_model.output_shape: logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")