You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_upload.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import json
  3. import unittest
  4. import tempfile
  5. from learnware.client import LearnwareClient
  6. class TestAllLearnware(unittest.TestCase):
  7. def setUp(self):
  8. unittest.TestCase.setUpClass()
  9. dir_path = os.path.dirname(__file__)
  10. config_path = os.path.join(dir_path, "config.json")
  11. if not os.path.exists(config_path):
  12. data = {"email": None, "token": None}
  13. with open(config_path, "w") as file:
  14. json.dump(data, file)
  15. with open(config_path, "r") as file:
  16. data = json.load(file)
  17. email = data["email"]
  18. token = data["token"]
  19. if email is None or token is None:
  20. raise ValueError("Please set email and token in config.json.")
  21. self.client = LearnwareClient()
  22. self.client.login(email, token)
  23. def test_upload(self):
  24. input_description = {
  25. "Dimension": 13,
  26. "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
  27. }
  28. output_description = {
  29. "Dimension": 1,
  30. "Description": {
  31. "0": "the probability of being a cat",
  32. },
  33. }
  34. semantic_spec = self.client.create_semantic_specification(
  35. name="learnware_example",
  36. description="Just a example for uploading a learnware",
  37. data_type="Table",
  38. task_type="Classification",
  39. library_type="Scikit-learn",
  40. scenarios=["Business", "Financial"],
  41. input_description=input_description,
  42. output_description=output_description,
  43. )
  44. assert isinstance(semantic_spec, dict)
  45. download_learnware_id = "00000084"
  46. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  47. zip_path = os.path.join(tempdir, f"test.zip")
  48. self.client.download_learnware(download_learnware_id, zip_path)
  49. learnware_id = self.client.upload_learnware(
  50. learnware_zip_path=zip_path, semantic_specification=semantic_spec
  51. )
  52. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  53. assert learnware_id in uploaded_ids
  54. self.client.delete_learnware(learnware_id)
  55. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  56. assert learnware_id not in uploaded_ids
  57. if __name__ == "__main__":
  58. unittest.main()