Browse Source

[MNT] add runnable_option in load_learnware

tags/v0.3.2
Gene 2 years ago
parent
commit
6ca46e2bf2
2 changed files with 83 additions and 35 deletions
  1. +53
    -25
      learnware/client/learnware_client.py
  2. +30
    -10
      tests/test_client/test_load.py

+ 53
- 25
learnware/client/learnware_client.py View File

@@ -1,5 +1,5 @@
import os
import numpy as np
import uuid
import yaml
import json
import atexit
@@ -7,6 +7,7 @@ import zipfile
import hashlib
import requests
import tempfile
import numpy as np
from enum import Enum
from tqdm import tqdm
from typing import Union, List
@@ -309,31 +310,44 @@ class LearnwareClient:

return semantic_conf[key.value]["Values"]

def load_learnware(self, learnware_file: Union[str, List[str]], load_option: str = "conda_env"):
"""Load learnware
def load_learnware(self, learnware_path: Union[str, List[str]] = None, learnware_id: Union[str, List[str]] = None, runnable_option: str = None):
"""Load learnware by learnware zip file or learnware id (zip file has higher priority)

Parameters
----------
learnware_file : Union[str, List[str]]
learnware_path : Union[str, List[str]]
learnware zip path or learnware zip path list
load_option : str
the option for loading learnwares
- "normal": load learnware without installing environment
- "conda_env": load learnware with installing conda virtual environment
learnware_id : Union[str, List[str]]
learnware id or learnware id list
runnable_option : str
the option for instantiating learnwares
- "normal": instantiate learnware without installing environment
- "conda_env": instantiate learnware with installing conda virtual environment

Returns
-------
Learnware
The contructed learnware object or object list
"""
if load_option not in ["normal", "conda_env"]:
raise ValueError(f"load_option must be one of ['normal', 'conda_env'], but got {load_option}")
if runnable_option is not None and runnable_option not in ["normal", "conda_env"]:
raise logger.warning(f"runnable_option must be one of ['normal', 'conda_env'], but got {runnable_option}")

if learnware_path is None and learnware_id is None:
raise ValueError("Requires one of learnware_path or learnware_id")
def _get_learnware_obj(learnware_zippath):
def _get_learnware_by_id(_learnware_id):
self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_"))
tempdir = self.tempdir_list[-1].name
zip_path = os.path.join(tempdir, f"{str(uuid.uuid4())}.zip")
self.download_learnware(_learnware_id, zip_path)
return zip_path, _get_learnware_by_path(zip_path, tempdir=tempdir)
def _get_learnware_by_path(_learnware_zippath, tempdir=None):
if tempdir is None:
self.tempdir_list.append(tempfile.TemporaryDirectory(prefix="learnware_"))
tempdir = self.tempdir_list[-1].name

with zipfile.ZipFile(learnware_zippath, "r") as z_file:
with zipfile.ZipFile(_learnware_zippath, "r") as z_file:
z_file.extractall(tempdir)

yaml_file = C.learnware_folder_config["yaml_file"]
@@ -354,22 +368,36 @@ class LearnwareClient:
semantic_specification = json.load(fin)

return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir)

if isinstance(learnware_file, str):
zip_paths = [learnware_file]
elif isinstance(learnware_file, list):
zip_paths = learnware_file
learnware_list = []
for zip_path in zip_paths:
learnware_obj = _get_learnware_obj(zip_path)
if load_option == "normal":
learnware_obj.instantiate_model()
learnware_list.append(learnware_obj)
zip_paths = []
if learnware_path is not None:
if isinstance(learnware_path, str):
zip_paths = [learnware_path]
elif isinstance(learnware_path, list):
zip_paths = learnware_path
for zip_path in zip_paths:
learnware_obj = _get_learnware_by_path(zip_path)
learnware_list.append(learnware_obj)
elif learnware_id is not None:
if isinstance(learnware_id, str):
id_list = [learnware_id]
elif isinstance(learnware_id, list):
id_list = learnware_id
for idx in id_list:
zip_path, learnware_obj = _get_learnware_by_id(idx)
zip_paths.append(zip_path)
learnware_list.append(learnware_obj)
if load_option == "conda_env":
env_container = LearnwaresContainer(learnware_list, zip_paths)
learnware_list = env_container.get_learnware_list_with_container()
if runnable_option is not None:
if runnable_option == "normal":
for i in range(len(learnware_list)):
learnware_list[i].instantiate_model()
elif runnable_option == "conda_env":
env_container = LearnwaresContainer(learnware_list, zip_paths)
learnware_list = env_container.get_learnware_list_with_container()
if len(learnware_list) == 1:
return learnware_list[0]


+ 30
- 10
tests/test_client/test_load.py View File

@@ -20,16 +20,15 @@ class TestLearnwareLoad(unittest.TestCase):
self.client = LearnwareClient()
self.client.login(email, token)

learnware_ids = ["00000084", "00000154", "00000155"]
zip_paths = ["1.zip", "2.zip", "3.zip"]
root = os.path.dirname(__file__)
for i in range(len(learnware_ids)):
zip_paths[i] = os.path.join(root, zip_paths[i])
self.client.download_learnware(learnware_ids[i], zip_paths[i])
self.zip_paths = zip_paths
self.learnware_ids = ["00000084", "00000154", "00000155"]
self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]

def test_single_learnware(self):
learnware_list = [self.client.load_learnware(zippath, load_option="conda_env") for zippath in self.zip_paths]
def test_load_single_learnware_by_zippath(self):
for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths):
self.client.download_learnware(learnware_id, zip_path)
learnware_list = [self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env") for zippath in self.zip_paths]
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))
@@ -37,8 +36,29 @@ class TestLearnwareLoad(unittest.TestCase):
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_multi_learnware(self):
learnware_list = self.client.load_learnware(self.zip_paths, load_option="conda_env")
def test_load_multi_learnware_by_zippath(self):
for (learnware_id, zip_path) in zip(self.learnware_ids, self.zip_paths):
self.client.download_learnware(learnware_id, zip_path)
learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env")
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))
def test_load_single_learnware_by_id(self):
learnware_list = [self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids]
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_load_multi_learnware_by_id(self):
learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env")
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))


Loading…
Cancel
Save