You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 7.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import os
  2. import fire
  3. import joblib
  4. import zipfile
  5. import numpy as np
  6. from sklearn import svm
  7. from shutil import copyfile, rmtree
  8. import learnware
  9. from learnware.market import EasyMarket, BaseUserInfo
  10. from learnware.market import database_ops
  11. from learnware.learnware import Learnware
  12. import learnware.specification as specification
  13. from learnware.utils import get_module_by_module_path
  14. curr_root = os.path.dirname(os.path.abspath(__file__))
  15. semantic_specs = [
  16. {
  17. "Data": {"Values": ["Tabular"], "Type": "Class"},
  18. "Task": {
  19. "Values": ["Classification"],
  20. "Type": "Class",
  21. },
  22. "Device": {"Values": ["GPU"], "Type": "Tag"},
  23. "Scenario": {"Values": ["Nature"], "Type": "Tag"},
  24. "Description": {"Values": "", "Type": "Description"},
  25. "Name": {"Values": "learnware_1", "Type": "Name"},
  26. },
  27. {
  28. "Data": {"Values": ["Tabular"], "Type": "Class"},
  29. "Task": {
  30. "Values": ["Classification"],
  31. "Type": "Class",
  32. },
  33. "Device": {"Values": ["GPU"], "Type": "Tag"},
  34. "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
  35. "Description": {"Values": "", "Type": "Description"},
  36. "Name": {"Values": "learnware_2", "Type": "Name"},
  37. },
  38. {
  39. "Data": {"Values": ["Tabular"], "Type": "Class"},
  40. "Task": {
  41. "Values": ["Classification"],
  42. "Type": "Class",
  43. },
  44. "Device": {"Values": ["GPU"], "Type": "Tag"},
  45. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  46. "Description": {"Values": "", "Type": "Description"},
  47. "Name": {"Values": "learnware_3", "Type": "Name"},
  48. },
  49. ]
  50. user_senmantic = {
  51. "Data": {"Values": ["Tabular"], "Type": "Class"},
  52. "Task": {
  53. "Values": ["Classification"],
  54. "Type": "Class",
  55. },
  56. "Device": {"Values": ["GPU"], "Type": "Tag"},
  57. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  58. "Description": {"Values": "", "Type": "Description"},
  59. "Name": {"Values": "", "Type": "Name"},
  60. }
  61. class LearnwareMarketWorkflow:
  62. def _init_learnware_market(self):
  63. """initialize learnware market"""
  64. database_ops.clear_learnware_table()
  65. learnware.init()
  66. np.random.seed(2023)
  67. def prepare_learnware_randomly(self, learnware_num=10):
  68. self.zip_path_list = []
  69. for i in range(learnware_num):
  70. dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
  71. os.makedirs(dir_path, exist_ok=True)
  72. print("Preparing Learnware: %d" % (i))
  73. data_X = np.random.randn(5000, 20) * i
  74. data_y = np.random.randn(5000)
  75. data_y = np.where(data_y > 0, 1, 0)
  76. clf = svm.SVC(kernel="linear")
  77. clf.fit(data_X, data_y)
  78. joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))
  79. spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
  80. spec.save(os.path.join(dir_path, "svm.json"))
  81. init_file = os.path.join(dir_path, "__init__.py")
  82. copyfile("example_init.py", init_file) # cp example_init.py init_file
  83. yaml_file = os.path.join(dir_path, "learnware.yaml")
  84. copyfile("example.yaml", yaml_file) # cp example.yaml yaml_file
  85. zip_file = dir_path + ".zip"
  86. # zip -q -r -j zip_file dir_path
  87. with zipfile.ZipFile(zip_file, "w") as zip_obj:
  88. for foldername, subfolders, filenames in os.walk(dir_path):
  89. for filename in filenames:
  90. file_path = os.path.join(foldername, filename)
  91. zip_info = zipfile.ZipInfo(filename)
  92. zip_info.compress_type = zipfile.ZIP_STORED
  93. with open(file_path, "rb") as file:
  94. zip_obj.writestr(zip_info, file.read())
  95. rmtree(dir_path) # rm -r dir_path
  96. self.zip_path_list.append(zip_file)
  97. def test_upload_delete_learnware(self, learnware_num=5, delete=False):
  98. self._init_learnware_market()
  99. self.prepare_learnware_randomly(learnware_num)
  100. easy_market = EasyMarket()
  101. print("Total Item:", len(easy_market))
  102. for idx, zip_path in enumerate(self.zip_path_list):
  103. semantic_spec = semantic_specs[idx % 3]
  104. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  105. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  106. easy_market.add_learnware(zip_path, semantic_spec)
  107. print("Total Item:", len(easy_market))
  108. curr_inds = easy_market._get_ids()
  109. print("Available ids After Uploading Learnwares:", curr_inds)
  110. if delete:
  111. for learnware_id in curr_inds:
  112. easy_market.delete_learnware(learnware_id)
  113. easy_market.delete_learnware(learnware_id)
  114. curr_inds = easy_market._get_ids()
  115. print("Available ids After Deleting Learnwares:", curr_inds)
  116. return easy_market
  117. def test_search_semantics(self, learnware_num=5):
  118. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  119. print("Total Item:", len(easy_market))
  120. test_folder = os.path.join(curr_root, "test_semantics")
  121. idx, zip_path = 1, self.zip_path_list[1]
  122. unzip_dir = os.path.join(test_folder, f"{idx}")
  123. # unzip -o -q zip_path -d unzip_dir
  124. if os.path.exists(unzip_dir):
  125. rmtree(unzip_dir)
  126. os.makedirs(unzip_dir, exist_ok=True)
  127. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  128. zip_obj.extractall(path=unzip_dir)
  129. user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic)
  130. _, single_learnware_list, _ = easy_market.search_learnware(user_info)
  131. print("User info:", user_info.get_semantic_spec())
  132. print(f"search result of user{idx}:")
  133. for learnware in single_learnware_list:
  134. print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())
  135. rmtree(test_folder) # rm -r test_folder
  136. def test_stat_search(self, learnware_num=5):
  137. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  138. print("Total Item:", len(easy_market))
  139. test_folder = os.path.join(curr_root, "test_stat")
  140. for idx, zip_path in enumerate(self.zip_path_list):
  141. unzip_dir = os.path.join(test_folder, f"{idx}")
  142. # unzip -o -q zip_path -d unzip_dir
  143. if os.path.exists(unzip_dir):
  144. rmtree(unzip_dir)
  145. os.makedirs(unzip_dir, exist_ok=True)
  146. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  147. zip_obj.extractall(path=unzip_dir)
  148. user_spec = specification.rkme.RKMEStatSpecification()
  149. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  150. user_info = BaseUserInfo(
  151. id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
  152. )
  153. sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
  154. print(f"search result of user{idx}:")
  155. for score, learnware in zip(sorted_score_list, single_learnware_list):
  156. print(f"score: {score}, learnware_id: {learnware.id}")
  157. mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
  158. print(f"mixture_learnware: {mixture_id}\n")
  159. rmtree(test_folder) # rm -r test_folder
  160. if __name__ == "__main__":
  161. fire.Fire(LearnwareMarketWorkflow)

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。