Browse Source

[MNT] update load test

tags/v0.3.2
Gene 2 years ago
parent
commit
f6c91c7caa
2 changed files with 0 additions and 108 deletions
  1. +0
    -33
      tests/test_learnware_client/test_docker.py
  2. +0
    -75
      tests/test_learnware_client/test_load.py

+ 0
- 33
tests/test_learnware_client/test_docker.py View File

@@ -1,33 +0,0 @@
import os
import zipfile
import numpy as np

import learnware
from learnware.client import LearnwareClient
from learnware.client.container import LearnwaresContainer
from learnware.learnware.reuse import AveragingReuser


if __name__ == "__main__":
email = "liujd@lamda.nju.edu.cn"
token = "f7e647146a314c6e8b4e2e1079c4bca4"

client = LearnwareClient()
client.login(email, token)

root = os.path.dirname(__file__)
learnware_ids = ["00000084", "00000154", "00000155"]
zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]

for learnware_id, zip_path in zip(learnware_ids, zip_paths):
client.download_learnware(learnware_id, zip_path)

learnware_list = [client.load_learnware(learnware_path=zippath) for zippath in zip_paths]
with LearnwaresContainer(learnware_list, zip_paths, mode="docker") as env_container:
learnware_list = env_container.get_learnwares_with_container()
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

+ 0
- 75
tests/test_learnware_client/test_load.py View File

@@ -1,75 +0,0 @@
import os
import unittest
import zipfile
import numpy as np

import learnware
from learnware.learnware import get_learnware_from_dirpath
from learnware.client import LearnwareClient
from learnware.client.container import ModelCondaContainer, LearnwaresContainer
from learnware.learnware.reuse import AveragingReuser


class TestLearnwareLoad(unittest.TestCase):
def setUp(self):
unittest.TestCase.setUpClass()
email = "liujd@lamda.nju.edu.cn"
token = "f7e647146a314c6e8b4e2e1079c4bca4"

self.client = LearnwareClient()
self.client.login(email, token)

root = os.path.dirname(__file__)
self.learnware_ids = ["00000084", "00000154", "00000155"]
self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]

def test_load_single_learnware_by_zippath(self):
for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
self.client.download_learnware(learnware_id, zip_path)

learnware_list = [
self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env")
for zippath in self.zip_paths
]
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_load_multi_learnware_by_zippath(self):
for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
self.client.download_learnware(learnware_id, zip_path)

learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env")
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_load_single_learnware_by_id(self):
learnware_list = [
self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids
]
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))

def test_load_multi_learnware_by_id(self):
learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env")
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
input_array = np.random.random(size=(20, 13))
print(reuser.predict(input_array))

for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))


if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save