You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load_learnware.py 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import unittest
  3. import argparse
  4. import numpy as np
  5. import learnware
  6. from learnware.learnware import get_learnware_from_dirpath
  7. from learnware.client import LearnwareClient
  8. from learnware.client.container import ModelCondaContainer, LearnwaresContainer
  9. from learnware.reuse import AveragingReuser
  10. from learnware.tests import parametrize
  11. class TestLearnwareLoad(unittest.TestCase):
  12. def __init__(self, method_name='runTest', mode="conda"):
  13. super(TestLearnwareLoad, self).__init__(method_name)
  14. self.runnable_options = []
  15. if mode in {"all", "conda"}:
  16. self.runnable_options.append("conda")
  17. if mode in {"all", "docker"}:
  18. self.runnable_options.append("docker")
  19. def setUp(self):
  20. self.client = LearnwareClient()
  21. root = os.path.dirname(__file__)
  22. self.learnware_ids = ["00000910", "00000899", "00000900"]
  23. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  24. def _test_load_learnware_by_zippath(self, runnable_option):
  25. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  26. self.client.download_learnware(learnware_id, zip_path)
  27. learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option=runnable_option)
  28. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  29. input_array = np.random.random(size=(20, 13))
  30. print(reuser.predict(input_array))
  31. for learnware in learnware_list:
  32. print(learnware.id, learnware.predict(input_array))
  33. def _test_load_learnware_by_id(self, runnable_option):
  34. learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option)
  35. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  36. input_array = np.random.random(size=(20, 13))
  37. print(reuser.predict(input_array))
  38. for learnware in learnware_list:
  39. print(learnware.id, learnware.predict(input_array))
  40. def test_load_learnware_by_zippath(self):
  41. for runnable_option in self.runnable_options:
  42. self._test_load_learnware_by_zippath(runnable_option=runnable_option)
  43. def test_load_learnware_by_id(self):
  44. for runnable_option in self.runnable_options:
  45. self._test_load_learnware_by_id(runnable_option=runnable_option)
  46. if __name__ == "__main__":
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument("--mode", type=str, required=False, default="conda", help="The mode to load learnware, must be in ['all', 'conda', 'docker']")
  49. args = parser.parse_args()
  50. assert args.mode in {"all", "conda", "docker"}, f"The mode must be in ['all', 'conda', 'docker'], instead of '{args.mode}'"
  51. runner = unittest.TextTestRunner()
  52. runner.run(parametrize(TestLearnwareLoad, mode=args.mode))