You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_upload.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import unittest
  3. import tempfile
  4. from learnware.client import LearnwareClient
  5. class TestAllLearnware(unittest.TestCase):
  6. def setUp(self):
  7. unittest.TestCase.setUpClass()
  8. email = "liujd@lamda.nju.edu.cn"
  9. token = "f7e647146a314c6e8b4e2e1079c4bca4"
  10. self.client = LearnwareClient()
  11. self.client.login(email, token)
  12. def test_upload(self):
  13. input_description = {
  14. "Dimension": 13,
  15. "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
  16. }
  17. output_description = {
  18. "Dimension": 1,
  19. "Description": {
  20. "0": "the probability of being a cat",
  21. },
  22. }
  23. semantic_spec = self.client.create_semantic_specification(
  24. name="learnware_example",
  25. description="Just a example for uploading a learnware",
  26. data_type="Table",
  27. task_type="Classification",
  28. library_type="Scikit-learn",
  29. scenarios=["Business", "Financial"],
  30. input_description=input_description,
  31. output_description=output_description,
  32. )
  33. assert isinstance(semantic_spec, dict)
  34. download_learnware_id = "00000084"
  35. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  36. zip_path = os.path.join(tempdir, f"test.zip")
  37. self.client.download_learnware(download_learnware_id, zip_path)
  38. learnware_id = self.client.upload_learnware(
  39. learnware_zip_path=zip_path, semantic_specification=semantic_spec
  40. )
  41. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  42. assert learnware_id in uploaded_ids
  43. self.client.delete_learnware(learnware_id)
  44. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  45. assert learnware_id not in uploaded_ids
  46. if __name__ == "__main__":
  47. unittest.main()