From 0e705c87cc6e1457de01b05f8f26499b6fedd372 Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 21 Apr 2023 15:13:00 +0800 Subject: [PATCH 1/2] [ENH] Modify learnware check --- learnware/market/easy.py | 65 +++++++++++++++++++--------------------- learnware/model/base.py | 5 ++-- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/learnware/market/easy.py b/learnware/market/easy.py index f08b40a..bd68fe1 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -19,9 +19,9 @@ logger = get_module_logger("market", "INFO") class EasyMarket(BaseMarket): - INVALID_LEARNWARE = "INVALID" - NONUSABLE_LEARNWARE = "NONUSABLE" - USABLE_LEARWARE = "USABLE" + INVALID_LEARNWARE = -1 + NONUSABLE_LEARNWARE = 0 + USABLE_LEARWARE = 1 def __init__(self, market_id: str = "default", rebuild: bool = False): """Initialize Learnware Market. @@ -82,15 +82,15 @@ class EasyMarket(BaseMarket): return cls.NONUSABLE_LEARNWARE try: - spec_data = learnware.specification.stat_spec["RKMEStatSpecification"].get_z() - except Exception: - logger.warning(f"The learnware [{learnware.id}] statistic specification is not avaliable!") - return cls.INVALID_LEARNWARE + learnware_model = learnware.get_model() + inputs = np.random.randn((10, *learnware_model.input_shape)) + outputs = learnware.predict(inputs) + if outputs.shape[1:] != learnware_model.output_shape: + logger.warning(f"The learnware [{learnware.id}] input and output dimention is error") + return cls.NONUSABLE_LEARNWARE - try: - pred_spec = learnware.predict(spec_data) - except Exception: - logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable") + except Exception as e: + logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}") return cls.NONUSABLE_LEARNWARE return cls.USABLE_LEARWARE @@ -112,9 +112,9 @@ class EasyMarket(BaseMarket): Returns ------- - Tuple[str, bool] + Tuple[str, int] - str indicating model_id - - bool indicating whether the learnware is added successfully. + - int indicating what the flag of learnware is added. """ if not os.path.exists(zip_path): @@ -160,38 +160,33 @@ class EasyMarket(BaseMarket): with zipfile.ZipFile(target_zip_dir, "r") as z_file: z_file.extractall(target_folder_dir) logger.info("Learnware move to %s, and unzip to %s" % (target_zip_dir, target_folder_dir)) + try: new_learnware = get_learnware_from_dirpath( id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir ) except: - new_learnware = None - - if new_learnware is None: try: os.remove(target_zip_dir) rmtree(target_folder_dir) except: pass - return None, False - else: - check_flag = self.check_learnware(new_learnware) - if not check_flag == self.INVALID_LEARNWARE: - self.learnware_list[id] = new_learnware - self.learnware_zip_list[id] = target_zip_dir - self.learnware_folder_list[id] = target_folder_dir - self.count += 1 - add_learnware_to_db( - market_id=self.market_id, - id=id, - semantic_spec=semantic_spec, - zip_path=target_zip_dir, - folder_path=target_folder_dir, - use_flag=check_flag, - ) - return id, True - else: - return None, False + return None, self.INVALID_LEARNWARE + + check_flag = self.check_learnware(new_learnware) + self.learnware_list[id] = new_learnware + self.learnware_zip_list[id] = target_zip_dir + self.learnware_folder_list[id] = target_folder_dir + self.count += 1 + add_learnware_to_db( + market_id=self.market_id, + id=id, + semantic_spec=semantic_spec, + zip_path=target_zip_dir, + folder_path=target_folder_dir, + use_flag=check_flag, + ) + return id, check_flag def _convert_dist_to_score( self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 diff --git a/learnware/model/base.py b/learnware/model/base.py index fe6b56e..db895e9 100644 --- a/learnware/model/base.py +++ b/learnware/model/base.py @@ -3,8 +3,9 @@ from abc import abstractmethod class BaseModel: - def __init__(self): - pass + def __init__(self, input_shape: tuple, output_shape: tuple): + self.input_shape = input_shape + self.output_shape = output_shape def fit(self, X: np.ndarray, y: np.ndarray): pass From 649091e3da478b7d05ca614d88170c65e430da54 Mon Sep 17 00:00:00 2001 From: xiey Date: Fri, 21 Apr 2023 15:16:06 +0800 Subject: [PATCH 2/2] [FIX] fix bugs in M5 --- examples/example_m5/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/example_m5/main.py b/examples/example_m5/main.py index 761582c..6b7544e 100644 --- a/examples/example_m5/main.py +++ b/examples/example_m5/main.py @@ -45,10 +45,10 @@ class M5DatasetWorkflow: def _init_learnware_market(self): """initialize learnware market""" - database_ops.clear_learnware_table() + # database_ops.clear_learnware_table() learnware.init() - easy_market = EasyMarket() + easy_market = EasyMarket(rebuild=True) print("Total Item:", len(easy_market)) zip_path_list = [] @@ -130,7 +130,12 @@ class M5DatasetWorkflow: user_info = BaseUserInfo( id=f"user_{idx}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec} ) - sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info) + ( + sorted_score_list, + single_learnware_list, + mixture_score, + mixture_learnware_list, + ) = easy_market.search_learnware(user_info) print(f"search result of user{idx}:") print(