diff --git a/learnware/client/learnware_client.py b/learnware/client/learnware_client.py index 47e93f0..d3d79c9 100644 --- a/learnware/client/learnware_client.py +++ b/learnware/client/learnware_client.py @@ -1,5 +1,5 @@ import os -import numpy as np +import uuid import yaml import json import atexit @@ -7,6 +7,7 @@ import zipfile import hashlib import requests import tempfile +import numpy as np from enum import Enum from tqdm import tqdm from typing import Union, List @@ -309,31 +310,44 @@ class LearnwareClient: return semantic_conf[key.value]["Values"] - def load_learnware(self, learnware_file: Union[str, List[str]], load_option: str = "conda_env"): - """Load learnware + def load_learnware(self, learnware_path: Union[str, List[str]] = None, learnware_id: Union[str, List[str]] = None, runnable_option: str = None): + """Load learnware by learnware zip file or learnware id (zip file has higher priority) Parameters ---------- - learnware_file : Union[str, List[str]] + learnware_path : Union[str, List[str]] learnware zip path or learnware zip path list - load_option : str - the option for loading learnwares - - "normal": load learnware without installing environment - - "conda_env": load learnware with installing conda virtual environment + learnware_id : Union[str, List[str]] + learnware id or learnware id list + runnable_option : str + the option for instantiating learnwares + - "normal": instantiate learnware without installing environment + - "conda_env": instantiate learnware with installing conda virtual environment Returns ------- Learnware The contructed learnware object or object list """ - if load_option not in ["normal", "conda_env"]: - raise ValueError(f"load_option must be one of ['normal', 'conda_env'], but got {load_option}") + if runnable_option is not None and runnable_option not in ["normal", "conda_env"]: + raise logger.warning(f"runnable_option must be one of ['normal', 'conda_env'], but got {runnable_option}") + + if learnware_path is None and learnware_id is None: + raise ValueError("Requires one of learnware_path or learnware_id") - def _get_learnware_obj(learnware_zippath): + def _get_learnware_by_id(_learnware_id): self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) tempdir = self.tempdir_list[-1].name + zip_path = os.path.join(tempdir, f"{str(uuid.uuid4())}.zip") + self.download_learnware(_learnware_id, zip_path) + return zip_path, _get_learnware_by_path(zip_path, tempdir=tempdir) + + def _get_learnware_by_path(_learnware_zippath, tempdir=None): + if tempdir is None: + self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_")) + tempdir = self.tempdir_list[-1].name - with zipfile.ZipFile(learnware_zippath, "r") as z_file: + with zipfile.ZipFile(_learnware_zippath, "r") as z_file: z_file.extractall(tempdir) yaml_file = C.learnware_folder_config["yaml_file"] @@ -354,22 +368,36 @@ class LearnwareClient: semantic_specification = json.load(fin) return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) - - if isinstance(learnware_file, str): - zip_paths = [learnware_file] - elif isinstance(learnware_file, list): - zip_paths = learnware_file learnware_list = [] - for zip_path in zip_paths: - learnware_obj = _get_learnware_obj(zip_path) - if load_option == "normal": - learnware_obj.instantiate_model() - learnware_list.append(learnware_obj) + zip_paths = [] + if learnware_path is not None: + if isinstance(learnware_path, str): + zip_paths = [learnware_path] + elif isinstance(learnware_path, list): + zip_paths = learnware_path + + for zip_path in zip_paths: + learnware_obj = _get_learnware_by_path(zip_path) + learnware_list.append(learnware_obj) + elif learnware_id is not None: + if isinstance(learnware_id, str): + id_list = [learnware_id] + elif isinstance(learnware_id, list): + id_list = learnware_id + + for idx in id_list: + zip_path, learnware_obj = _get_learnware_by_id(idx) + zip_paths.append(zip_path) + learnware_list.append(learnware_obj) - if load_option == "conda_env": - env_container = LearnwaresContainer(learnware_list, zip_paths) - learnware_list = env_container.get_learnware_list_with_container() + if runnable_option is not None: + if runnable_option == "normal": + for i in range(len(learnware_list)): + learnware_list[i].instantiate_model() + elif runnable_option == "conda_env": + env_container = LearnwaresContainer(learnware_list, zip_paths) + learnware_list = env_container.get_learnware_list_with_container() if len(learnware_list) == 1: return learnware_list[0] diff --git a/tests/test_client/test_load.py b/tests/test_client/test_load.py index 73be704..1706e00 100644 --- a/tests/test_client/test_load.py +++ b/tests/test_client/test_load.py @@ -20,16 +20,15 @@ class TestLearnwareLoad(unittest.TestCase): self.client = LearnwareClient() self.client.login(email, token) - learnware_ids = ["00000084", "00000154", "00000155"] - zip_paths = ["1.zip", "2.zip", "3.zip"] root = os.path.dirname(__file__) - for i in range(len(learnware_ids)): - zip_paths[i] = os.path.join(root, zip_paths[i]) - self.client.download_learnware(learnware_ids[i], zip_paths[i]) - self.zip_paths = zip_paths + self.learnware_ids = ["00000084", "00000154", "00000155"] + self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] - def test_single_learnware(self): - learnware_list = [self.client.load_learnware(zippath, load_option="conda_env") for zippath in self.zip_paths] + def test_load_single_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = [self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") for zippath in self.zip_paths] reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array)) @@ -37,8 +36,29 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) - def test_multi_learnware(self): - learnware_list = self.client.load_learnware(self.zip_paths, load_option="conda_env") + def test_load_multi_learnware_by_zippath(self): + for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths): + self.client.download_learnware(learnware_id, zip_path) + + learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env") + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_single_learnware_by_id(self): + learnware_list = [self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids] + reuser = AveragingReuser(learnware_list, mode="vote_by_label") + input_array = np.random.random(size=(20, 13)) + print(reuser.predict(input_array)) + + for learnware in learnware_list: + print(learnware.id, learnware.predict(input_array)) + + def test_load_multi_learnware_by_id(self): + learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env") reuser = AveragingReuser(learnware_list, mode="vote_by_label") input_array = np.random.random(size=(20, 13)) print(reuser.predict(input_array))