| @@ -381,14 +381,14 @@ class LearnwareClient: | |||
| @staticmethod | |||
| def _check_semantic_specification(semantic_spec): | |||
| return EasySemanticChecker.check_semantic_spec(semantic_spec) != BaseChecker.INVALID_LEARNWARE | |||
| return EasySemanticChecker.check_semantic_spec(semantic_spec)[0] != BaseChecker.INVALID_LEARNWARE | |||
| @staticmethod | |||
| def _check_stat_specification(learnware): | |||
| from ..market import CondaChecker | |||
| stat_checker = CondaChecker(inner_checker=EasyStatChecker()) | |||
| return stat_checker(learnware) != BaseChecker.INVALID_LEARNWARE | |||
| return stat_checker(learnware)[0] != BaseChecker.INVALID_LEARNWARE | |||
| @staticmethod | |||
| def check_learnware(learnware_zip_path, semantic_specification=None): | |||
| @@ -74,12 +74,14 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| nonexist_packages = [] | |||
| for package in packages: | |||
| try: | |||
| # os.system("python3 -m pip index versions {0}".format(package)) | |||
| try_to_run(args=["pip", "index", "versions", parse_pip_requirement(package)], timeout=5) | |||
| exist_packages.append(package) | |||
| package_name = parse_pip_requirement(package) | |||
| if package_name != "learnware": | |||
| try_to_run(args=["pip", "index", "versions", package_name], timeout=5) | |||
| exist_packages.append(package) | |||
| continue | |||
| except Exception as e: | |||
| logger.error(e) | |||
| nonexist_packages.append(package) | |||
| nonexist_packages.append(package) | |||
| return exist_packages, nonexist_packages | |||
| @@ -105,7 +107,13 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str] | |||
| command = f"conda env create --name env_test --file {test_yaml_file} --dry-run --json" | |||
| result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |||
| output = json.loads(result.stdout.strip()).get("bad_deps", []) | |||
| stdout = result.stdout.strip() | |||
| last_bracket = stdout.rfind("\n{") | |||
| if last_bracket != -1: | |||
| stdout = stdout[last_bracket:] | |||
| pass | |||
| print(stdout) | |||
| output = json.loads(stdout).get("bad_deps", []) | |||
| if len(output) > 0: | |||
| exist_packages = [] | |||
| @@ -25,8 +25,9 @@ def system_execute(args, timeout=None, env=None, stdout=subprocess.DEVNULL, stde | |||
| try: | |||
| com_process.check_returncode() | |||
| except subprocess.CalledProcessError as err: | |||
| logger.warning(f"System Execute Error: {com_process.stderr.decode()}") | |||
| raise err | |||
| errmsg = com_process.stderr.decode() | |||
| logger.warning(f"System Execute Error: {errmsg}") | |||
| raise Exception(errmsg) | |||
| def remove_enviroment(conda_env): | |||
| @@ -84,7 +84,7 @@ class LearnwareMarket: | |||
| ) | |||
| for name in checker_names: | |||
| checker = self.learnware_checker[name] | |||
| check_status = checker(pending_learnware) | |||
| check_status, message = checker(pending_learnware) | |||
| final_status = max(final_status, check_status) | |||
| if check_status == BaseChecker.INVALID_LEARNWARE: | |||
| @@ -447,7 +447,7 @@ class BaseChecker: | |||
| def reset(self, organizer): | |||
| self.learnware_organizer = organizer | |||
| def __call__(self, learnware: Learnware) -> int: | |||
| def __call__(self, learnware: Learnware) -> Tuple[int, str]: | |||
| """Check the utility of a learnware | |||
| Parameters | |||
| @@ -456,11 +456,15 @@ class BaseChecker: | |||
| Returns | |||
| ------- | |||
| int | |||
| A flag indicating whether the learnware can be accepted. | |||
| - The INVALID_LEARNWARE denotes the learnware does not pass the check | |||
| - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency | |||
| - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction | |||
| Tuple[int, str]: | |||
| flag and message of learnware check result | |||
| - int | |||
| A flag indicating whether the learnware can be accepted. | |||
| - The INVALID_LEARNWARE denotes the learnware does not pass the check | |||
| - The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency | |||
| - The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction | |||
| - str | |||
| A message indicating the reason of learnware check result | |||
| """ | |||
| raise NotImplementedError("'__call__' method is not implemented in BaseChecker") | |||
| @@ -16,9 +16,11 @@ class CondaChecker(BaseChecker): | |||
| try: | |||
| with LearnwaresContainer(learnware, ignore_error=False) as env_container: | |||
| learnwares = env_container.get_learnwares_with_container() | |||
| check_status = self.inner_checker(learnwares[0]) | |||
| check_status, message = self.inner_checker(learnwares[0]) | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.warning(f"Conda Checker failed due to installed learnware failed and {e}") | |||
| return BaseChecker.INVALID_LEARNWARE | |||
| return check_status | |||
| message = f"Conda Checker failed due to installed learnware failed and {e}" | |||
| logger.warning(message) | |||
| message += "\n" + traceback.format_exc() | |||
| return BaseChecker.INVALID_LEARNWARE, message | |||
| return check_status, message | |||
| @@ -3,6 +3,7 @@ import numpy as np | |||
| import torch | |||
| import random | |||
| import string | |||
| import traceback | |||
| from ..base import BaseChecker | |||
| from ..utils import parse_specification_type | |||
| @@ -50,11 +51,11 @@ class EasySemanticChecker(BaseChecker): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| return EasySemanticChecker.NONUSABLE_LEARNWARE | |||
| return EasySemanticChecker.NONUSABLE_LEARNWARE, 'EasySemanticChecker Success' | |||
| except AssertionError as err: | |||
| logger.warning(f"semantic_specification is not valid due to {err}!") | |||
| return EasySemanticChecker.INVALID_LEARNWARE | |||
| return EasySemanticChecker.INVALID_LEARNWARE, traceback.format_exc() | |||
| def __call__(self, learnware): | |||
| semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| @@ -88,7 +89,7 @@ class EasyStatChecker(BaseChecker): | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | |||
| return self.INVALID_LEARNWARE | |||
| return self.INVALID_LEARNWARE, traceback.format_exc() | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # Check input shape | |||
| @@ -97,19 +98,22 @@ class EasyStatChecker(BaseChecker): | |||
| if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | |||
| int(semantic_spec["Input"]["Dimension"]), | |||
| ): | |||
| logger.warning("input shapes of model and semantic specifications are different") | |||
| return self.INVALID_LEARNWARE | |||
| message = "input shapes of model and semantic specifications are different" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| spec_type = parse_specification_type(learnware.get_specification().stat_spec) | |||
| if spec_type is None: | |||
| logger.warning(f"No valid specification is found in stat spec {spec_type}") | |||
| return self.INVALID_LEARNWARE | |||
| message = f"No valid specification is found in stat spec {spec_type}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| if spec_type == "RKMETableSpecification": | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | |||
| return self.INVALID_LEARNWARE | |||
| message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification." | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| inputs = np.random.randn(10, *input_shape) | |||
| elif spec_type == "RKMETextSpecification": | |||
| inputs = EasyStatChecker._generate_random_text_list(10) | |||
| @@ -122,16 +126,19 @@ class EasyStatChecker(BaseChecker): | |||
| try: | |||
| outputs = learnware.predict(inputs) | |||
| except Exception: | |||
| logger.warning(f"learnware {learnware} prediction method is not valid!") | |||
| return self.INVALID_LEARNWARE | |||
| message = f"The learnware {learnware.id} prediction is not avaliable!" | |||
| logger.warning(message) | |||
| message += '\r\n' + traceback.format_exc() | |||
| return self.INVALID_LEARNWARE, message | |||
| if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"): | |||
| # Check output type | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| if not isinstance(outputs, np.ndarray): | |||
| logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!") | |||
| return self.INVALID_LEARNWARE | |||
| message = f"The learnware {learnware.id} output must be np.ndarray or torch.Tensor!" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| if outputs.ndim == 1: | |||
| outputs = outputs.reshape(-1, 1) | |||
| @@ -139,13 +146,13 @@ class EasyStatChecker(BaseChecker): | |||
| if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != ( | |||
| int(semantic_spec["Output"]["Dimension"]), | |||
| ): | |||
| logger.warning( | |||
| f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||
| ) | |||
| return self.INVALID_LEARNWARE | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| except Exception as e: | |||
| logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.") | |||
| return self.INVALID_LEARNWARE | |||
| message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| return self.USABLE_LEARWARE | |||
| return self.USABLE_LEARWARE, "EasyStatChecker Success" | |||
| @@ -167,6 +167,29 @@ class DatabaseOperations(object): | |||
| pass | |||
| pass | |||
| def get_learnware_info(self, id: str): | |||
| with self.engine.connect() as conn: | |||
| r = conn.execute( | |||
| text("SELECT semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware WHERE id=:id;"), | |||
| dict(id=id), | |||
| ) | |||
| row = r.fetchone() | |||
| if row is None: | |||
| return None | |||
| else: | |||
| semantic_spec = json.loads(row[0]) | |||
| zip_path = row[1] | |||
| folder_path = row[2] | |||
| use_flag = int(row[3]) | |||
| return { | |||
| "semantic_spec": semantic_spec, | |||
| "zip_path": zip_path, | |||
| "folder_path": folder_path, | |||
| "use_flag": use_flag, | |||
| } | |||
| pass | |||
| pass | |||
| def load_market(self): | |||
| with self.engine.connect() as conn: | |||
| cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) | |||
| @@ -3,7 +3,7 @@ import copy | |||
| import zipfile | |||
| import tempfile | |||
| from shutil import copyfile, rmtree | |||
| from typing import Tuple, List, Union | |||
| from typing import Tuple, List, Union, Dict | |||
| from .database_ops import DatabaseOperations | |||
| from ..base import BaseOrganizer, BaseChecker | |||
| @@ -392,5 +392,23 @@ class EasyOrganizer(BaseOrganizer): | |||
| self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) | |||
| pass | |||
| def get_learnware_info_from_storage(self, learnware_id: str) -> Dict: | |||
| """return learnware zip path and semantic_specification from storage | |||
| Parameters | |||
| ---------- | |||
| learnware_id : str | |||
| learnware id | |||
| Returns | |||
| ------- | |||
| Dict | |||
| - semantic_spec: semantic_specification | |||
| - zip_path: zip_path | |||
| - folder_path: folder_path | |||
| - use_flag: use_flag | |||
| """ | |||
| return self.dbops.get_learnware_info(learnware_id) | |||
| def __len__(self): | |||
| return len(self.learnware_list) | |||
| @@ -18,7 +18,7 @@ from tqdm import tqdm | |||
| from . import cnn_gp | |||
| from ..base import RegularStatsSpecification | |||
| from ..table.rkme import solve_qp, choose_device, setup_seed | |||
| from ..table.rkme import rkme_solve_qp, choose_device, setup_seed | |||
| class RKMEImageSpecification(RegularStatsSpecification): | |||
| @@ -97,6 +97,9 @@ class RKMEImageSpecification(RegularStatsSpecification): | |||
| ------- | |||
| """ | |||
| if len(X.shape) != 4: | |||
| raise ValueError("X should be in shape of [N, C, H, W]. ") | |||
| if ( | |||
| X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH | |||
| ) and not resize: | |||
| @@ -175,7 +178,7 @@ class RKMEImageSpecification(RegularStatsSpecification): | |||
| C = torch.sum(C, dim=1) / x_features.shape[0] | |||
| if nonnegative_beta: | |||
| beta = solve_qp(K.double(), C.double()).to(self.device) | |||
| beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device) | |||
| else: | |||
| beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C | |||
| @@ -29,6 +29,13 @@ class TestCheckLearnware(unittest.TestCase): | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| def test_check_learnware_dependency(self): | |||
| learnware_id = "00000147" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -70,6 +70,18 @@ class TestLearnwareLoad(unittest.TestCase): | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_pip(self): | |||
| learnware_id = "00000147" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env") | |||
| input_array = np.random.random(size=(20, 23)) | |||
| print(learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_conda(self): | |||
| learnware_id = "00000148" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env") | |||
| input_array = np.random.random(size=(20, 204)) | |||
| print(learnware.predict(input_array)) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -48,6 +48,18 @@ class TestLearnwareLoad(unittest.TestCase): | |||
| learnware_list[0].get_model()._destroy_docker_container(docker_container) | |||
| def test_load_single_learnware_by_id_pip(self): | |||
| learnware_id = "00000147" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") | |||
| input_array = np.random.random(size=(20, 23)) | |||
| print(learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_conda(self): | |||
| learnware_id = "00000148" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") | |||
| input_array = np.random.random(size=(20, 204)) | |||
| print(learnware.predict(input_array)) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -4,7 +4,7 @@ import numpy as np | |||
| from learnware.learnware import get_learnware_from_dirpath | |||
| from learnware.client.container import LearnwaresContainer | |||
| from learnware.reuse import AveragingReuser | |||
| from learnware.test.module import get_semantic_specification | |||
| from learnware.tests.module import get_semantic_specification | |||
| if __name__ == "__main__": | |||
| semantic_specification = get_semantic_specification() | |||