Browse Source

Merge pull request #18 from Learnware-LAMDA/test_learnware

[ENH] add check_learnware in client
tags/v0.3.2
Gene GitHub 2 years ago
parent
commit
b07eb0a4bb
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 132 deletions
  1. +14
    -98
      learnware/client/learnware_client.py
  2. +18
    -11
      learnware/market/easy.py
  3. +27
    -0
      tests/test_learnware_client/test_check_learnware.py
  4. +0
    -23
      tests/test_learnware_client/test_learnware.py

+ 14
- 98
learnware/client/learnware_client.py View File

@@ -19,7 +19,8 @@ from .container import LearnwaresContainer
from ..market.easy import EasyMarket
from ..logger import get_module_logger
from ..specification import Specification
from ..learnware import BaseReuser, Learnware
from ..learnware import BaseReuser, Learnware, get_learnware_from_dirpath
from ..test import get_semantic_specification

CHUNK_SIZE = 1024 * 1024
logger = get_module_logger(module_name="LearnwareClient")
@@ -264,29 +265,6 @@ class LearnwareClient:
raise Exception("delete failed: " + json.dumps(result))
pass

def check_learnware(self, path, semantic_specification):
if os.path.isfile(path):
with tempfile.TemporaryDirectory() as tempdir:
with zipfile.ZipFile(path, "r") as z_file:
z_file.extractall(tempdir)
pass
return self.check_learnware_folder(tempdir, semantic_specification)
pass
else:
return self.check_learnware_folder(path, semantic_specification)
pass
pass

def check_learnware_folder(self, folder, semantic_specification):
learnware_obj = learnware.get_learnware_from_dirpath("test_id", semantic_specification, folder)

check_result = EasyMarket.check_learnware(learnware_obj)
if check_result == EasyMarket.USABLE_LEARWARE:
return True
else:
return False
pass

def create_semantic_specification(
self, name, description, data_type, task_type, library_type, senarioes, input_description, output_description
):
@@ -409,90 +387,28 @@ class LearnwareClient:
else:
return learnware_list

def system(self, command):
retcd = os.system(command)
if retcd != 0:
raise RuntimeError(f"Command {command} failed with return code {retcd}")
pass

def install_environment(self, zip_path, conda_env=None):
"""Install environment of a learnware

Parameters
----------
zip_path : str
Path of the learnware zip file
conda_env : optional
If it is not None, a new conda environment will be created with the given name;
If it is None, use current environment.

Raises
------
Exception
Lack of the environment configuration file.
"""
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(zip_path, "r") as z_file:
logger.info(f"zip_file namelist: {z_file.namelist}")
if "environment.yaml" in z_file.namelist():
z_file.extract("environment.yaml", tempdir)
yaml_path = os.path.join(tempdir, "environment.yaml")
yaml_path_filter = os.path.join(tempdir, "environment_filter.yaml")
package_utils.filter_nonexist_conda_packages_file(yaml_path, yaml_path_filter)
# create environment
if conda_env is not None:
self.system(f"conda env update --name {conda_env} --file {yaml_path_filter}")
pass
else:
self.system(f"conda env update --file {yaml_path_filter}")
pass
pass
elif "requirements.txt" in z_file.namelist():
z_file.extract("requirements.txt", tempdir)
requirements_path = os.path.join(tempdir, "requirements.txt")
requirements_path_filter = os.path.join(tempdir, "requirements_filter.txt")
package_utils.filter_nonexist_pip_packages_file(requirements_path, requirements_path_filter)

if conda_env is not None:
self.system(f"conda create -y --name {conda_env} python=3.8")
self.system(
f"conda run --name {conda_env} --no-capture-output python3 -m pip install -r {requirements_path_filter}"
)
else:
self.system(f"python3 -m pip install -r {requirements_path_filter}")
pass
pass
else:
raise Exception("Environment.yaml or requirements.txt not found in the learnware zip file.")
pass
pass
pass

def test_learnware(self, zip_path, semantic_specification=None):
@staticmethod
def check_learnware(zip_path, semantic_specification=None):
if semantic_specification is None:
semantic_specification = dict()
pass
semantic_specification = get_semantic_specification()

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)
pass

learnware_obj = learnware.get_learnware_from_dirpath("test_id", semantic_specification, tempdir)
learnware = get_learnware_from_dirpath(
id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir
)

if learnware_obj is None:
if learnware is None:
raise Exception("The learnware is not valid.")

learnware_obj.instantiate_model()

if len(semantic_specification) > 0:
if EasyMarket.check_learnware(learnware_obj) != EasyMarket.USABLE_LEARWARE:
with LearnwaresContainer(learnware, zip_path) as env_container:
learnware = env_container.get_learnwares_with_container()[0]
if EasyMarket.check_learnware(learnware) == EasyMarket.USABLE_LEARWARE:
logger.info("The learnware passed the local test.")
else:
raise Exception("The learnware is not usable.")
pass
pass

logger.info("test ok")
pass

def cleanup(self):
for tempdir in self.tempdir_list:


+ 18
- 11
learnware/market/easy.py View File

@@ -699,6 +699,7 @@ class EasyMarket(BaseMarket):
List[Learnware]
The list of returned learnwares
"""

def _match_semantic_spec_tag(semantic_spec1, semantic_spec2) -> bool:
"""Judge if tags of two semantic specs are consistent

@@ -737,8 +738,8 @@ class EasyMarket(BaseMarket):
elif semantic_spec1[key]["Type"] == "Tag":
if not (set(v1) & set(v2)):
return False
return True
return True
matched_learnware_tag = []
final_result = []
user_semantic_spec = user_info.get_semantic_spec()
@@ -747,15 +748,21 @@ class EasyMarket(BaseMarket):
learnware_semantic_spec = learnware.get_specification().get_semantic_spec()
if _match_semantic_spec_tag(user_semantic_spec, learnware_semantic_spec):
matched_learnware_tag.append(learnware)
if len(matched_learnware_tag) > 0:
if "Name" in user_semantic_spec:
name_user = user_semantic_spec["Name"]["Values"].lower()
if len(name_user) > 0:
# Exact search
name_list = [learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower() for learnware in matched_learnware_tag]
des_list = [learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower() for learnware in matched_learnware_tag]
name_list = [
learnware.get_specification().get_semantic_spec()["Name"]["Values"].lower()
for learnware in matched_learnware_tag
]
des_list = [
learnware.get_specification().get_semantic_spec()["Description"]["Values"].lower()
for learnware in matched_learnware_tag
]

matched_learnware_exact = []
for i in range(len(name_list)):
if name_user in name_list[i] or name_user in des_list[i]:
@@ -771,9 +778,11 @@ class EasyMarket(BaseMarket):
if final_score >= min_score:
matched_learnware_fuzz.append(matched_learnware_tag[i])
fuzz_scores.append(final_score)
# Sort by score
sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[:max_num]
sort_idx = sorted(list(range(len(fuzz_scores))), key=lambda k: fuzz_scores[k], reverse=True)[
:max_num
]
final_result = [matched_learnware_fuzz[idx] for idx in sort_idx]
else:
final_result = matched_learnware_exact
@@ -782,9 +791,7 @@ class EasyMarket(BaseMarket):
else:
final_result = matched_learnware_tag

logger.info(
"semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))
)
logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list)))
return final_result

def search_learnware(


+ 27
- 0
tests/test_learnware_client/test_check_learnware.py View File

@@ -0,0 +1,27 @@
import os
import unittest
import tempfile


from learnware.client import LearnwareClient


class TestCheckLearnware(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUpClass()
email = "liujd@lamda.nju.edu.cn"
token = "f7e647146a314c6e8b4e2e1079c4bca4"

self.client = LearnwareClient()
self.client.login(email, token)
self.learnware_id = "00000154"

def test_check_learnware(self):
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(self.learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)


if __name__ == "__main__":
unittest.main()

+ 0
- 23
tests/test_learnware_client/test_learnware.py View File

@@ -1,23 +0,0 @@
import os
import zipfile
import tempfile
from learnware.learnware import get_learnware_from_dirpath
from learnware.test import get_semantic_specification
from learnware.client.container import LearnwaresContainer
from learnware.market import EasyMarket

if __name__ == "__main__":
semantic_specification = get_semantic_specification()

zip_path = "rf_tic.zip"
with tempfile.TemporaryDirectory(suffix="learnware") as tempdir:
learnware_dirpath = os.path.join(tempdir, "test")
with zipfile.ZipFile(zip_path, "r") as z_file:
z_file.extractall(learnware_dirpath)
learnware = get_learnware_from_dirpath(
id="test", semantic_spec=semantic_specification, learnware_dirpath=learnware_dirpath
)

with LearnwaresContainer(learnware, zip_path) as env_container:
learnware = env_container.get_learnwares_with_container()[0]
print(EasyMarket.check_learnware(learnware))

Loading…
Cancel
Save