From 7e36e63f59deb2165a847896df67c3384adca64b Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 24 Oct 2023 20:47:04 +0800 Subject: [PATCH 1/3] [ENH] add check_learnware in client --- learnware/client/learnware_client.py | 122 +++++---------------------- 1 file changed, 19 insertions(+), 103 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 1f43540..8a6efb0 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -19,7 +19,8 @@ from .container import LearnwaresContainer from ..market.easy import EasyMarket from ..logger import get_module_logger from ..specification import Specification -from ..learnware import BaseReuser, Learnware +from ..learnware import BaseReuser, Learnware, get_learnware_from_dirpath +from ..test import get_semantic_specification CHUNK_SIZE = 1024 * 1024 logger = get_module_logger(module_name="LearnwareClient") @@ -264,29 +265,6 @@ class LearnwareClient: raise Exception("delete failed: " + json.dumps(result)) pass - def check_learnware(self, path, semantic_specification): - if os.path.isfile(path): - with tempfile.TemporaryDirectory() as tempdir: - with zipfile.ZipFile(path, "r") as z_file: - z_file.extractall(tempdir) - pass - return self.check_learnware_folder(tempdir, semantic_specification) - pass - else: - return self.check_learnware_folder(path, semantic_specification) - pass - pass - - def check_learnware_folder(self, folder, semantic_specification): - learnware_obj = learnware.get_learnware_from_dirpath("test_id", semantic_specification, folder) - - check_result = EasyMarket.check_learnware(learnware_obj) - if check_result == EasyMarket.USABLE_LEARWARE: - return True - else: - return False - pass - def create_semantic_specification( self, name, description, data_type, task_type, library_type, senarioes, input_description, output_description ): @@ -408,91 +386,29 @@ class LearnwareClient: return learnware_list[0] else: return learnware_list - - def system(self, command): - retcd = os.system(command) - if retcd != 0: - raise RuntimeError(f"Command {command} failed with return code {retcd}") - pass - - def install_environment(self, zip_path, conda_env=None): - """Install environment of a learnware - - Parameters - ---------- - zip_path : str - Path of the learnware zip file - conda_env : optional - If it is not None, a new conda environment will be created with the given name; - If it is None, use current environment. - - Raises - ------ - Exception - Lack of the environment configuration file. - """ - with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - with zipfile.ZipFile(zip_path, "r") as z_file: - logger.info(f"zip_file namelist: {z_file.namelist}") - if "environment.yaml" in z_file.namelist(): - z_file.extract("environment.yaml", tempdir) - yaml_path = os.path.join(tempdir, "environment.yaml") - yaml_path_filter = os.path.join(tempdir, "environment_filter.yaml") - package_utils.filter_nonexist_conda_packages_file(yaml_path, yaml_path_filter) - # create environment - if conda_env is not None: - self.system(f"conda env update --name {conda_env} --file {yaml_path_filter}") - pass - else: - self.system(f"conda env update --file {yaml_path_filter}") - pass - pass - elif "requirements.txt" in z_file.namelist(): - z_file.extract("requirements.txt", tempdir) - requirements_path = os.path.join(tempdir, "requirements.txt") - requirements_path_filter = os.path.join(tempdir, "requirements_filter.txt") - package_utils.filter_nonexist_pip_packages_file(requirements_path, requirements_path_filter) - - if conda_env is not None: - self.system(f"conda create -y --name {conda_env} python=3.8") - self.system( - f"conda run --name {conda_env} --no-capture-output python3 -m pip install -r {requirements_path_filter}" - ) - else: - self.system(f"python3 -m pip install -r {requirements_path_filter}") - pass - pass - else: - raise Exception("Environment.yaml or requirements.txt not found in the learnware zip file.") - pass - pass - pass - - def test_learnware(self, zip_path, semantic_specification=None): + + @staticmethod + def check_learnware(zip_path, semantic_specification=None): if semantic_specification is None: - semantic_specification = dict() - pass - + semantic_specification = get_semantic_specification() + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(zip_path, mode="r") as z_file: z_file.extractall(tempdir) - pass - - learnware_obj = learnware.get_learnware_from_dirpath("test_id", semantic_specification, tempdir) - - if learnware_obj is None: + + learnware = get_learnware_from_dirpath( + id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir + ) + + if learnware is None: raise Exception("The learnware is not valid.") - - learnware_obj.instantiate_model() - - if len(semantic_specification) > 0: - if EasyMarket.check_learnware(learnware_obj) != EasyMarket.USABLE_LEARWARE: + + with LearnwaresContainer(learnware, zip_path) as env_container: + learnware = env_container.get_learnwares_with_container()[0] + if EasyMarket.check_learnware(learnware) == EasyMarket.USABLE_LEARWARE: + logger.info("The learnware passed the local test.") + else: raise Exception("The learnware is not usable.") - pass - pass - - logger.info("test ok") - pass def cleanup(self): for tempdir in self.tempdir_list: From ffaecc60f76ebc7285cce8a47ab4290086bc5428 Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 24 Oct 2023 20:47:19 +0800 Subject: [PATCH 2/3] [ENH] add test for check_learnware --- .../test_check_learnware.py | 27 +++++++++++++++++++ tests/test_learnware_client/test_learnware.py | 23 ---------------- 2 files changed, 27 insertions(+), 23 deletions(-) create mode 100644 tests/test_learnware_client/test_check_learnware.py delete mode 100644 tests/test_learnware_client/test_learnware.py diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py new file mode 100644 index 0000000..df77c41 --- /dev/null +++ b/tests/test_learnware_client/test_check_learnware.py @@ -0,0 +1,27 @@ +import os +import unittest +import tempfile + + +from learnware.client import LearnwareClient + + +class TestCheckLearnware(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUpClass() + email = "liujd@lamda.nju.edu.cn" + token = "f7e647146a314c6e8b4e2e1079c4bca4" + + self.client = LearnwareClient() + self.client.login(email, token) + self.learnware_id = "00000154" + + def test_check_learnware(self): + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + self.zip_path = os.path.join(tempdir, "test.zip") + self.client.download_learnware(self.learnware_id, self.zip_path) + LearnwareClient.check_learnware(self.zip_path) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_learnware_client/test_learnware.py b/tests/test_learnware_client/test_learnware.py deleted file mode 100644 index 4fee31a..0000000 --- a/tests/test_learnware_client/test_learnware.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import zipfile -import tempfile -from learnware.learnware import get_learnware_from_dirpath -from learnware.test import get_semantic_specification -from learnware.client.container import LearnwaresContainer -from learnware.market import EasyMarket - -if __name__ == "__main__": - semantic_specification = get_semantic_specification() - - zip_path = "rf_tic.zip" - with tempfile.TemporaryDirectory(suffix="learnware") as tempdir: - learnware_dirpath = os.path.join(tempdir, "test") - with zipfile.ZipFile(zip_path, "r") as z_file: - z_file.extractall(learnware_dirpath) - learnware = get_learnware_from_dirpath( - id="test", semantic_spec=semantic_specification, learnware_dirpath=learnware_dirpath - ) - - with LearnwaresContainer(learnware, zip_path) as env_container: - learnware = env_container.get_learnwares_with_container()[0] - print(EasyMarket.check_learnware(learnware)) From c7d44c6831e3935a3eca52388c50cb98865b9672 Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 24 Oct 2023 20:47:35 +0800 Subject: [PATCH 3/3] [MNT] format code --- learnware/client/learnware_client.py | 10 +++---- learnware/market/easy.py | 29 ++++++++++++------- .../test_check_learnware.py | 2 +- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 8a6efb0..e3dd28e 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -386,23 +386,23 @@ class LearnwareClient: return learnware_list[0] else: return learnware_list - + @staticmethod def check_learnware(zip_path, semantic_specification=None): if semantic_specification is None: semantic_specification = get_semantic_specification() - + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(zip_path, mode="r") as z_file: z_file.extractall(tempdir) - + learnware = get_learnware_from_dirpath( id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir ) - + if learnware is None: raise Exception("The learnware is not valid.") - + with LearnwaresContainer(learnware, zip_path) as env_container: learnware = env_container.get_learnwares_with_container()[0] if EasyMarket.check_learnware(learnware) == EasyMarket.USABLE_LEARWARE: diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 26f8fb0..7776d6c 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -699,6 +699,7 @@ class EasyMarket(BaseMarket): List[Learnware] The list of returned learnwares """ + def _match_semantic_spec_tag(semantic_spec1, semantic_spec2) -> bool: """Judge if tags of two semantic specs are consistent @@ -737,8 +738,8 @@ class EasyMarket(BaseMarket): elif semantic_spec1[key]["Type"] == "Tag": if not (set(v1) & set(v2)): return False - return True - + return True + matched_learnware_tag = [] final_result = [] user_semantic_spec = user_info.get_semantic_spec() @@ -747,15 +748,21 @@ class EasyMarket(BaseMarket): learnware_semantic_spec = learnware.get_specification().get_semantic_spec() if _match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec): matched_learnware_tag.append(learnware) - + if len(matched_learnware_tag) > 0: if "Name" in user_semantic_spec: name_user = user_semantic_spec["Name"]["Values"].lower() if len(name_user) > 0: # Exact search - name_list = [learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() for learnware in matched_learnware_tag] - des_list = [learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() for learnware in matched_learnware_tag] - + name_list = [ + learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() + for learnware in matched_learnware_tag + ] + des_list = [ + learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() + for learnware in matched_learnware_tag + ] + matched_learnware_exact = [] for i in range(len(name_list)): if name_user in name_list[i] or name_user in des_list[i]: @@ -771,9 +778,11 @@ class EasyMarket(BaseMarket): if final_score >= min_score: matched_learnware_fuzz.append(matched_learnware_tag[i]) fuzz_scores.append(final_score) - + # Sort by score - sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[:max_num] + sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[ + :max_num + ] final_result = [matched_learnware_fuzz[idx] for idx in sort_idx] else: final_result = matched_learnware_exact @@ -782,9 +791,7 @@ class EasyMarket(BaseMarket): else: final_result = matched_learnware_tag - logger.info( - "semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list)) - ) + logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) return final_result def search_learnware( diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index df77c41..aa21543 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -24,4 +24,4 @@ class TestCheckLearnware(unittest.TestCase): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()