diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index fdddd21..f3cbe61 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -140,6 +140,33 @@ class LearnwareClient: return result["data"]["learnware_id"] + @require_login + def update_learnware(self, learnware_id, semantic_specification, learnware_zip_path=None): + assert self._check_semantic_specification(semantic_specification)[0], "Semantic specification check failed!" + + url_update = f"{self.host}/user/update_learnware" + payload = {"learnware_id": learnware_id, "semantic_specification": json.dumps(semantic_specification)} + + if learnware_zip_path is None: + response = requests.post( + url_update, + files={"learnware_file": None}, + data=payload, + headers=self.headers, + ) + else: + response = requests.post( + url_update, + files={"learnware_file": open(learnware_zip_path, "rb")}, + data=payload, + headers=self.headers, + ) + + result = response.json() + + if result["code"] != 0: + raise Exception("update failed: " + json.dumps(result)) + def download_learnware(self, learnware_id, save_path): url = f"{self.host}/engine/download_learnware" @@ -275,8 +302,8 @@ class LearnwareClient: "Type": "String", "Values": description if description is not None else "", } - semantic_specification["Input"] = input_description - semantic_specification["Output"] = output_description + semantic_specification["Input"] = {} if input_description is None else input_description + semantic_specification["Output"] = {} if output_description is None else output_description return semantic_specification @@ -351,7 +378,7 @@ class LearnwareClient: semantic_specification = json.load(fin) return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - + learnware_list = [] if learnware_path is not None: zip_paths = [learnware_path] if isinstance(learnware_path, str) else learnware_path diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 90ee9d1..eb6fe75 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -47,6 +47,10 @@ class EasySemanticChecker(BaseChecker): if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: assert semantic_spec["Output"] is not None, "Lack of output semantics" dim = semantic_spec["Output"]["Dimension"] + assert ( + dim > 1 or semantic_spec["Task"]["Values"][0] == "Regression" + ), "Classification task must have dimension > 1" + for k, v in semantic_spec["Output"]["Description"].items(): assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" assert isinstance(v, str), "Description must be string" @@ -110,6 +114,11 @@ class EasyStatChecker(BaseChecker): if spec_type == "RKMETableSpecification": stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) + if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): + raise ValueError( + f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" + ) + if stat_spec.get_z().shape[1:] != input_shape: message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification." logger.warning(message) @@ -118,6 +127,10 @@ class EasyStatChecker(BaseChecker): elif spec_type == "RKMETextSpecification": inputs = EasyStatChecker._generate_random_text_list(10) elif spec_type == "RKMEImageSpecification": + if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): + raise ValueError( + f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" + ) inputs = np.random.randint(0, 255, size=(10, *input_shape)) else: raise ValueError(f"not supported spec type for spec_type = {spec_type}") @@ -155,19 +168,39 @@ class EasyStatChecker(BaseChecker): # Check output shape if outputs[0].shape != learnware_model.output_shape: - message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" + message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" logger.warning(message) return self.INVALID_LEARNWARE, message - # Check output dimension - if semantic_spec["Task"]["Values"][0] in [ - "Classification", - "Regression", - ] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): - message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" + # Check output dimension for regression + if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( + semantic_spec["Output"]["Dimension"] + ): + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" logger.warning(message) return self.INVALID_LEARNWARE, message + # Check output dimension for classification + if semantic_spec["Task"]["Values"][0] == "Classification": + model_output_shape = learnware_model.output_shape[0] + semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) + + if model_output_shape == 1: + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if isinstance(outputs, list): + outputs = np.array(outputs) + + if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)): + message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + else: + if model_output_shape != semantic_output_shape: + message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message + except Exception as e: message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." logger.warning(message) diff --git a/tests/test_learnware_client/test_all_learnware.py b/tests/test_learnware_client/test_all_learnware.py index 37b9046..2303089 100644 --- a/tests/test_learnware_client/test_all_learnware.py +++ b/tests/test_learnware_client/test_all_learnware.py @@ -1,5 +1,6 @@ import os import json +import zipfile import unittest import tempfile @@ -48,8 +49,11 @@ class TestAllLearnware(unittest.TestCase): for idx in learnware_ids: zip_path = os.path.join(tempdir, f"test_{idx}.zip") self.client.download_learnware(idx, zip_path) + with zipfile.ZipFile(zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) try: - LearnwareClient.check_learnware(zip_path) + LearnwareClient.check_learnware(zip_path, semantic_spec) print(f"check learnware {idx} succeed") except: failed_ids.append(idx) diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index 218e222..59f0820 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -1,4 +1,6 @@ import os +import json +import zipfile import unittest import tempfile @@ -12,39 +14,59 @@ class TestCheckLearnware(unittest.TestCase): self.client = LearnwareClient() def test_check_learnware_pip(self): - learnware_id = "00000154" + learnware_id = "00000208" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_conda(self): learnware_id = "00000148" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_dependency(self): learnware_id = "00000147" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_image(self): learnware_id = "00000677" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) def test_check_learnware_text(self): learnware_id = "00000662" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path) - LearnwareClient.check_learnware(self.zip_path) + + with zipfile.ZipFile(self.zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + LearnwareClient.check_learnware(self.zip_path, semantic_spec) if __name__ == "__main__":