Browse Source

Merge branch 'main' of https://github.com/Learnware-LAMDA/Learnware into xiey/dev

tags/v0.3.2
nju-xy 2 years ago
parent
commit
009ef6cd69
13 changed files with 141 additions and 44 deletions
  1. +2
    -2
      learnware/client/learnware_client.py
  2. +13
    -5
      learnware/client/package_utils.py
  3. +3
    -2
      learnware/client/utils.py
  4. +11
    -7
      learnware/market/base.py
  5. +6
    -4
      learnware/market/classes.py
  6. +27
    -20
      learnware/market/easy/checker.py
  7. +23
    -0
      learnware/market/easy/database_ops.py
  8. +19
    -1
      learnware/market/easy/organizer.py
  9. +5
    -2
      learnware/specification/regular/image/rkme.py
  10. +7
    -0
      tests/test_learnware_client/test_check_learnware.py
  11. +12
    -0
      tests/test_learnware_client/test_load_conda.py
  12. +12
    -0
      tests/test_learnware_client/test_load_docker.py
  13. +1
    -1
      tests/test_learnware_client/test_reuse.py

+ 2
- 2
learnware/client/learnware_client.py View File

@@ -381,14 +381,14 @@ class LearnwareClient:

@staticmethod
def _check_semantic_specification(semantic_spec):
return EasySemanticChecker.check_semantic_spec(semantic_spec) != BaseChecker.INVALID_LEARNWARE
return EasySemanticChecker.check_semantic_spec(semantic_spec)[0] != BaseChecker.INVALID_LEARNWARE

@staticmethod
def _check_stat_specification(learnware):
from ..market import CondaChecker

stat_checker = CondaChecker(inner_checker=EasyStatChecker())
return stat_checker(learnware) != BaseChecker.INVALID_LEARNWARE
return stat_checker(learnware)[0] != BaseChecker.INVALID_LEARNWARE

@staticmethod
def check_learnware(learnware_zip_path, semantic_specification=None):


+ 13
- 5
learnware/client/package_utils.py View File

@@ -74,12 +74,14 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
nonexist_packages = []
for package in packages:
try:
# os.system("python3 -m pip index versions {0}".format(package))
try_to_run(args=["pip", "index", "versions", parse_pip_requirement(package)], timeout=5)
exist_packages.append(package)
package_name = parse_pip_requirement(package)
if package_name != "learnware":
try_to_run(args=["pip", "index", "versions", package_name], timeout=5)
exist_packages.append(package)
continue
except Exception as e:
logger.error(e)
nonexist_packages.append(package)
nonexist_packages.append(package)

return exist_packages, nonexist_packages

@@ -105,7 +107,13 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str]

command = f"conda env create --name env_test --file {test_yaml_file} --dry-run --json"
result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output = json.loads(result.stdout.strip()).get("bad_deps", [])
stdout = result.stdout.strip()
last_bracket = stdout.rfind("\n{")
if last_bracket != -1:
stdout = stdout[last_bracket:]
pass
print(stdout)
output = json.loads(stdout).get("bad_deps", [])

if len(output) > 0:
exist_packages = []


+ 3
- 2
learnware/client/utils.py View File

@@ -25,8 +25,9 @@ def system_execute(args, timeout=None, env=None, stdout=subprocess.DEVNULL, stde
try:
com_process.check_returncode()
except subprocess.CalledProcessError as err:
logger.warning(f"System Execute Error: {com_process.stderr.decode()}")
raise err
errmsg = com_process.stderr.decode()
logger.warning(f"System Execute Error: {errmsg}")
raise Exception(errmsg)


def remove_enviroment(conda_env):


+ 11
- 7
learnware/market/base.py View File

@@ -84,7 +84,7 @@ class LearnwareMarket:
)
for name in checker_names:
checker = self.learnware_checker[name]
check_status = checker(pending_learnware)
check_status, message = checker(pending_learnware)
final_status = max(final_status, check_status)

if check_status == BaseChecker.INVALID_LEARNWARE:
@@ -447,7 +447,7 @@ class BaseChecker:
def reset(self, organizer):
self.learnware_organizer = organizer

def __call__(self, learnware: Learnware) -> int:
def __call__(self, learnware: Learnware) -> Tuple[int, str]:
"""Check the utility of a learnware

Parameters
@@ -456,11 +456,15 @@ class BaseChecker:

Returns
-------
int
A flag indicating whether the learnware can be accepted.
- The INVALID_LEARNWARE denotes the learnware does not pass the check
- The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
Tuple[int, str]:
flag and message of learnware check result
- int
A flag indicating whether the learnware can be accepted.
- The INVALID_LEARNWARE denotes the learnware does not pass the check
- The NOPREDICTION_LEARNWARE denotes the learnware pass the check but cannot make prediction due to some env dependency
- The NOPREDICTION_LEARNWARE denotes the leanrware pass the check and can make prediction
- str
A message indicating the reason of learnware check result
"""

raise NotImplementedError("'__call__' method is not implemented in BaseChecker")

+ 6
- 4
learnware/market/classes.py View File

@@ -16,9 +16,11 @@ class CondaChecker(BaseChecker):
try:
with LearnwaresContainer(learnware, ignore_error=False) as env_container:
learnwares = env_container.get_learnwares_with_container()
check_status = self.inner_checker(learnwares[0])
check_status, message = self.inner_checker(learnwares[0])
except Exception as e:
traceback.print_exc()
logger.warning(f"Conda Checker failed due to installed learnware failed and {e}")
return BaseChecker.INVALID_LEARNWARE
return check_status
message = f"Conda Checker failed due to installed learnware failed and {e}"
logger.warning(message)
message += "\n" + traceback.format_exc()
return BaseChecker.INVALID_LEARNWARE, message
return check_status, message

+ 27
- 20
learnware/market/easy/checker.py View File

@@ -3,6 +3,7 @@ import numpy as np
import torch
import random
import string
import traceback

from ..base import BaseChecker
from ..utils import parse_specification_type
@@ -50,11 +51,11 @@ class EasySemanticChecker(BaseChecker):
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"

return EasySemanticChecker.NONUSABLE_LEARNWARE
return EasySemanticChecker.NONUSABLE_LEARNWARE, 'EasySemanticChecker Success'

except AssertionError as err:
logger.warning(f"semantic_specification is not valid due to {err}!")
return EasySemanticChecker.INVALID_LEARNWARE
return EasySemanticChecker.INVALID_LEARNWARE, traceback.format_exc()

def __call__(self, learnware):
semantic_spec = learnware.get_specification().get_semantic_spec()
@@ -88,7 +89,7 @@ class EasyStatChecker(BaseChecker):
except Exception as e:
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE
return self.INVALID_LEARNWARE, traceback.format_exc()
try:
learnware_model = learnware.get_model()
# Check input shape
@@ -97,19 +98,22 @@ class EasyStatChecker(BaseChecker):
if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (
int(semantic_spec["Input"]["Dimension"]),
):
logger.warning("input shapes of model and semantic specifications are different")
return self.INVALID_LEARNWARE
message = "input shapes of model and semantic specifications are different"
logger.warning(message)
return self.INVALID_LEARNWARE, message

spec_type = parse_specification_type(learnware.get_specification().stat_spec)
if spec_type is None:
logger.warning(f"No valid specification is found in stat spec {spec_type}")
return self.INVALID_LEARNWARE
message = f"No valid specification is found in stat spec {spec_type}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")
return self.INVALID_LEARNWARE
message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification."
logger.warning(message)
return self.INVALID_LEARNWARE, message
inputs = np.random.randn(10, *input_shape)
elif spec_type == "RKMETextSpecification":
inputs = EasyStatChecker._generate_random_text_list(10)
@@ -122,16 +126,19 @@ class EasyStatChecker(BaseChecker):
try:
outputs = learnware.predict(inputs)
except Exception:
logger.warning(f"learnware {learnware} prediction method is not valid!")
return self.INVALID_LEARNWARE
message = f"The learnware {learnware.id} prediction is not avaliable!"
logger.warning(message)
message += '\r\n' + traceback.format_exc()
return self.INVALID_LEARNWARE, message

if semantic_spec["Task"]["Values"][0] in ("Classification", "Regression"):
# Check output type
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if not isinstance(outputs, np.ndarray):
logger.warning(f"The learnware [{learnware.id}] output must be np.ndarray or torch.Tensor!")
return self.INVALID_LEARNWARE
message = f"The learnware {learnware.id} output must be np.ndarray or torch.Tensor!"
logger.warning(message)
return self.INVALID_LEARNWARE, message

if outputs.ndim == 1:
outputs = outputs.reshape(-1, 1)
@@ -139,13 +146,13 @@ class EasyStatChecker(BaseChecker):
if outputs[0].shape != learnware_model.output_shape or learnware_model.output_shape != (
int(semantic_spec["Output"]["Dimension"]),
):
logger.warning(
f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
)
return self.INVALID_LEARNWARE
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

except Exception as e:
logger.warning(f"The learnware [{learnware.id}] prediction is not avaliable! Due to {repr(e)}.")
return self.INVALID_LEARNWARE
message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}."
logger.warning(message)
return self.INVALID_LEARNWARE, message

return self.USABLE_LEARWARE
return self.USABLE_LEARWARE, "EasyStatChecker Success"

+ 23
- 0
learnware/market/easy/database_ops.py View File

@@ -167,6 +167,29 @@ class DatabaseOperations(object):
pass
pass

def get_learnware_info(self, id: str):
with self.engine.connect() as conn:
r = conn.execute(
text("SELECT semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware WHERE id=:id;"),
dict(id=id),
)
row = r.fetchone()
if row is None:
return None
else:
semantic_spec = json.loads(row[0])
zip_path = row[1]
folder_path = row[2]
use_flag = int(row[3])
return {
"semantic_spec": semantic_spec,
"zip_path": zip_path,
"folder_path": folder_path,
"use_flag": use_flag,
}
pass
pass

def load_market(self):
with self.engine.connect() as conn:
cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;"))


+ 19
- 1
learnware/market/easy/organizer.py View File

@@ -3,7 +3,7 @@ import copy
import zipfile
import tempfile
from shutil import copyfile, rmtree
from typing import Tuple, List, Union
from typing import Tuple, List, Union, Dict

from .database_ops import DatabaseOperations
from ..base import BaseOrganizer, BaseChecker
@@ -392,5 +392,23 @@ class EasyOrganizer(BaseOrganizer):
self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id)
pass

def get_learnware_info_from_storage(self, learnware_id: str) -> Dict:
"""return learnware zip path and semantic_specification from storage

Parameters
----------
learnware_id : str
learnware id

Returns
-------
Dict
- semantic_spec: semantic_specification
- zip_path: zip_path
- folder_path: folder_path
- use_flag: use_flag
"""
return self.dbops.get_learnware_info(learnware_id)

def __len__(self):
return len(self.learnware_list)

+ 5
- 2
learnware/specification/regular/image/rkme.py View File

@@ -18,7 +18,7 @@ from tqdm import tqdm

from . import cnn_gp
from ..base import RegularStatsSpecification
from ..table.rkme import solve_qp, choose_device, setup_seed
from ..table.rkme import rkme_solve_qp, choose_device, setup_seed


class RKMEImageSpecification(RegularStatsSpecification):
@@ -97,6 +97,9 @@ class RKMEImageSpecification(RegularStatsSpecification):
-------

"""
if len(X.shape) != 4:
raise ValueError("X should be in shape of [N, C, H, W]. ")

if (
X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH
) and not resize:
@@ -175,7 +178,7 @@ class RKMEImageSpecification(RegularStatsSpecification):
C = torch.sum(C, dim=1) / x_features.shape[0]

if nonnegative_beta:
beta = solve_qp(K.double(), C.double()).to(self.device)
beta = rkme_solve_qp(K.double(), C.double())[0].to(self.device)
else:
beta = torch.linalg.inv(K + torch.eye(K.shape[0]).to(self.device) * 1e-5) @ C



+ 7
- 0
tests/test_learnware_client/test_check_learnware.py View File

@@ -29,6 +29,13 @@ class TestCheckLearnware(unittest.TestCase):
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

def test_check_learnware_dependency(self):
learnware_id = "00000147"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)


if __name__ == "__main__":
unittest.main()

+ 12
- 0
tests/test_learnware_client/test_load_conda.py View File

@@ -70,6 +70,18 @@ class TestLearnwareLoad(unittest.TestCase):
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_load_single_learnware_by_id_pip(self):
learnware_id = "00000147"
learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env")
input_array = np.random.random(size=(20, 23))
print(learnware.predict(input_array))

def test_load_single_learnware_by_id_conda(self):
learnware_id = "00000148"
learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda_env")
input_array = np.random.random(size=(20, 204))
print(learnware.predict(input_array))


if __name__ == "__main__":
unittest.main()

+ 12
- 0
tests/test_learnware_client/test_load_docker.py View File

@@ -48,6 +48,18 @@ class TestLearnwareLoad(unittest.TestCase):

learnware_list[0].get_model()._destroy_docker_container(docker_container)

def test_load_single_learnware_by_id_pip(self):
learnware_id = "00000147"
learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker")
input_array = np.random.random(size=(20, 23))
print(learnware.predict(input_array))

def test_load_single_learnware_by_id_conda(self):
learnware_id = "00000148"
learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker")
input_array = np.random.random(size=(20, 204))
print(learnware.predict(input_array))


if __name__ == "__main__":
unittest.main()

+ 1
- 1
tests/test_learnware_client/test_reuse.py View File

@@ -4,7 +4,7 @@ import numpy as np
from learnware.learnware import get_learnware_from_dirpath
from learnware.client.container import LearnwaresContainer
from learnware.reuse import AveragingReuser
from learnware.test.module import get_semantic_specification
from learnware.tests.module import get_semantic_specification

if __name__ == "__main__":
semantic_specification = get_semantic_specification()


Loading…
Cancel
Save