You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_search.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. import unittest
  3. import tempfile
  4. import logging
  5. import learnware
  6. learnware.init(logging_level=logging.WARNING)
  7. from learnware.learnware import Learnware
  8. from learnware.client import LearnwareClient
  9. from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker
  10. from learnware.config import C
  11. class TestSearch(unittest.TestCase):
  12. client = LearnwareClient()
  13. @classmethod
  14. def setUpClass(cls):
  15. cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
  16. if cls.client.is_connected():
  17. cls._build_learnware_market()
  18. @classmethod
  19. def _build_learnware_market(cls):
  20. table_learnware_ids = ["00001951", "00001980", "00001987"]
  21. image_learnware_ids = ["00000851", "00000858", "00000841"]
  22. text_learnware_ids = ["00000652", "00000637"]
  23. learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids
  24. with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir:
  25. for learnware_id in learnware_ids:
  26. learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
  27. try:
  28. cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
  29. semantic_spec = (
  30. cls.client.load_learnware(learnware_path=learnware_zippath)
  31. .get_specification()
  32. .get_semantic_spec()
  33. )
  34. except Exception:
  35. print("'learnware_id' is passed due to the network problem.")
  36. cls.market.add_learnware(
  37. learnware_zippath,
  38. learnware_id=learnware_id,
  39. semantic_spec=semantic_spec,
  40. checker_names=["EasySemanticChecker"],
  41. )
  42. def _skip_test(self):
  43. if not self.client.is_connected():
  44. print("Client can not connect!")
  45. return True
  46. return False
  47. def test_image_search(self):
  48. if not self._skip_test():
  49. learnware_id = "00000619"
  50. try:
  51. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  52. except Exception:
  53. print("'test_image_search' is passed due to the network problem.")
  54. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  55. search_result = self.market.search_learnware(user_info)
  56. print("Single Search Results:", search_result.get_single_results())
  57. print("Multiple Search Results:", search_result.get_multiple_results())
  58. def test_text_search(self):
  59. if not self._skip_test():
  60. learnware_id = "00000653"
  61. try:
  62. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  63. except Exception:
  64. print("'test_text_search' is passed due to the network problem.")
  65. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  66. search_result = self.market.search_learnware(user_info)
  67. print("Single Search Results:", search_result.get_single_results())
  68. print("Multiple Search Results:", search_result.get_multiple_results())
  69. def test_table_search(self):
  70. if not self._skip_test():
  71. learnware_id = "00001950"
  72. try:
  73. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  74. except Exception:
  75. print("'test_table_search' is passed due to the network problem.")
  76. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  77. search_result = self.market.search_learnware(user_info)
  78. print("Single Search Results:", search_result.get_single_results())
  79. print("Multiple Search Results:", search_result.get_multiple_results())
  80. def suite():
  81. _suite = unittest.TestSuite()
  82. _suite.addTest(TestSearch("test_image_search"))
  83. _suite.addTest(TestSearch("test_text_search"))
  84. _suite.addTest(TestSearch("test_table_search"))
  85. return _suite
  86. if __name__ == "__main__":
  87. runner = unittest.TextTestRunner()
  88. runner.run(suite())