diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 93a3fa3..5f69127 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,7 +85,9 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[split:,], + train_xs[ + split:, + ], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/client/container.py b/learnware/client/container.py index 979daf1..50fe0ed 100644 --- a/learnware/client/container.py +++ b/learnware/client/container.py @@ -1,5 +1,6 @@ import os import pickle +import atexit import tempfile import shortuuid from concurrent.futures import ProcessPoolExecutor @@ -127,6 +128,12 @@ class LearnwaresContainer: for _learnware, _zippath in zip(learnware_list, learnware_zippaths) ] + model_list = [_learnware.get_model() for _learnware in self.learnware_list] + with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: + executor.map(self._initialize_model_container, model_list) + + atexit.register(self.cleanup) + @staticmethod def _initialize_model_container(model: ModelEnvContainer): model.init_env_and_metadata() @@ -135,16 +142,9 @@ class LearnwaresContainer: def _destroy_model_container(model: ModelEnvContainer): model.remove_env() - def __enter__(self): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._initialize_model_container, model_list) - return self - - def __exit__(self, type, value, trace): - model_list = [_learnware.get_model() for _learnware in self.learnware_list] - with ProcessPoolExecutor(max_workers=max(os.cpu_count() // 2, 1)) as executor: - executor.map(self._destroy_model_container, model_list) - def get_learnware_list_with_container(self): return self.learnware_list + + def cleanup(self): + for _learnware in self.learnware_list: + self._destroy_model_container(_learnware.get_model()) diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 2982cdf..8f07408 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -1,14 +1,16 @@ import os -import numpy as np +import uuid import yaml import json +import atexit import zipfile import hashlib import requests import tempfile +import numpy as np from enum import Enum from tqdm import tqdm -from typing import List +from typing import Union, List from ..config import C from .. import learnware @@ -68,10 +70,10 @@ class LearnwareClient: self.host = C.backend_host else: self.host = host - pass self.chunk_size = 1024 * 1024 - pass + self.tempdir_list = [] + atexit.register(self.cleanup) def login(self, email, token): url = f"{self.host}/auth/login_by_token" @@ -308,44 +310,104 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] - def load_learnware(self, learnware_file: str, load_model: bool = True): - with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - with zipfile.ZipFile(learnware_file, "r") as z_file: + def load_learnware( + self, + learnware_path: Union[str, List[str]] = None, + learnware_id: Union[str, List[str]] = None, + runnable_option: str = None, + ): + """Load learnware by learnware zip file or learnware id (zip file has higher priority) + + Parameters + ---------- + learnware_path : Union[str, List[str]] + learnware zip path or learnware zip path list + learnware_id : Union[str, List[str]] + learnware id or learnware id list + runnable_option : str + the option for instantiating learnwares + - "normal": instantiate learnware without installing environment + - "conda_env": instantiate learnware with installing conda virtual environment + + Returns + ------- + Learnware + The contructed learnware object or object list + """ + if runnable_option is not None and runnable_option not in ["normal", "conda_env"]: + raise logger.warning(f"runnable_option must be one of ['normal', 'conda_env'], but got {runnable_option}") + + if learnware_path is None and learnware_id is None: + raise ValueError("Requires one of learnware_path or learnware_id") + + def _get_learnware_by_id(_learnware_id): + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name + zip_path = os.path.join(tempdir, f"{str(uuid.uuid4())}.zip") + self.download_learnware(_learnware_id, zip_path) + return zip_path, _get_learnware_by_path(zip_path, tempdir=tempdir) + + def _get_learnware_by_path(_learnware_zippath, tempdir=None): + if tempdir is None: + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name + + with zipfile.ZipFile(_learnware_zippath, "r") as z_file: z_file.extractall(tempdir) - pass yaml_file = C.learnware_folder_config["yaml_file"] - with open(os.path.join(tempdir, yaml_file), "r") as fin: learnware_info = yaml.safe_load(fin) - pass learnware_id = learnware_info.get("id") if learnware_id is None: learnware_id = "test_id" - pass semantic_specification = learnware_info.get("semantic_specification") if semantic_specification is None: semantic_specification = {} - pass else: semantic_file = semantic_specification.get("file_name") with open(os.path.join(tempdir, semantic_file), "r") as fin: semantic_specification = json.load(fin) - pass - pass - learnware_obj = learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - - if load_model: - learnware_obj.instantiate_model() - pass - - return learnware_obj - pass - pass + return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) + + learnware_list = [] + zip_paths = [] + if learnware_path is not None: + if isinstance(learnware_path, str): + zip_paths = [learnware_path] + elif isinstance(learnware_path, list): + zip_paths = learnware_path + + for zip_path in zip_paths: + learnware_obj = _get_learnware_by_path(zip_path) + learnware_list.append(learnware_obj) + elif learnware_id is not None: + if isinstance(learnware_id, str): + id_list = [learnware_id] + elif isinstance(learnware_id, list): + id_list = learnware_id + + for idx in id_list: + zip_path, learnware_obj = _get_learnware_by_id(idx) + zip_paths.append(zip_path) + learnware_list.append(learnware_obj) + + if runnable_option is not None: + if runnable_option == "normal": + for i in range(len(learnware_list)): + learnware_list[i].instantiate_model() + elif runnable_option == "conda_env": + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() + + if len(learnware_list) == 1: + return learnware_list[0] + else: + return learnware_list def system(self, command): retcd = os.system(command) @@ -392,9 +454,9 @@ class LearnwareClient: package_utils.filter_nonexist_pip_packages_file(requirements_path, requirements_path_filter) if conda_env is not None: - self.system(f"conda create --name {conda_env}") + self.system(f"conda create -y --name {conda_env} python=3.8") self.system( - f"conda run --no-capture-output python3 -m pip install -r {requirements_path_filter}" + f"conda run --name {conda_env} --no-capture-output python3 -m pip install -r {requirements_path_filter}" ) else: self.system(f"python3 -m pip install -r {requirements_path_filter}") @@ -432,17 +494,6 @@ class LearnwareClient: logger.info("test ok") pass - def reuse_learnware( - self, - input_array: np.ndarray, - learnware_list: List[Learnware], - learnware_zippaths: List[str], - reuser: BaseReuser, - ): - logger.info(f"reuse learnare list {learnware_list} with reuser {reuser}") - with LearnwaresContainer(learnware_list, learnware_zippaths) as env_container: - learnware_list = env_container.get_learnware_list_with_container() - reuser.reset(learnware_list=learnware_list) - result = reuser.predict(input_array) - - return result + def cleanup(self): + for tempdir in self.tempdir_list: + tempdir.cleanup() diff --git a/learnware/client/utils.py b/learnware/client/utils.py index b6d9c8a..a48fe45 100644 --- a/learnware/client/utils.py +++ b/learnware/client/utils.py @@ -61,12 +61,14 @@ def install_environment(zip_path, conda_env): requirements_file=requirements_path, output_file=requirements_path_filter ) logger.info(f"create empty conda env [{conda_env}]") - system_execute(args=["conda", "create", "--name", f"{conda_env}", "python=3.8"]) + system_execute(args=["conda", "create", "-y", "--name", f"{conda_env}", "python=3.8"]) logger.info(f"install pip requirements for conda env [{conda_env}]") system_execute( args=[ "conda", "run", + "-n", + f"{conda_env}", "--no-capture-output", "python3", "-m", @@ -80,4 +82,17 @@ def install_environment(zip_path, conda_env): raise Exception("Environment.yaml or requirements.txt not found in the learnware zip file.") logger.info(f"install learnware package for conda env [{conda_env}]") - system_execute(args=["conda", "run", "--no-capture-output", "python3", "-m", "pip", "install", "learnware"]) + system_execute( + args=[ + "conda", + "run", + "-n", + f"{conda_env}", + "--no-capture-output", + "python3", + "-m", + "pip", + "install", + "learnware", + ] + ) diff --git a/tests/test_learnware_upload/test_upload.py b/tests/test_client/test_learnware.py similarity index 100% rename from tests/test_learnware_upload/test_upload.py rename to tests/test_client/test_learnware.py diff --git a/tests/test_client/test_load.py b/tests/test_client/test_load.py new file mode 100644 index 0000000..67981dc --- /dev/null +++ b/tests/test_client/test_load.py @@ -0,0 +1,75 @@ +import os +import unittest +import zipfile +import numpy as np + +import learnware +from learnware.learnware import get_learnware_from_dirpath +from learnware.client import LearnwareClient +from learnware.client.container import ModelEnvContainer, LearnwaresContainer +from learnware.learnware.reuse import AveragingReuser + + +class TestLearnwareLoad(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUpClass() + email = "liujd@lamda.nju.edu.cn" + token = "f7e647146a314c6e8b4e2e1079c4bca4" + + self.client = LearnwareClient() + self.client.login(email, token) + + root = os.path.dirname(__file__) + self.learnware_ids = ["00000084", "00000154", "00000155"] + self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] + + def test_load_single_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = [ + self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") + for zippath in self.zip_paths + ] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_multi_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_single_learnware_by_id(self): + learnware_list = [ + self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids + ] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_multi_learnware_by_id(self): + learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_client/test_reuse.py b/tests/test_client/test_reuse.py new file mode 100644 index 0000000..5e84f5d --- /dev/null +++ b/tests/test_client/test_reuse.py @@ -0,0 +1,42 @@ +import zipfile +import numpy as np + +from learnware.learnware import get_learnware_from_dirpath, Learnware +from learnware.market import EasyMarket +from learnware.client.container import ModelEnvContainer, LearnwaresContainer +from learnware.learnware.reuse import AveragingReuser + +if __name__ == "__main__": + semantic_specification = dict() + semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} + semantic_specification["Task"] = {"Type": "Class", "Values": ["Ranking"]} + semantic_specification["Library"] = {"Type": "Class", "Values": ["Scikit-learn"]} + semantic_specification["Scenario"] = {"Type": "Tag", "Values": "Financial"} + semantic_specification["Name"] = {"Type": "String", "Values": "test"} + semantic_specification["Description"] = {"Type": "String", "Values": "test"} + + zip_paths = [ + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic.zip", + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic.zip", + ] + dir_paths = [ + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic", + "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic", + ] + + learnware_list = [] + for id, (zip_path, dir_path) in enumerate(zip(zip_paths, dir_paths)): + with zipfile.ZipFile(zip_path, "r") as z_file: + z_file.extractall(dir_path) + + learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) + learnware_list.append(learnware) + + with LearnwaresContainer(learnware_list, zip_paths) as env_container: + + learnware_list = env_container.get_learnware_list_with_container() + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.randint(0, 3, size=(20, 9)) + print(reuser.predict(input_array).argmax(axis=1)) + for id, ind_learner in enumerate(learnware_list): + print(f"learner_{id}", reuser.predict(input_array).argmax(axis=1))