From 3d0ca966dfd074dee96980669df49f20c9476906 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 12 Oct 2023 16:50:01 +0800 Subject: [PATCH 01/12] [MNT] add test download --- tests/test_client/test_download.py | 33 +++++++++++++++ .../test_learnware.py} | 0 tests/test_client/test_reuse.py | 42 +++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 tests/test_client/test_download.py rename tests/{test_learnware_upload/test_upload.py => test_client/test_learnware.py} (100%) create mode 100644 tests/test_client/test_reuse.py diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py new file mode 100644 index 0000000..5dba8ae --- /dev/null +++ b/tests/test_client/test_download.py @@ -0,0 +1,33 @@ +import os +import numpy as np + +import learnware +from learnware.client import LearnwareClient +from learnware.client.container import ModelEnvContainer, LearnwaresContainer +from learnware.learnware.reuse import AveragingReuser + + +if __name__ == "__main__": + email = "liujd@lamda.nju.edu.cn" + token = "f7e647146a314c6e8b4e2e1079c4bca4" + + client = LearnwareClient() + client.login(email, token) + + learnware_ids = ["00000084", "00000154", "00000155"] + zip_paths = ["1.zip", "2.zip", "3.zip"] + root = os.path.dirname(__file__) + for i in range(len(learnware_ids)): + zip_paths[i] = os.path.join(root, zip_paths[i]) + client.download_learnware(learnware_ids[i], zip_paths[i]) + + learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for idx, learnware in enumerate(learnware_list): + print(f"learnware_{idx}", reuser.predict(learnware)) \ No newline at end of file diff --git a/tests/test_learnware_upload/test_upload.py b/tests/test_client/test_learnware.py similarity index 100% rename from tests/test_learnware_upload/test_upload.py rename to tests/test_client/test_learnware.py diff --git a/tests/test_client/test_reuse.py b/tests/test_client/test_reuse.py new file mode 100644 index 0000000..699a138 --- /dev/null +++ b/tests/test_client/test_reuse.py @@ -0,0 +1,42 @@ +import zipfile +import numpy as np + +from learnware.learnware import get_learnware_from_dirpath, Learnware +from learnware.market import EasyMarket +from learnware.client.container import ModelEnvContainer, LearnwaresContainer +from learnware.learnware.reuse import AveragingReuser + +if __name__ == "__main__": + semantic_specification = dict() + semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} + semantic_specification["Task"] = {"Type": "Class", "Values": ["Ranking"]} + semantic_specification["Library"] = {"Type": "Class", "Values": ["Scikit-learn"]} + semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} + semantic_specification["Name"] = {"Type": "String", "Values": "test"} + semantic_specification["Description"] = {"Type": "String", "Values": "test"} + + zip_paths = [ + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic.zip", + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic.zip", + ] + dir_paths = [ + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic", + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic", + ] + + learnware_list = [] + for id, (zip_path, dir_path) in enumerate(zip(zip_paths, dir_paths)): + with zipfile.ZipFile(zip_path, "r") as z_file: + z_file.extractall(dir_path) + + learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) + learnware_list.append(learnware) + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote") + input_array = np.random.randint(0, 3, size=(20, 9)) + print(reuser.predict(input_array).argmax(axis=1)) + for id, ind_learner in enumerate(learnware_list): + print(f"learner_{id}", reuser.predict(input_array).argmax(axis=1)) From b5713a6a4449ec4fca7ab5541436e364e12a28c3 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 12 Oct 2023 21:52:41 +0800 Subject: [PATCH 02/12] [FIX] fix bugs about virtual env --- learnware/client/learnware_client.py | 4 ++-- learnware/client/utils.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 723267f..3ec9bcf 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -388,9 +388,9 @@ class LearnwareClient: package_utils.filter_nonexist_pip_packages_file(requirements_path, requirements_path_filter) if conda_env is not None: - self.system(f"conda create --name {conda_env}") + self.system(f"conda create -y --name {conda_env} python=3.8") self.system( - f"conda run --no-capture-output python3 -m pip install -r {requirements_path_filter}" + f"conda run --name {conda_env} --no-capture-output python3 -m pip install -r {requirements_path_filter}" ) else: self.system(f"python3 -m pip install -r {requirements_path_filter}") diff --git a/learnware/client/utils.py b/learnware/client/utils.py index b6d9c8a..4639c3b 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -61,12 +61,14 @@ def install_environment(zip_path, conda_env): requirements_file=requirements_path, output_file=requirements_path_filter ) logger.info(f"create empty conda env [{conda_env}]") - system_execute(args=["conda", "create", "--name", f"{conda_env}", "python=3.8"]) + system_execute(args=["conda", "create", "-y", "--name", f"{conda_env}", "python=3.8"]) logger.info(f"install pip requirements for conda env [{conda_env}]") system_execute( args=[ "conda", "run", + "-n", + f"{conda_env}", "--no-capture-output", "python3", "-m", @@ -80,4 +82,4 @@ def install_environment(zip_path, conda_env): raise Exception("Environment.yaml or requirements.txt not found in the learnware zip file.") logger.info(f"install learnware package for conda env [{conda_env}]") - system_execute(args=["conda", "run", "--no-capture-output", "python3", "-m", "pip", "install", "learnware"]) + system_execute(args=["conda", "run", "-n", f"{conda_env}", "--no-capture-output", "python3", "-m", "pip", "install", "learnware"]) From 3e1d1997c0acb44079b2f54034937e02718cba7d Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 12 Oct 2023 23:34:28 +0800 Subject: [PATCH 03/12] [FIX] fix bugs in LearnwareClient --- learnware/client/container.py | 2 +- learnware/client/learnware_client.py | 5 +-- tests/test_client/test_download.py | 55 +++++++++++++++++++++++----- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/learnware/client/container.py b/learnware/client/container.py index 15236b0..8e7573a 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -31,7 +31,7 @@ class ModelEnvContainer(BaseModel): with open(model_path, "wb") as model_fp: pickle.dump(self.model_config, model_fp) - + system_execute( [ "conda", diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 3ec9bcf..1c6cfed 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -337,11 +337,8 @@ class LearnwareClient: if load_model: learnware_obj.instantiate_model() - pass - + return learnware_obj - pass - pass def system(self, command): retcd = os.system(command) diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index 5dba8ae..7314ac8 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -1,12 +1,55 @@ import os +import zipfile import numpy as np import learnware +from learnware.learnware import get_learnware_from_dirpath from learnware.client import LearnwareClient from learnware.client.container import ModelEnvContainer, LearnwaresContainer from learnware.learnware.reuse import AveragingReuser +def test_container(zip_paths): + semantic_specification = dict() + semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} + semantic_specification["Task"] = {"Type": "Class", "Values": ["Ranking"]} + semantic_specification["Library"] = {"Type": "Class", "Values": ["Scikit-learn"]} + semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} + semantic_specification["Name"] = {"Type": "String", "Values": "test"} + semantic_specification["Description"] = {"Type": "String", "Values": "test"} + + learnware_list = [] + for id, zip_path in enumerate(zip_paths): + dir_path = zip_path[:-4] + with zipfile.ZipFile(zip_path, "r") as z_file: + z_file.extractall(dir_path) + + learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) + learnware_list.append(learnware) + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for idx, learnware in enumerate(learnware_list): + print(f"learnware_{idx}", learnware.predict(input_array)) + + +def test_load(zip_paths): + learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for idx, learnware in enumerate(learnware_list): + print(f"learnware_{idx}", learnware.predict(input_array)) + + if __name__ == "__main__": email = "liujd@lamda.nju.edu.cn" token = "f7e647146a314c6e8b4e2e1079c4bca4" @@ -21,13 +64,5 @@ if __name__ == "__main__": zip_paths[i] = os.path.join(root, zip_paths[i]) client.download_learnware(learnware_ids[i], zip_paths[i]) - learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - - with LearnwaresContainer(learnware_list, zip_paths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for idx, learnware in enumerate(learnware_list): - print(f"learnware_{idx}", reuser.predict(learnware)) \ No newline at end of file + test_container(zip_paths) + # test_load(zip_paths) \ No newline at end of file From f75606bf672922db6e55451ca5043be4e3af1024 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 12 Oct 2023 23:35:03 +0800 Subject: [PATCH 04/12] [MNT] format code --- .../pfs/pfs_cross_transfer.py | 4 +++- learnware/client/container.py | 2 +- learnware/client/learnware_client.py | 2 +- learnware/client/utils.py | 15 ++++++++++++- tests/test_client/test_download.py | 22 +++++++++---------- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 93a3fa3..5f69127 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,7 +85,9 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[split:,], + train_xs[ + split:, + ], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/client/container.py b/learnware/client/container.py index 8e7573a..15236b0 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -31,7 +31,7 @@ class ModelEnvContainer(BaseModel): with open(model_path, "wb") as model_fp: pickle.dump(self.model_config, model_fp) - + system_execute( [ "conda", diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 1c6cfed..d2cd8f1 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -337,7 +337,7 @@ class LearnwareClient: if load_model: learnware_obj.instantiate_model() - + return learnware_obj def system(self, command): diff --git a/learnware/client/utils.py b/learnware/client/utils.py index 4639c3b..a48fe45 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -82,4 +82,17 @@ def install_environment(zip_path, conda_env): raise Exception("Environment.yaml or requirements.txt not found in the learnware zip file.") logger.info(f"install learnware package for conda env [{conda_env}]") - system_execute(args=["conda", "run", "-n", f"{conda_env}", "--no-capture-output", "python3", "-m", "pip", "install", "learnware"]) + system_execute( + args=[ + "conda", + "run", + "-n", + f"{conda_env}", + "--no-capture-output", + "python3", + "-m", + "pip", + "install", + "learnware", + ] + ) diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index 7314ac8..7d2cc53 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -17,7 +17,7 @@ def test_container(zip_paths): semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} semantic_specification["Name"] = {"Type": "String", "Values": "test"} semantic_specification["Description"] = {"Type": "String", "Values": "test"} - + learnware_list = [] for id, zip_path in enumerate(zip_paths): dir_path = zip_path[:-4] @@ -26,43 +26,43 @@ def test_container(zip_paths): learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) learnware_list.append(learnware) - + with LearnwaresContainer(learnware_list, zip_paths) as env_container: learnware_list = env_container.get_learnware_list_with_container() reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) - + for idx, learnware in enumerate(learnware_list): print(f"learnware_{idx}", learnware.predict(input_array)) - + def test_load(zip_paths): learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - + with LearnwaresContainer(learnware_list, zip_paths) as env_container: learnware_list = env_container.get_learnware_list_with_container() reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) - + for idx, learnware in enumerate(learnware_list): print(f"learnware_{idx}", learnware.predict(input_array)) - + if __name__ == "__main__": email = "liujd@lamda.nju.edu.cn" token = "f7e647146a314c6e8b4e2e1079c4bca4" - + client = LearnwareClient() client.login(email, token) - + learnware_ids = ["00000084", "00000154", "00000155"] zip_paths = ["1.zip", "2.zip", "3.zip"] root = os.path.dirname(__file__) for i in range(len(learnware_ids)): zip_paths[i] = os.path.join(root, zip_paths[i]) client.download_learnware(learnware_ids[i], zip_paths[i]) - + test_container(zip_paths) - # test_load(zip_paths) \ No newline at end of file + # test_load(zip_paths) From fc18d9a151cf07c197a79bfa1f21dec247183198 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 00:37:36 +0800 Subject: [PATCH 05/12] [MNT] modify load_learnware in LearnwareClient --- learnware/client/learnware_client.py | 55 ++++++++++++++-------------- tests/test_client/test_download.py | 53 +++++---------------------- 2 files changed, 37 insertions(+), 71 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index d2cd8f1..2d40d70 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -64,10 +64,9 @@ class LearnwareClient: self.host = C.backend_host else: self.host = host - pass self.chunk_size = 1024 * 1024 - pass + self.tempdir_list = [] def login(self, email, token): url = f"{self.host}/auth/login_by_token" @@ -305,40 +304,36 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] def load_learnware(self, learnware_file: str, load_model: bool = True): - with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - with zipfile.ZipFile(learnware_file, "r") as z_file: - z_file.extractall(tempdir) - pass + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name + + with zipfile.ZipFile(learnware_file, "r") as z_file: + z_file.extractall(tempdir) - yaml_file = C.learnware_folder_config["yaml_file"] + yaml_file = C.learnware_folder_config["yaml_file"] - with open(os.path.join(tempdir, yaml_file), "r") as fin: - learnware_info = yaml.safe_load(fin) - pass + with open(os.path.join(tempdir, yaml_file), "r") as fin: + learnware_info = yaml.safe_load(fin) - learnware_id = learnware_info.get("id") - if learnware_id is None: - learnware_id = "test_id" - pass + learnware_id = learnware_info.get("id") + if learnware_id is None: + learnware_id = "test_id" - semantic_specification = learnware_info.get("semantic_specification") - if semantic_specification is None: - semantic_specification = {} - pass - else: - semantic_file = semantic_specification.get("file_name") + semantic_specification = learnware_info.get("semantic_specification") + if semantic_specification is None: + semantic_specification = {} + else: + semantic_file = semantic_specification.get("file_name") - with open(os.path.join(tempdir, semantic_file), "r") as fin: - semantic_specification = json.load(fin) - pass - pass + with open(os.path.join(tempdir, semantic_file), "r") as fin: + semantic_specification = json.load(fin) - learnware_obj = learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) + learnware_obj = learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - if load_model: - learnware_obj.instantiate_model() + if load_model: + learnware_obj.instantiate_model() - return learnware_obj + return learnware_obj def system(self, command): retcd = os.system(command) @@ -424,3 +419,7 @@ class LearnwareClient: logger.info("test ok") pass + + def __del__(self): + for tempdir in self.tempdir_list: + tempdir.cleanup() \ No newline at end of file diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index 7d2cc53..4a21286 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -9,47 +9,6 @@ from learnware.client.container import ModelEnvContainer, LearnwaresContainer from learnware.learnware.reuse import AveragingReuser -def test_container(zip_paths): - semantic_specification = dict() - semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} - semantic_specification["Task"] = {"Type": "Class", "Values": ["Ranking"]} - semantic_specification["Library"] = {"Type": "Class", "Values": ["Scikit-learn"]} - semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} - semantic_specification["Name"] = {"Type": "String", "Values": "test"} - semantic_specification["Description"] = {"Type": "String", "Values": "test"} - - learnware_list = [] - for id, zip_path in enumerate(zip_paths): - dir_path = zip_path[:-4] - with zipfile.ZipFile(zip_path, "r") as z_file: - z_file.extractall(dir_path) - - learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) - learnware_list.append(learnware) - - with LearnwaresContainer(learnware_list, zip_paths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for idx, learnware in enumerate(learnware_list): - print(f"learnware_{idx}", learnware.predict(input_array)) - - -def test_load(zip_paths): - learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - - with LearnwaresContainer(learnware_list, zip_paths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for idx, learnware in enumerate(learnware_list): - print(f"learnware_{idx}", learnware.predict(input_array)) - - if __name__ == "__main__": email = "liujd@lamda.nju.edu.cn" token = "f7e647146a314c6e8b4e2e1079c4bca4" @@ -64,5 +23,13 @@ if __name__ == "__main__": zip_paths[i] = os.path.join(root, zip_paths[i]) client.download_learnware(learnware_ids[i], zip_paths[i]) - test_container(zip_paths) - # test_load(zip_paths) + learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) From 6e5af20ecc01063a12df3800e1a2f0e72fb85ea4 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 00:37:59 +0800 Subject: [PATCH 06/12] [MNT] format code --- learnware/client/learnware_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 2d40d70..df187e5 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -306,7 +306,7 @@ class LearnwareClient: def load_learnware(self, learnware_file: str, load_model: bool = True): self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) tempdir = self.tempdir_list[-1].name - + with zipfile.ZipFile(learnware_file, "r") as z_file: z_file.extractall(tempdir) @@ -419,7 +419,7 @@ class LearnwareClient: logger.info("test ok") pass - + def __del__(self): for tempdir in self.tempdir_list: - tempdir.cleanup() \ No newline at end of file + tempdir.cleanup() From 768bdf48a643b33f14423dadcf05cdb97a960ffa Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 09:45:07 +0800 Subject: [PATCH 07/12] [MNT] replace __exit__ by atexit.register --- learnware/client/container.py | 22 +++++++++++----------- learnware/client/learnware_client.py | 4 +++- tests/test_client/test_download.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/learnware/client/container.py b/learnware/client/container.py index 8dcb195..c924722 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -1,5 +1,6 @@ import os import pickle +import atexit import tempfile import shortuuid from concurrent.futures import ProcessPoolExecutor @@ -126,6 +127,12 @@ class LearnwaresContainer: ) for _learnware, _zippath in zip(learnware_list, learnware_zippaths) ] + + model_list = [_learnware.get_model() for _learnware in self.learnware_list] + with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: + executor.map(self._initialize_model_container, model_list) + + atexit.register(self.cleanup) @staticmethod def _initialize_model_container(model: ModelEnvContainer): @@ -135,16 +142,9 @@ class LearnwaresContainer: def _destroy_model_container(model: ModelEnvContainer): model.remove_env() - def __enter__(self): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._initialize_model_container, model_list) - return self - - def __exit__(self, type, value, trace): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._destroy_model_container, model_list) - def get_learnware_list_with_container(self): return self.learnware_list + + def cleanup(self): + for _learnware in self.learnware_list: + self._destroy_model_container(_learnware.get_model()) \ No newline at end of file diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index d1b1887..6760ae6 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -2,6 +2,7 @@ import os import numpy as np import yaml import json +import atexit import zipfile import hashlib import requests @@ -71,6 +72,7 @@ class LearnwareClient: self.chunk_size = 1024 * 1024 self.tempdir_list = [] + atexit.register(self.cleanup) def login(self, email, token): url = f"{self.host}/auth/login_by_token" @@ -439,6 +441,6 @@ class LearnwareClient: return result - def __del__(self): + def cleanup(self): for tempdir in self.tempdir_list: tempdir.cleanup() diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index 4a21286..d74749a 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -25,11 +25,11 @@ if __name__ == "__main__": learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - with LearnwaresContainer(learnware_list, zip_paths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for learnware in learnware_list: - print(learnware.id, learnware.predict(input_array)) + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) From c807116bb7d882faff65f5e453af25b194a9d215 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 10:18:33 +0800 Subject: [PATCH 08/12] [MNT] modify load_learnware in LearnwareClient for LearnwaresContainer --- learnware/client/learnware_client.py | 108 ++++++++++++++++----------- tests/test_client/test_download.py | 32 +++++--- 2 files changed, 86 insertions(+), 54 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 6760ae6..47e93f0 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -9,7 +9,7 @@ import requests import tempfile from enum import Enum from tqdm import tqdm -from typing import List +from typing import Union, List from ..config import C from .. import learnware @@ -309,37 +309,72 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] - def load_learnware(self, learnware_file: str, load_model: bool = True): - self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) - tempdir = self.tempdir_list[-1].name + def load_learnware(self, learnware_file: Union[str, List[str]], load_option: str = "conda_env"): + """Load learnware - with zipfile.ZipFile(learnware_file, "r") as z_file: - z_file.extractall(tempdir) - - yaml_file = C.learnware_folder_config["yaml_file"] - - with open(os.path.join(tempdir, yaml_file), "r") as fin: - learnware_info = yaml.safe_load(fin) - - learnware_id = learnware_info.get("id") - if learnware_id is None: - learnware_id = "test_id" + Parameters + ---------- + learnware_file : Union[str, List[str]] + learnware zip path or learnware zip path list + load_option : str + the option for loading learnwares + - "normal": load learnware without installing environment + - "conda_env": load learnware with installing conda virtual environment + + Returns + ------- + Learnware + The contructed learnware object or object list + """ + if load_option not in ["normal", "conda_env"]: + raise ValueError(f"load_option must be one of ['normal', 'conda_env'], but got {load_option}") + + def _get_learnware_obj(learnware_zippath): + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name + + with zipfile.ZipFile(learnware_zippath, "r") as z_file: + z_file.extractall(tempdir) - semantic_specification = learnware_info.get("semantic_specification") - if semantic_specification is None: - semantic_specification = {} + yaml_file = C.learnware_folder_config["yaml_file"] + with open(os.path.join(tempdir, yaml_file), "r") as fin: + learnware_info = yaml.safe_load(fin) + + learnware_id = learnware_info.get("id") + if learnware_id is None: + learnware_id = "test_id" + + semantic_specification = learnware_info.get("semantic_specification") + if semantic_specification is None: + semantic_specification = {} + else: + semantic_file = semantic_specification.get("file_name") + + with open(os.path.join(tempdir, semantic_file), "r") as fin: + semantic_specification = json.load(fin) + + return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) + + if isinstance(learnware_file, str): + zip_paths = [learnware_file] + elif isinstance(learnware_file, list): + zip_paths = learnware_file + + learnware_list = [] + for zip_path in zip_paths: + learnware_obj = _get_learnware_obj(zip_path) + if load_option == "normal": + learnware_obj.instantiate_model() + learnware_list.append(learnware_obj) + + if load_option == "conda_env": + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() + + if len(learnware_list) == 1: + return learnware_list[0] else: - semantic_file = semantic_specification.get("file_name") - - with open(os.path.join(tempdir, semantic_file), "r") as fin: - semantic_specification = json.load(fin) - - learnware_obj = learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - - if load_model: - learnware_obj.instantiate_model() - - return learnware_obj + return learnware_list def system(self, command): retcd = os.system(command) @@ -426,21 +461,6 @@ class LearnwareClient: logger.info("test ok") pass - def reuse_learnware( - self, - input_array: np.ndarray, - learnware_list: List[Learnware], - learnware_zippaths: List[str], - reuser: BaseReuser, - ): - logger.info(f"reuse learnare list {learnware_list} with reuser {reuser}") - with LearnwaresContainer(learnware_list, learnware_zippaths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser.reset(learnware_list=learnware_list) - result = reuser.predict(input_array) - - return result - def cleanup(self): for tempdir in self.tempdir_list: tempdir.cleanup() diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py index d74749a..a2d3ab3 100644 --- a/tests/test_client/test_download.py +++ b/tests/test_client/test_download.py @@ -9,6 +9,26 @@ from learnware.client.container import ModelEnvContainer, LearnwaresContainer from learnware.learnware.reuse import AveragingReuser +def test_single_learnware(client, zip_paths): + learnware_list = [client.load_learnware(zippath, load_option="conda_env") for zippath in zip_paths] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + +def test_multi_learnware(client, zip_paths): + learnware_list = client.load_learnware(zip_paths, load_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + if __name__ == "__main__": email = "liujd@lamda.nju.edu.cn" token = "f7e647146a314c6e8b4e2e1079c4bca4" @@ -23,13 +43,5 @@ if __name__ == "__main__": zip_paths[i] = os.path.join(root, zip_paths[i]) client.download_learnware(learnware_ids[i], zip_paths[i]) - learnware_list = [client.load_learnware(file, load_model=False) for file in zip_paths] - - env_container = LearnwaresContainer(learnware_list, zip_paths) - learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for learnware in learnware_list: - print(learnware.id, learnware.predict(input_array)) + test_single_learnware(zip_paths) + test_multi_learnware(zip_paths) From 9806c1d4618c5e2e9100d7b56a948772c8511551 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 10:48:22 +0800 Subject: [PATCH 09/12] [MNT] add TestLearnwareLoad --- tests/test_client/test_download.py | 47 --------------------------- tests/test_client/test_load.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 47 deletions(-) delete mode 100644 tests/test_client/test_download.py create mode 100644 tests/test_client/test_load.py diff --git a/tests/test_client/test_download.py b/tests/test_client/test_download.py deleted file mode 100644 index a2d3ab3..0000000 --- a/tests/test_client/test_download.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import zipfile -import numpy as np - -import learnware -from learnware.learnware import get_learnware_from_dirpath -from learnware.client import LearnwareClient -from learnware.client.container import ModelEnvContainer, LearnwaresContainer -from learnware.learnware.reuse import AveragingReuser - - -def test_single_learnware(client, zip_paths): - learnware_list = [client.load_learnware(zippath, load_option="conda_env") for zippath in zip_paths] - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for learnware in learnware_list: - print(learnware.id, learnware.predict(input_array)) - - -def test_multi_learnware(client, zip_paths): - learnware_list = client.load_learnware(zip_paths, load_option="conda_env") - reuser = AveragingReuser(learnware_list, mode="vote_by_label") - input_array = np.random.random(size=(20, 13)) - print(reuser.predict(input_array)) - - for learnware in learnware_list: - print(learnware.id, learnware.predict(input_array)) - - -if __name__ == "__main__": - email = "liujd@lamda.nju.edu.cn" - token = "f7e647146a314c6e8b4e2e1079c4bca4" - - client = LearnwareClient() - client.login(email, token) - - learnware_ids = ["00000084", "00000154", "00000155"] - zip_paths = ["1.zip", "2.zip", "3.zip"] - root = os.path.dirname(__file__) - for i in range(len(learnware_ids)): - zip_paths[i] = os.path.join(root, zip_paths[i]) - client.download_learnware(learnware_ids[i], zip_paths[i]) - - test_single_learnware(zip_paths) - test_multi_learnware(zip_paths) diff --git a/tests/test_client/test_load.py b/tests/test_client/test_load.py new file mode 100644 index 0000000..73be704 --- /dev/null +++ b/tests/test_client/test_load.py @@ -0,0 +1,51 @@ +import os +import unittest +import zipfile +import numpy as np + +import learnware +from learnware.learnware import get_learnware_from_dirpath +from learnware.client import LearnwareClient +from learnware.client.container import ModelEnvContainer, LearnwaresContainer +from learnware.learnware.reuse import AveragingReuser + + +class TestLearnwareLoad(unittest.TestCase): + + def setUp(self): + unittest.TestCase.setUpClass() + email = "liujd@lamda.nju.edu.cn" + token = "f7e647146a314c6e8b4e2e1079c4bca4" + + self.client = LearnwareClient() + self.client.login(email, token) + + learnware_ids = ["00000084", "00000154", "00000155"] + zip_paths = ["1.zip", "2.zip", "3.zip"] + root = os.path.dirname(__file__) + for i in range(len(learnware_ids)): + zip_paths[i] = os.path.join(root, zip_paths[i]) + self.client.download_learnware(learnware_ids[i], zip_paths[i]) + self.zip_paths = zip_paths + + def test_single_learnware(self): + learnware_list = [self.client.load_learnware(zippath, load_option="conda_env") for zippath in self.zip_paths] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_multi_learnware(self): + learnware_list = self.client.load_learnware(self.zip_paths, load_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 6ca46e2bf2e346e53b3fd7090c668cea28a2dea7 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 15:58:25 +0800 Subject: [PATCH 10/12] [MNT] add runnable_option in load_learnware --- learnware/client/learnware_client.py | 78 +++++++++++++++++++--------- tests/test_client/test_load.py | 40 ++++++++++---- 2 files changed, 83 insertions(+), 35 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 47e93f0..d3d79c9 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -1,5 +1,5 @@ import os -import numpy as np +import uuid import yaml import json import atexit @@ -7,6 +7,7 @@ import zipfile import hashlib import requests import tempfile +import numpy as np from enum import Enum from tqdm import tqdm from typing import Union, List @@ -309,31 +310,44 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] - def load_learnware(self, learnware_file: Union[str, List[str]], load_option: str = "conda_env"): - """Load learnware + def load_learnware(self, learnware_path: Union[str, List[str]] = None, learnware_id: Union[str, List[str]] = None, runnable_option: str = None): + """Load learnware by learnware zip file or learnware id (zip file has higher priority) Parameters ---------- - learnware_file : Union[str, List[str]] + learnware_path : Union[str, List[str]] learnware zip path or learnware zip path list - load_option : str - the option for loading learnwares - - "normal": load learnware without installing environment - - "conda_env": load learnware with installing conda virtual environment + learnware_id : Union[str, List[str]] + learnware id or learnware id list + runnable_option : str + the option for instantiating learnwares + - "normal": instantiate learnware without installing environment + - "conda_env": instantiate learnware with installing conda virtual environment Returns ------- Learnware The contructed learnware object or object list """ - if load_option not in ["normal", "conda_env"]: - raise ValueError(f"load_option must be one of ['normal', 'conda_env'], but got {load_option}") + if runnable_option is not None and runnable_option not in ["normal", "conda_env"]: + raise logger.warning(f"runnable_option must be one of ['normal', 'conda_env'], but got {runnable_option}") + + if learnware_path is None and learnware_id is None: + raise ValueError("Requires one of learnware_path or learnware_id") - def _get_learnware_obj(learnware_zippath): + def _get_learnware_by_id(_learnware_id): self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) tempdir = self.tempdir_list[-1].name + zip_path = os.path.join(tempdir, f"{str(uuid.uuid4())}.zip") + self.download_learnware(_learnware_id, zip_path) + return zip_path, _get_learnware_by_path(zip_path, tempdir=tempdir) + + def _get_learnware_by_path(_learnware_zippath, tempdir=None): + if tempdir is None: + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name - with zipfile.ZipFile(learnware_zippath, "r") as z_file: + with zipfile.ZipFile(_learnware_zippath, "r") as z_file: z_file.extractall(tempdir) yaml_file = C.learnware_folder_config["yaml_file"] @@ -354,22 +368,36 @@ class LearnwareClient: semantic_specification = json.load(fin) return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - - if isinstance(learnware_file, str): - zip_paths = [learnware_file] - elif isinstance(learnware_file, list): - zip_paths = learnware_file learnware_list = [] - for zip_path in zip_paths: - learnware_obj = _get_learnware_obj(zip_path) - if load_option == "normal": - learnware_obj.instantiate_model() - learnware_list.append(learnware_obj) + zip_paths = [] + if learnware_path is not None: + if isinstance(learnware_path, str): + zip_paths = [learnware_path] + elif isinstance(learnware_path, list): + zip_paths = learnware_path + + for zip_path in zip_paths: + learnware_obj = _get_learnware_by_path(zip_path) + learnware_list.append(learnware_obj) + elif learnware_id is not None: + if isinstance(learnware_id, str): + id_list = [learnware_id] + elif isinstance(learnware_id, list): + id_list = learnware_id + + for idx in id_list: + zip_path, learnware_obj = _get_learnware_by_id(idx) + zip_paths.append(zip_path) + learnware_list.append(learnware_obj) - if load_option == "conda_env": - env_container = LearnwaresContainer(learnware_list, zip_paths) - learnware_list = env_container.get_learnware_list_with_container() + if runnable_option is not None: + if runnable_option == "normal": + for i in range(len(learnware_list)): + learnware_list[i].instantiate_model() + elif runnable_option == "conda_env": + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() if len(learnware_list) == 1: return learnware_list[0] diff --git a/tests/test_client/test_load.py b/tests/test_client/test_load.py index 73be704..1706e00 100644 --- a/tests/test_client/test_load.py +++ b/tests/test_client/test_load.py @@ -20,16 +20,15 @@ class TestLearnwareLoad(unittest.TestCase): self.client = LearnwareClient() self.client.login(email, token) - learnware_ids = ["00000084", "00000154", "00000155"] - zip_paths = ["1.zip", "2.zip", "3.zip"] root = os.path.dirname(__file__) - for i in range(len(learnware_ids)): - zip_paths[i] = os.path.join(root, zip_paths[i]) - self.client.download_learnware(learnware_ids[i], zip_paths[i]) - self.zip_paths = zip_paths + self.learnware_ids = ["00000084", "00000154", "00000155"] + self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] - def test_single_learnware(self): - learnware_list = [self.client.load_learnware(zippath, load_option="conda_env") for zippath in self.zip_paths] + def test_load_single_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = [self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") for zippath in self.zip_paths] reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) @@ -37,8 +36,29 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) - def test_multi_learnware(self): - learnware_list = self.client.load_learnware(self.zip_paths, load_option="conda_env") + def test_load_multi_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_single_learnware_by_id(self): + learnware_list = [self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_multi_learnware_by_id(self): + learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env") reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) From b56be51e27e1fc35c991de07743882480e27adb8 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 15:58:57 +0800 Subject: [PATCH 11/12] [FIX] fix bugs --- tests/test_client/test_reuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client/test_reuse.py b/tests/test_client/test_reuse.py index 699a138..5e84f5d 100644 --- a/tests/test_client/test_reuse.py +++ b/tests/test_client/test_reuse.py @@ -35,7 +35,7 @@ if __name__ == "__main__": with LearnwaresContainer(learnware_list, zip_paths) as env_container: learnware_list = env_container.get_learnware_list_with_container() - reuser = AveragingReuser(learnware_list, mode="vote") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.randint(0, 3, size=(20, 9)) print(reuser.predict(input_array).argmax(axis=1)) for id, ind_learner in enumerate(learnware_list): From e062a181cff4c0ce96b97a9c0e18b6c0651ca169 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 13 Oct 2023 15:59:12 +0800 Subject: [PATCH 12/12] [MNT] format code --- learnware/client/container.py | 8 ++++---- learnware/client/learnware_client.py | 21 +++++++++++++-------- tests/test_client/test_load.py | 22 +++++++++++++--------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/learnware/client/container.py b/learnware/client/container.py index c924722..2221523 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -127,11 +127,11 @@ class LearnwaresContainer: ) for _learnware, _zippath in zip(learnware_list, learnware_zippaths) ] - + model_list = [_learnware.get_model() for _learnware in self.learnware_list] with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: executor.map(self._initialize_model_container, model_list) - + atexit.register(self.cleanup) @staticmethod @@ -144,7 +144,7 @@ class LearnwaresContainer: def get_learnware_list_with_container(self): return self.learnware_list - + def cleanup(self): for _learnware in self.learnware_list: - self._destroy_model_container(_learnware.get_model()) \ No newline at end of file + self._destroy_model_container(_learnware.get_model()) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index d3d79c9..8f07408 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -310,7 +310,12 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] - def load_learnware(self, learnware_path: Union[str, List[str]] = None, learnware_id: Union[str, List[str]] = None, runnable_option: str = None): + def load_learnware( + self, + learnware_path: Union[str, List[str]] = None, + learnware_id: Union[str, List[str]] = None, + runnable_option: str = None, + ): """Load learnware by learnware zip file or learnware id (zip file has higher priority) Parameters @@ -334,14 +339,14 @@ class LearnwareClient: if learnware_path is None and learnware_id is None: raise ValueError("Requires one of learnware_path or learnware_id") - + def _get_learnware_by_id(_learnware_id): self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) tempdir = self.tempdir_list[-1].name zip_path = os.path.join(tempdir, f"{str(uuid.uuid4())}.zip") self.download_learnware(_learnware_id, zip_path) return zip_path, _get_learnware_by_path(zip_path, tempdir=tempdir) - + def _get_learnware_by_path(_learnware_zippath, tempdir=None): if tempdir is None: self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) @@ -368,7 +373,7 @@ class LearnwareClient: semantic_specification = json.load(fin) return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - + learnware_list = [] zip_paths = [] if learnware_path is not None: @@ -376,7 +381,7 @@ class LearnwareClient: zip_paths = [learnware_path] elif isinstance(learnware_path, list): zip_paths = learnware_path - + for zip_path in zip_paths: learnware_obj = _get_learnware_by_path(zip_path) learnware_list.append(learnware_obj) @@ -385,12 +390,12 @@ class LearnwareClient: id_list = [learnware_id] elif isinstance(learnware_id, list): id_list = learnware_id - + for idx in id_list: zip_path, learnware_obj = _get_learnware_by_id(idx) zip_paths.append(zip_path) learnware_list.append(learnware_obj) - + if runnable_option is not None: if runnable_option == "normal": for i in range(len(learnware_list)): @@ -398,7 +403,7 @@ class LearnwareClient: elif runnable_option == "conda_env": env_container = LearnwaresContainer(learnware_list, zip_paths) learnware_list = env_container.get_learnware_list_with_container() - + if len(learnware_list) == 1: return learnware_list[0] else: diff --git a/tests/test_client/test_load.py b/tests/test_client/test_load.py index 1706e00..67981dc 100644 --- a/tests/test_client/test_load.py +++ b/tests/test_client/test_load.py @@ -11,7 +11,6 @@ from learnware.learnware.reuse import AveragingReuser class TestLearnwareLoad(unittest.TestCase): - def setUp(self): unittest.TestCase.setUpClass() email = "liujd@lamda.nju.edu.cn" @@ -27,8 +26,11 @@ class TestLearnwareLoad(unittest.TestCase): def test_load_single_learnware_by_zippath(self): for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): self.client.download_learnware(learnware_id, zip_path) - - learnware_list = [self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") for zippath in self.zip_paths] + + learnware_list = [ + self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") + for zippath in self.zip_paths + ] reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) @@ -39,7 +41,7 @@ class TestLearnwareLoad(unittest.TestCase): def test_load_multi_learnware_by_zippath(self): for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): self.client.download_learnware(learnware_id, zip_path) - + learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env") reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) @@ -47,9 +49,11 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) - + def test_load_single_learnware_by_id(self): - learnware_list = [self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids] + learnware_list = [ + self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids + ] reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) @@ -57,7 +61,7 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) - def test_load_multi_learnware_by_id(self): + def test_load_multi_learnware_by_id(self): learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env") reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) @@ -67,5 +71,5 @@ class TestLearnwareLoad(unittest.TestCase): print(learnware.id, learnware.predict(input_array)) -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()