Browse Source

[MNT] format code by black

tags/v0.3.2
Gene 2 years ago
parent
commit
05005d528a
3 changed files with 25 additions and 15 deletions
  1. +3
    -1
      examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py
  2. +17
    -9
      learnware/market/base.py
  3. +5
    -5
      learnware/market/easy2/checker.py

+ 3
- 1
examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py View File

@@ -85,7 +85,9 @@ def get_split_errs(algo):
split = train_xs.shape[0] - proportion_list[tmp]
model.fit(
train_xs[split:,],
train_xs[
split:,
],
train_ys[split:],
eval_set=[(val_xs, val_ys)],
early_stopping_rounds=50,


+ 17
- 9
learnware/market/base.py View File

@@ -62,7 +62,7 @@ class LearnwareMarket:
self.learnware_organizer.reload_market(rebuild=rebuild)
self.learnware_searcher = BaseSearcher() if searcher is None else searcher
self.learnware_searcher.reset(organizer=self.learnware_organizer)
if checker_list is None:
self.learnware_checker = {"BaseChecker": BaseChecker()}
else:
@@ -78,11 +78,11 @@ class LearnwareMarket:
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)
pending_learnware = get_learnware_from_dirpath(
id="pending", semantic_spec=semantic_specification, learnware_dirpath=tempdir
)
final_status = BaseChecker.INVALID_LEARNWARE
checker_names = list(self.learnware_checker.keys()) if checker_names is None else checker_names

@@ -93,16 +93,20 @@ class LearnwareMarket:

if check_status == BaseChecker.INVALID_LEARNWARE:
return BaseChecker.INVALID_LEARNWARE
return final_status
except Exception as err:
logger.warning(f"Check learnware failed! Due to {err}.")
return BaseChecker.INVALID_LEARNWARE

def add_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> Tuple[str, bool]:
def add_learnware(
self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs
) -> Tuple[str, bool]:
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.add_learnware(zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs)
return self.learnware_organizer.add_learnware(
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def search_learnware(self, user_info: BaseUserInfo, **kwargs) -> Tuple[Any, List[Learnware]]:
return self.learnware_searcher(user_info, **kwargs)
@@ -110,9 +114,13 @@ class LearnwareMarket:
def delete_learnware(self, id: str, **kwargs) -> bool:
return self.learnware_organizer.delete_learnware(id, **kwargs)

def update_learnware(self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
def update_learnware(
self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs
) -> bool:
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.update_learnware(id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs)
return self.learnware_organizer.update_learnware(
id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)

def get_learnware_ids(self, top: int = None, **kwargs):
return self.learnware_organizer.get_learnware_ids(top, **kwargs)


+ 5
- 5
learnware/market/easy2/checker.py View File

@@ -17,18 +17,18 @@ class EasySemanticChecker(BaseChecker):
value = semantic_spec[key]["Values"]
valid_type = C["semantic_specs"][key]["Type"]
assert semantic_spec[key]["Type"] == valid_type, f"{key} type mismatch"
if valid_type == "Class":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) == 1, f"{key} must be unique"
assert value[0] in valid_list, f"{key} must be in {valid_list}"
elif valid_type == "Tag":
valid_list = C["semantic_specs"][key]["Values"]
assert len(value) >= 1, f"{key} cannot be empty"
for v in value:
assert v in valid_list, f"{key} must be in {valid_list}"
elif valid_type == "String":
assert isinstance(value, str), f"{key} must be string"
assert len(value) >= 1, f"{key} cannot be empty"
@@ -89,7 +89,7 @@ class EasyStatisticalChecker(BaseChecker):
# Check output
if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
if outputs.shape[1:] != learnware_model.output_shape:
logger.warning(f"The learnware [{learnware.id}] output dimention mismatch!")
return self.INVALID_LEARNWARE
@@ -112,4 +112,4 @@ class EasyStatisticalChecker(BaseChecker):
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.")
return self.INVALID_LEARNWARE

return self.USABLE_LEARWARE
return self.USABLE_LEARWARE

Loading…
Cancel
Save