From 05005d528a88982b4bc620e1015ebbaece2ed246 Mon Sep 17 00:00:00 2001 From: Gene Date: Sat, 28 Oct 2023 21:45:51 +0800 Subject: [PATCH] [MNT] format code by black --- .../pfs/pfs_cross_transfer.py | 4 ++- learnware/market/base.py | 26 ++++++++++++------- learnware/market/easy2/checker.py | 10 +++---- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 93a3fa3..5f69127 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,7 +85,9 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[split:,], + train_xs[ + split:, + ], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/market/base.py b/learnware/market/base.py index 1bc2d2d..1c26cea 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -62,7 +62,7 @@ class LearnwareMarket: self.learnware_organizer.reload_market(rebuild=rebuild) self.learnware_searcher = BaseSearcher() if searcher is None else searcher self.learnware_searcher.reset(organizer=self.learnware_organizer) - + if checker_list is None: self.learnware_checker = {"BaseChecker": BaseChecker()} else: @@ -78,11 +78,11 @@ class LearnwareMarket: with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: with zipfile.ZipFile(zip_path, mode="r") as z_file: z_file.extractall(tempdir) - + pending_learnware = get_learnware_from_dirpath( id="pending", semantic_spec=semantic_specification, learnware_dirpath=tempdir ) - + final_status = BaseChecker.INVALID_LEARNWARE checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names @@ -93,16 +93,20 @@ class LearnwareMarket: if check_status == BaseChecker.INVALID_LEARNWARE: return BaseChecker.INVALID_LEARNWARE - + return final_status - + except Exception as err: logger.warning(f"Check learnware failed! Due to {err}.") return BaseChecker.INVALID_LEARNWARE - def add_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> Tuple[str, bool]: + def add_learnware( + self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs + ) -> Tuple[str, bool]: check_status = self.check_learnware(zip_path, semantic_spec, checker_names) - return self.learnware_organizer.add_learnware(zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) + return self.learnware_organizer.add_learnware( + zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs + ) def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]: return self.learnware_searcher(user_info, **kwargs) @@ -110,9 +114,13 @@ class LearnwareMarket: def delete_learnware(self, id: str, **kwargs) -> bool: return self.learnware_organizer.delete_learnware(id, **kwargs) - def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: + def update_learnware( + self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs + ) -> bool: check_status = self.check_learnware(zip_path, semantic_spec, checker_names) - return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs) + return self.learnware_organizer.update_learnware( + id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs + ) def get_learnware_ids(self, top: int = None, **kwargs): return self.learnware_organizer.get_learnware_ids(top, **kwargs) diff --git a/learnware/market/easy2/checker.py b/learnware/market/easy2/checker.py index 25ee452..7f26b91 100644 --- a/learnware/market/easy2/checker.py +++ b/learnware/market/easy2/checker.py @@ -17,18 +17,18 @@ class EasySemanticChecker(BaseChecker): value = semantic_spec[key]["Values"] valid_type = C["semantic_specs"][key]["Type"] assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch" - + if valid_type == "Class": valid_list = C["semantic_specs"][key]["Values"] assert len(value) == 1, f"{key} must be unique" assert value[0] in valid_list, f"{key} must be in {valid_list}" - + elif valid_type == "Tag": valid_list = C["semantic_specs"][key]["Values"] assert len(value) >= 1, f"{key} cannot be empty" for v in value: assert v in valid_list, f"{key} must be in {valid_list}" - + elif valid_type == "String": assert isinstance(value, str), f"{key} must be string" assert len(value) >= 1, f"{key} cannot be empty" @@ -89,7 +89,7 @@ class EasyStatisticalChecker(BaseChecker): # Check output if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) - + if outputs.shape[1:] != learnware_model.output_shape: logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!") return self.INVALID_LEARNWARE @@ -112,4 +112,4 @@ class EasyStatisticalChecker(BaseChecker): logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") return self.INVALID_LEARNWARE - return self.USABLE_LEARWARE \ No newline at end of file + return self.USABLE_LEARWARE