Browse Source

Merge branch 'dev' of git.nju.edu.cn:learnware/learnware-market into dev

tags/v0.3.2
chenzx 2 years ago
parent
commit
cfb4f4c94d
6 changed files with 15 additions and 8 deletions
  1. +2
    -2
      examples/example_pfs/main.py
  2. +1
    -0
      examples/workflow_by_code/learnware_example/example_init.py
  3. +7
    -1
      examples/workflow_by_code/main.py
  4. +1
    -1
      learnware/learnware/__init__.py
  5. +2
    -2
      learnware/learnware/reuse.py
  6. +2
    -2
      learnware/market/easy.py

+ 2
- 2
examples/example_pfs/main.py View File

@@ -8,7 +8,7 @@ from shutil import copyfile, rmtree
import learnware
from learnware.market import EasyMarket, BaseUserInfo
from learnware.market import database_ops
from learnware.learnware import Learnware, JobSelectorReuser, EnsembleReuser
from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser
import learnware.specification as specification
from pfs import Dataloader

@@ -163,7 +163,7 @@ class PFSDatasetWorkflow:
job_selector_score = pfs.score(test_y, job_selector_predict_y)
print(f"mixture reuse loss (job selector): {job_selector_score}")

reuse_ensemble = EnsembleReuser(learnware_list=mixture_learnware_list)
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list)
ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
ensemble_score = pfs.score(test_y, ensemble_predict_y)
print(f"mixture reuse loss (ensemble): {ensemble_score}\n")


+ 1
- 0
examples/workflow_by_code/learnware_example/example_init.py View File

@@ -6,6 +6,7 @@ from learnware.model import BaseModel

class SVM(BaseModel):
def __init__(self):
super(SVM, self).__init__(input_shape=(20,), output_shape=())
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "svm.pkl"))



+ 7
- 1
examples/workflow_by_code/main.py View File

@@ -161,11 +161,17 @@ class LearnwareMarketWorkflow:
user_info = BaseUserInfo(
id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
)
sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info)

print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")



+ 1
- 1
learnware/learnware/__init__.py View File

@@ -2,7 +2,7 @@ import os
import copy

from .base import Learnware, BaseReuser
from .reuse import JobSelectorReuser, EnsembleReuser
from .reuse import JobSelectorReuser, AveragingReuser

from .utils import get_stat_spec_from_config, get_model_from_config
from ..specification import Specification


+ 2
- 2
learnware/learnware/reuse.py View File

@@ -226,7 +226,7 @@ class JobSelectorReuser(BaseReuser):
return model


class EnsembleReuser(BaseReuser):
class AveragingReuser(BaseReuser):
"""Baseline Multiple Learnware Reuser uing Ensemble Method"""

def __init__(self, learnware_list: List[Learnware], mode="mean"):
@@ -237,7 +237,7 @@ class EnsembleReuser(BaseReuser):
learnware_list : List[Learnware]
The learnware list, which should have RKME Specification for each learnweare
"""
super(EnsembleReuser, self).__init__(learnware_list)
super(AveragingReuser, self).__init__(learnware_list)
self.mode = mode

def predict(self, user_data: np.ndarray) -> np.ndarray:


+ 2
- 2
learnware/market/easy.py View File

@@ -78,12 +78,12 @@ class EasyMarket(BaseMarket):
try:
learnware.instantiate_model()
except Exception as e:
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {repr(e)}")
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}")
return cls.NONUSABLE_LEARNWARE

try:
learnware_model = learnware.get_model()
inputs = np.random.randn((10, *learnware_model.input_shape))
inputs = np.random.randn(10, *learnware_model.input_shape)
outputs = learnware.predict(inputs)
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] input and output dimention is error")


Loading…
Cancel
Save