diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 2876863..2b8a84b 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -381,14 +381,14 @@ class LearnwareClient: @staticmethod def _check_semantic_specification(semantic_spec): - return EasySemanticChecker.check_semantic_spec(semantic_spec) != BaseChecker.INVALID_LEARNWARE + return EasySemanticChecker.check_semantic_spec(semantic_spec)[0] != BaseChecker.INVALID_LEARNWARE @staticmethod def _check_stat_specification(learnware): from ..market import CondaChecker stat_checker = CondaChecker(inner_checker=EasyStatChecker()) - return stat_checker(learnware) != BaseChecker.INVALID_LEARNWARE + return stat_checker(learnware)[0] != BaseChecker.INVALID_LEARNWARE @staticmethod def check_learnware(learnware_zip_path, semantic_specification=None): diff --git a/learnware/client/package_utils.py b/learnware/client/package_utils.py index 077fcac..0492f72 100644 --- a/learnware/client/package_utils.py +++ b/learnware/client/package_utils.py @@ -74,12 +74,14 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: nonexist_packages = [] for package in packages: try: - # os.system("python3 -m pip index versions {0}".format(package)) - try_to_run(args=["pip", "index", "versions", parse_pip_requirement(package)], timeout=5) - exist_packages.append(package) + package_name = parse_pip_requirement(package) + if package_name != "learnware": + try_to_run(args=["pip", "index", "versions", package_name], timeout=5) + exist_packages.append(package) + continue except Exception as e: logger.error(e) - nonexist_packages.append(package) + nonexist_packages.append(package) return exist_packages, nonexist_packages @@ -105,7 +107,13 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str] command = f"conda env create --name env_test --file {test_yaml_file} --dry-run --json" result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - output = json.loads(result.stdout.strip()).get("bad_deps", []) + stdout = result.stdout.strip() + last_bracket = stdout.rfind("\n{") + if last_bracket != -1: + stdout = stdout[last_bracket:] + pass + print(stdout) + output = json.loads(stdout).get("bad_deps", []) if len(output) > 0: exist_packages = [] diff --git a/learnware/client/utils.py b/learnware/client/utils.py index 6a15d65..0e11ae0 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -25,8 +25,9 @@ def system_execute(args, timeout=None, env=None, stdout=subprocess.DEVNULL, stde try: com_process.check_returncode() except subprocess.CalledProcessError as err: - logger.warning(f"System Execute Error: {com_process.stderr.decode()}") - raise err + errmsg = com_process.stderr.decode() + logger.warning(f"System Execute Error: {errmsg}") + raise Exception(errmsg) def remove_enviroment(conda_env): diff --git a/learnware/market/base.py b/learnware/market/base.py index d061d7f..12837fa 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -84,7 +84,7 @@ class LearnwareMarket: ) for name in checker_names: checker = self.learnware_checker[name] - check_status = checker(pending_learnware) + check_status, message = checker(pending_learnware) final_status = max(final_status, check_status) if check_status == BaseChecker.INVALID_LEARNWARE: @@ -447,7 +447,7 @@ class BaseChecker: def reset(self, organizer): self.learnware_organizer = organizer - def __call__(self, learnware: Learnware) -> int: + def __call__(self, learnware: Learnware) -> Tuple[int, str]: """Check the utility of a learnware Parameters @@ -456,11 +456,15 @@ class BaseChecker: Returns ------- - int - A flag indicating whether the learnware can be accepted. - - The INVALID_LEARNWARE denotes the learnware does not pass the check - - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency - - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction + Tuple[int, str]: + flag and message of learnware check result + - int + A flag indicating whether the learnware can be accepted. + - The INVALID_LEARNWARE denotes the learnware does not pass the check + - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency + - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction + - str + A message indicating the reason of learnware check result """ raise NotImplementedError("'__call__' method is not implemented in BaseChecker") diff --git a/learnware/market/classes.py b/learnware/market/classes.py index 9c99555..0c15309 100644 --- a/learnware/market/classes.py +++ b/learnware/market/classes.py @@ -16,9 +16,11 @@ class CondaChecker(BaseChecker): try: with LearnwaresContainer(learnware, ignore_error=False) as env_container: learnwares = env_container.get_learnwares_with_container() - check_status = self.inner_checker(learnwares[0]) + check_status, message = self.inner_checker(learnwares[0]) except Exception as e: traceback.print_exc() - logger.warning(f"Conda Checker failed due to installed learnware failed and {e}") - return BaseChecker.INVALID_LEARNWARE - return check_status + message = f"Conda Checker failed due to installed learnware failed and {e}" + logger.warning(message) + message += "\n" + traceback.format_exc() + return BaseChecker.INVALID_LEARNWARE, message + return check_status, message diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 6419b98..54aa0e5 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -3,6 +3,7 @@ import numpy as np import torch import random import string +import traceback from ..base import BaseChecker from ..utils import parse_specification_type @@ -50,11 +51,11 @@ class EasySemanticChecker(BaseChecker): assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" assert isinstance(v, str), "Description must be string" - return EasySemanticChecker.NONUSABLE_LEARNWARE + return EasySemanticChecker.NONUSABLE_LEARNWARE, 'EasySemanticChecker Success' except AssertionError as err: logger.warning(f"semantic_specification is not valid due to {err}!") - return EasySemanticChecker.INVALID_LEARNWARE + return EasySemanticChecker.INVALID_LEARNWARE, traceback.format_exc() def __call__(self, learnware): semantic_spec = learnware.get_specification().get_semantic_spec() @@ -88,7 +89,7 @@ class EasyStatChecker(BaseChecker): except Exception as e: traceback.print_exc() logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") - return self.INVALID_LEARNWARE + return self.INVALID_LEARNWARE, traceback.format_exc() try: learnware_model = learnware.get_model() # Check input shape @@ -97,19 +98,22 @@ class EasyStatChecker(BaseChecker): if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( int(semantic_spec["Input"]["Dimension"]), ): - logger.warning("input shapes of model and semantic specifications are different") - return self.INVALID_LEARNWARE + message = "input shapes of model and semantic specifications are different" + logger.warning(message) + return self.INVALID_LEARNWARE, message spec_type = parse_specification_type(learnware.get_specification().stat_spec) if spec_type is None: - logger.warning(f"No valid specification is found in stat spec {spec_type}") - return self.INVALID_LEARNWARE + message = f"No valid specification is found in stat spec {spec_type}" + logger.warning(message) + return self.INVALID_LEARNWARE, message if spec_type == "RKMETableSpecification": stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) if stat_spec.get_z().shape[1:] != input_shape: - logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") - return self.INVALID_LEARNWARE + message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification." + logger.warning(message) + return self.INVALID_LEARNWARE, message inputs = np.random.randn(10, *input_shape) elif spec_type == "RKMETextSpecification": inputs = EasyStatChecker._generate_random_text_list(10) @@ -122,16 +126,19 @@ class EasyStatChecker(BaseChecker): try: outputs = learnware.predict(inputs) except Exception: - logger.warning(f"learnware {learnware} prediction method is not valid!") - return self.INVALID_LEARNWARE + message = f"The learnware {learnware.id} prediction is not avaliable!" + logger.warning(message) + message += '\r\n' + traceback.format_exc() + return self.INVALID_LEARNWARE, message if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): # Check output type if isinstance(outputs, torch.Tensor): outputs = outputs.detach().cpu().numpy() if not isinstance(outputs, np.ndarray): - logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!") - return self.INVALID_LEARNWARE + message = f"The learnware {learnware.id} output must be np.ndarray or torch.Tensor!" + logger.warning(message) + return self.INVALID_LEARNWARE, message if outputs.ndim == 1: outputs = outputs.reshape(-1, 1) @@ -139,13 +146,13 @@ class EasyStatChecker(BaseChecker): if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != ( int(semantic_spec["Output"]["Dimension"]), ): - logger.warning( - f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" - ) - return self.INVALID_LEARNWARE + message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" + logger.warning(message) + return self.INVALID_LEARNWARE, message except Exception as e: - logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") - return self.INVALID_LEARNWARE + message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." + logger.warning(message) + return self.INVALID_LEARNWARE, message - return self.USABLE_LEARWARE + return self.USABLE_LEARWARE, "EasyStatChecker Success" diff --git a/learnware/market/easy/database_ops.py b/learnware/market/easy/database_ops.py index c9fb3de..a0b163c 100644 --- a/learnware/market/easy/database_ops.py +++ b/learnware/market/easy/database_ops.py @@ -167,6 +167,29 @@ class DatabaseOperations(object): pass pass + def get_learnware_info(self, id: str): + with self.engine.connect() as conn: + r = conn.execute( + text("SELECT semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware WHERE id=:id;"), + dict(id=id), + ) + row = r.fetchone() + if row is None: + return None + else: + semantic_spec = json.loads(row[0]) + zip_path = row[1] + folder_path = row[2] + use_flag = int(row[3]) + return { + "semantic_spec": semantic_spec, + "zip_path": zip_path, + "folder_path": folder_path, + "use_flag": use_flag, + } + pass + pass + def load_market(self): with self.engine.connect() as conn: cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) diff --git a/learnware/market/easy/organizer.py b/learnware/market/easy/organizer.py index 9337841..27b974d 100644 --- a/learnware/market/easy/organizer.py +++ b/learnware/market/easy/organizer.py @@ -3,7 +3,7 @@ import copy import zipfile import tempfile from shutil import copyfile, rmtree -from typing import Tuple, List, Union +from typing import Tuple, List, Union, Dict from .database_ops import DatabaseOperations from ..base import BaseOrganizer, BaseChecker @@ -392,5 +392,23 @@ class EasyOrganizer(BaseOrganizer): self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) pass + def get_learnware_info_from_storage(self, learnware_id: str) -> Dict: + """return learnware zip path and semantic_specification from storage + + Parameters + ---------- + learnware_id : str + learnware id + + Returns + ------- + Dict + - semantic_spec: semantic_specification + - zip_path: zip_path + - folder_path: folder_path + - use_flag: use_flag + """ + return self.dbops.get_learnware_info(learnware_id) + def __len__(self): return len(self.learnware_list) diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index 4421f91..5bf57fd 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -18,7 +18,7 @@ from tqdm import tqdm from . import cnn_gp from ..base import RegularStatsSpecification -from ..table.rkme import solve_qp, choose_device, setup_seed +from ..table.rkme import rkme_solve_qp, choose_device, setup_seed class RKMEImageSpecification(RegularStatsSpecification): @@ -97,6 +97,9 @@ class RKMEImageSpecification(RegularStatsSpecification): ------- """ + if len(X.shape) != 4: + raise ValueError("X should be in shape of [N, C, H, W]. ") + if ( X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH ) and not resize: @@ -175,7 +178,7 @@ class RKMEImageSpecification(RegularStatsSpecification): C = torch.sum(C, dim=1) / x_features.shape[0] if nonnegative_beta: - beta = solve_qp(K.double(), C.double()).to(self.device) + beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device) else: beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index 0e6fca6..a5c5297 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -29,6 +29,13 @@ class TestCheckLearnware(unittest.TestCase): self.client.download_learnware(learnware_id, self.zip_path) LearnwareClient.check_learnware(self.zip_path) + def test_check_learnware_dependency(self): + learnware_id = "00000147" + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + self.zip_path = os.path.join(tempdir, "test.zip") + self.client.download_learnware(learnware_id, self.zip_path) + LearnwareClient.check_learnware(self.zip_path) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_learnware_client/test_load_conda.py b/tests/test_learnware_client/test_load_conda.py index cd77f12..d343201 100644 --- a/tests/test_learnware_client/test_load_conda.py +++ b/tests/test_learnware_client/test_load_conda.py @@ -70,6 +70,18 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) + def test_load_single_learnware_by_id_pip(self): + learnware_id = "00000147" + learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env") + input_array = np.random.random(size=(20, 23)) + print(learnware.predict(input_array)) + + def test_load_single_learnware_by_id_conda(self): + learnware_id = "00000148" + learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env") + input_array = np.random.random(size=(20, 204)) + print(learnware.predict(input_array)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_learnware_client/test_load_docker.py b/tests/test_learnware_client/test_load_docker.py index 8b2cf6f..ac16afd 100644 --- a/tests/test_learnware_client/test_load_docker.py +++ b/tests/test_learnware_client/test_load_docker.py @@ -48,6 +48,18 @@ class TestLearnwareLoad(unittest.TestCase): learnware_list[0].get_model()._destroy_docker_container(docker_container) + def test_load_single_learnware_by_id_pip(self): + learnware_id = "00000147" + learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") + input_array = np.random.random(size=(20, 23)) + print(learnware.predict(input_array)) + + def test_load_single_learnware_by_id_conda(self): + learnware_id = "00000148" + learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") + input_array = np.random.random(size=(20, 204)) + print(learnware.predict(input_array)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_learnware_client/test_reuse.py b/tests/test_learnware_client/test_reuse.py index 1e5d2c3..b6b4485 100644 --- a/tests/test_learnware_client/test_reuse.py +++ b/tests/test_learnware_client/test_reuse.py @@ -4,7 +4,7 @@ import numpy as np from learnware.learnware import get_learnware_from_dirpath from learnware.client.container import LearnwaresContainer from learnware.reuse import AveragingReuser -from learnware.test.module import get_semantic_specification +from learnware.tests.module import get_semantic_specification if __name__ == "__main__": semantic_specification = get_semantic_specification()