You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. import unittest
  3. import zipfile
  4. import numpy as np
  5. import learnware
  6. from learnware.learnware import get_learnware_from_dirpath
  7. from learnware.client import LearnwareClient
  8. from learnware.client.container import ModelCondaContainer, LearnwaresContainer
  9. from learnware.learnware.reuse import AveragingReuser
  10. class TestLearnwareLoad(unittest.TestCase):
  11. def setUp(self):
  12. unittest.TestCase.setUpClass()
  13. email = "liujd@lamda.nju.edu.cn"
  14. token = "f7e647146a314c6e8b4e2e1079c4bca4"
  15. self.client = LearnwareClient()
  16. self.client.login(email, token)
  17. root = os.path.dirname(__file__)
  18. self.learnware_ids = ["00000084", "00000154", "00000155"]
  19. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  20. def test_load_single_learnware_by_zippath(self):
  21. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  22. self.client.download_learnware(learnware_id, zip_path)
  23. learnware_list = [
  24. self.client.load_learnware(learnware_path=zippath, runnable_option="conda_env")
  25. for zippath in self.zip_paths
  26. ]
  27. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  28. input_array = np.random.random(size=(20, 13))
  29. print(reuser.predict(input_array))
  30. for learnware in learnware_list:
  31. print(learnware.id, learnware.predict(input_array))
  32. def test_load_multi_learnware_by_zippath(self):
  33. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  34. self.client.download_learnware(learnware_id, zip_path)
  35. learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda_env")
  36. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  37. input_array = np.random.random(size=(20, 13))
  38. print(reuser.predict(input_array))
  39. for learnware in learnware_list:
  40. print(learnware.id, learnware.predict(input_array))
  41. def test_load_single_learnware_by_id(self):
  42. learnware_list = [
  43. self.client.load_learnware(learnware_id=idx, runnable_option="conda_env") for idx in self.learnware_ids
  44. ]
  45. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  46. input_array = np.random.random(size=(20, 13))
  47. print(reuser.predict(input_array))
  48. for learnware in learnware_list:
  49. print(learnware.id, learnware.predict(input_array))
  50. def test_load_multi_learnware_by_id(self):
  51. learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda_env")
  52. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  53. input_array = np.random.random(size=(20, 13))
  54. print(reuser.predict(input_array))
  55. for learnware in learnware_list:
  56. print(learnware.id, learnware.predict(input_array))
  57. if __name__ == "__main__":
  58. unittest.main()