From 1d4d0516e9eb846d51e992e728c937e02e559364 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 15:27:29 +0800 Subject: [PATCH 1/7] [MNT] add log for input_shape in StatChecker --- learnware/market/easy/checker.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 90ee9d1..6a0720a 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -110,6 +110,11 @@ class EasyStatChecker(BaseChecker): if spec_type == "RKMETableSpecification": stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) + if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): + raise ValueError( + f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" + ) + if stat_spec.get_z().shape[1:] != input_shape: message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification." logger.warning(message) @@ -118,6 +123,10 @@ class EasyStatChecker(BaseChecker): elif spec_type == "RKMETextSpecification": inputs = EasyStatChecker._generate_random_text_list(10) elif spec_type == "RKMEImageSpecification": + if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): + raise ValueError( + f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" + ) inputs = np.random.randint(0, 255, size=(10, *input_shape)) else: raise ValueError(f"not supported spec type for spec_type = {spec_type}") From 71ebe0f1e6f743b3bbb1ea0ebb7600a7dd4b6495 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 16:49:35 +0800 Subject: [PATCH 2/7] [MNT] modify check for output dimension --- learnware/market/easy/checker.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 6a0720a..db19ca3 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -164,19 +164,34 @@ class EasyStatChecker(BaseChecker): # Check output shape if outputs[0].shape != learnware_model.output_shape: - message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" + message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" logger.warning(message) return self.INVALID_LEARNWARE, message - # Check output dimension - if semantic_spec["Task"]["Values"][0] in [ - "Classification", - "Regression", - ] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): - message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" + # Check output dimension for regression + if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( + semantic_spec["Output"]["Dimension"] + ): + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" logger.warning(message) return self.INVALID_LEARNWARE, message + # Check output dimension for classification + if semantic_spec["Task"]["Values"][0] == "Classification": + model_output_shape = learnware_model.output_shape[0] + semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) + + if model_output_shape == 1: + if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): + message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + else: + if model_output_shape != semantic_output_shape: + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + except Exception as e: message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." logger.warning(message) From f2def1ef65cef8aad7fb9e0b6d3957a9763ce231 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 17:20:45 +0800 Subject: [PATCH 3/7] [FIX] fix bugs in create_semantic_specification --- learnware/client/learnware_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index b19b055..d6a01b9 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -265,8 +265,8 @@ class LearnwareClient: "Type": "String", "Values": description if description is not None else "", } - semantic_specification["Input"] = input_description - semantic_specification["Output"] = output_description + semantic_specification["Input"] = {} if input_description is None else input_description + semantic_specification["Output"] = {} if output_description is None else output_description return semantic_specification @@ -341,7 +341,7 @@ class LearnwareClient: semantic_specification = json.load(fin) return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - + learnware_list = [] if learnware_path is not None: zip_paths = [learnware_path] if isinstance(learnware_path, str) else learnware_path From 8ba941bf000dde780de83ff412bbde8f6b9967fe Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 18:04:56 +0800 Subject: [PATCH 4/7] [MNT] modify check details for classification tasks --- learnware/market/easy/checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index db19ca3..7e0983e 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -47,6 +47,10 @@ class EasySemanticChecker(BaseChecker): if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: assert semantic_spec["Output"] is not None, "Lack of output semantics" dim = semantic_spec["Output"]["Dimension"] + assert ( + dim > 1 or semantic_spec["Task"]["Values"][0] == "Regression" + ), "Classification task must have dimension > 1" + for k, v in semantic_spec["Output"]["Description"].items(): assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" assert isinstance(v, str), "Description must be string" From 22ab5dfdc4874f76a7b7067df564e3004a23beba Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 18:06:18 +0800 Subject: [PATCH 5/7] [MNT] add semantic_spec in tests --- .../test_all_learnware.py | 6 +++- .../test_check_learnware.py | 32 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/test_learnware_client/test_all_learnware.py b/tests/test_learnware_client/test_all_learnware.py index 37b9046..2303089 100644 --- a/tests/test_learnware_client/test_all_learnware.py +++ b/tests/test_learnware_client/test_all_learnware.py @@ -1,5 +1,6 @@ import os import json +import zipfile import unittest import tempfile @@ -48,8 +49,11 @@ class TestAllLearnware(unittest.TestCase): for idx in learnware_ids: zip_path = os.path.join(tempdir, f"test_{idx}.zip") self.client.download_learnware(idx, zip_path) + with zipfile.ZipFile(zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) try: - LearnwareClient.check_learnware(zip_path) + LearnwareClient.check_learnware(zip_path, semantic_spec) print(f"check learnware {idx} succeed") except: failed_ids.append(idx) diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index 218e222..77a44c3 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -1,4 +1,6 @@ import os +import json +import zipfile import unittest import tempfile @@ -16,35 +18,55 @@ class TestCheckLearnware(unittest.TestCase): with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_conda(self): learnware_id = "00000148" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_dependency(self): learnware_id = "00000147" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_image(self): learnware_id = "00000677" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_text(self): learnware_id = "00000662" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) if __name__ == "__main__": From 11d28330daea36f9f285fda401c26ec2b477dda8 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 21:19:34 +0800 Subject: [PATCH 6/7] [ENH] add update_learnware api --- learnware/client/learnware_client.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index d6a01b9..b396add 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -139,6 +139,33 @@ class LearnwareClient: return result["data"]["learnware_id"] + @require_login + def update_learnware(self, learnware_id, semantic_specification, learnware_zip_path=None): + assert self._check_semantic_specification(semantic_specification)[0], "Semantic specification check failed!" + + url_update = f"{self.host}/user/update_learnware" + payload = {"learnware_id": learnware_id, "semantic_specification": json.dumps(semantic_specification)} + + if learnware_zip_path is None: + response = requests.post( + url_update, + files={"learnware_file": None}, + data=payload, + headers=self.headers, + ) + else: + response = requests.post( + url_update, + files={"learnware_file": open(learnware_zip_path, "rb")}, + data=payload, + headers=self.headers, + ) + + result = response.json() + + if result["code"] != 0: + raise Exception("update failed: " + json.dumps(result)) + def download_learnware(self, learnware_id, save_path): url = f"{self.host}/engine/download_learnware" From 77eeceb7e845b9af659d5c515fe1484c551567d7 Mon Sep 17 00:00:00 2001 From: Gene Date: Sun, 26 Nov 2023 22:21:34 +0800 Subject: [PATCH 7/7] [MNT] modify details --- learnware/market/easy/checker.py | 7 ++++++- tests/test_learnware_client/test_check_learnware.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 7e0983e..eb6fe75 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -186,7 +186,12 @@ class EasyStatChecker(BaseChecker): semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) if model_output_shape == 1: - if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if isinstance(outputs, list): + outputs = np.array(outputs) + + if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)): message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" logger.warning(message) return self.INVALID_LEARNWARE, message diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index 77a44c3..59f0820 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -14,7 +14,7 @@ class TestCheckLearnware(unittest.TestCase): self.client = LearnwareClient() def test_check_learnware_pip(self): - learnware_id = "00000154" + learnware_id = "00000208" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path)