From 3581ab8a6a65f8b623457dcf96f4ff5bc7e109a2 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sun, 15 Oct 2023 15:52:16 +0800 Subject: [PATCH] [MNT] black format --- .../dataset_pfs_workflow/pfs/pfs_cross_transfer.py | 4 +--- learnware/client/container.py | 13 ++++--------- learnware/test/__init__.py | 2 +- learnware/test/module.py | 4 +--- tests/test_learnware_client/test_learnware.py | 2 +- tests/test_learnware_client/test_load.py | 4 ++-- tests/test_learnware_client/test_reuse.py | 1 - tests/test_workflow/test_workflow.py | 2 +- 8 files changed, 11 insertions(+), 21 deletions(-) diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 5f69127..93a3fa3 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,9 +85,7 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[ - split:, - ], + train_xs[split:,], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/client/container.py b/learnware/client/container.py index 418d453..10afe88 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -127,18 +127,16 @@ class LearnwaresContainer: ) for _learnware, _zippath in zip(learnware_list, learnware_zippaths) ] - + # We should first register the destroy method atexit.register(self.cleanup) self.init_env() - + def init_env(self): model_list = [_learnware.get_model() for _learnware in self.learnware_list] with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: executor.map(self._initialize_model_container, model_list) - - def cleanup(self): for _learnware in self.learnware_list: self._destroy_model_container(_learnware.get_model()) @@ -148,17 +146,14 @@ class LearnwaresContainer: try: model.init_env_and_metadata() except Exception as err: - logger.warning(f"build env {model.conda_env} failed due to {err}") + logger.warning(f"build env {model.conda_env} failed due to {err}") @staticmethod def _destroy_model_container(model: ModelEnvContainer): try: model.remove_env() except Exception as err: - logger.warning(f"remove env {model.conda_env} failed due to {err}") - + logger.warning(f"remove env {model.conda_env} failed due to {err}") def get_learnware_list_with_container(self): return self.learnware_list - - \ No newline at end of file diff --git a/learnware/test/__init__.py b/learnware/test/__init__.py index ceb0614..a048b3f 100644 --- a/learnware/test/__init__.py +++ b/learnware/test/__init__.py @@ -1 +1 @@ -from .module import get_semantic_specification \ No newline at end of file +from .module import get_semantic_specification diff --git a/learnware/test/module.py b/learnware/test/module.py index e260bc7..677456c 100644 --- a/learnware/test/module.py +++ b/learnware/test/module.py @@ -1,5 +1,3 @@ - - def get_semantic_specification(): semantic_specification = dict() semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} @@ -8,4 +6,4 @@ def get_semantic_specification(): semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} semantic_specification["Name"] = {"Type": "String", "Values": "test"} semantic_specification["Description"] = {"Type": "String", "Values": "test"} - return semantic_specification \ No newline at end of file + return semantic_specification diff --git a/tests/test_learnware_client/test_learnware.py b/tests/test_learnware_client/test_learnware.py index a1c79e1..67bb61f 100644 --- a/tests/test_learnware_client/test_learnware.py +++ b/tests/test_learnware_client/test_learnware.py @@ -3,7 +3,7 @@ from learnware.test import get_semantic_specification if __name__ == "__main__": semantic_specification = get_semantic_specification() - + zip_path = "test.zip" client = LearnwareClient() client.install_environment(zip_path) diff --git a/tests/test_learnware_client/test_load.py b/tests/test_learnware_client/test_load.py index 67981dc..39562f7 100644 --- a/tests/test_learnware_client/test_load.py +++ b/tests/test_learnware_client/test_load.py @@ -24,7 +24,7 @@ class TestLearnwareLoad(unittest.TestCase): self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] def test_load_single_learnware_by_zippath(self): - for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): self.client.download_learnware(learnware_id, zip_path) learnware_list = [ @@ -39,7 +39,7 @@ class TestLearnwareLoad(unittest.TestCase): print(learnware.id, learnware.predict(input_array)) def test_load_multi_learnware_by_zippath(self): - for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): self.client.download_learnware(learnware_id, zip_path) learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env") diff --git a/tests/test_learnware_client/test_reuse.py b/tests/test_learnware_client/test_reuse.py index d71d38c..62191d5 100644 --- a/tests/test_learnware_client/test_reuse.py +++ b/tests/test_learnware_client/test_reuse.py @@ -7,7 +7,6 @@ from learnware.learnware.reuse import AveragingReuser from learnware.test.module import get_semantic_specification if __name__ == "__main__": - semantic_specification = get_semantic_specification() zip_paths = [ "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic.zip", diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index 48fa42f..1711c8a 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -70,7 +70,7 @@ class TestAllWorkflow(unittest.TestCase): env_file = os.path.join(dir_path, "environment.yaml") copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file) - + zip_file = dir_path + ".zip" # zip -q -r -j zip_file dir_path with zipfile.ZipFile(zip_file, "w") as zip_obj: