You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_db.py 6.3 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import joblib
  3. import numpy as np
  4. from sklearn import svm
  5. from learnware.market import EasyMarket, BaseUserInfo
  6. from learnware.market import database_ops
  7. from learnware.learnware import Learnware
  8. import learnware.specification as specification
  9. from learnware.utils import get_module_by_module_path
  10. curr_root = os.path.dirname(os.path.abspath(__file__))
  11. semantic_specs = [
  12. {
  13. "Data": {"Values": ["Tabular"], "Type": "Class"},
  14. "Task": {"Values": ["Classification"], "Type": "Class"},
  15. "Device": {"Values": ["GPU"], "Type": "Tag"},
  16. "Scenario": {"Values": ["Nature"], "Type": "Tag"},
  17. "Description": {"Values": "", "Type": "Description"},
  18. "Name": {"Values": "learnware_1", "Type": "Name"},
  19. },
  20. {
  21. "Data": {"Values": ["Tabular"], "Type": "Class"},
  22. "Task": {"Values": ["Classification"], "Type": "Class"},
  23. "Device": {"Values": ["GPU"], "Type": "Tag"},
  24. "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
  25. "Description": {"Values": "", "Type": "Description"},
  26. "Name": {"Values": "learnware_2", "Type": "Name"},
  27. },
  28. {
  29. "Data": {"Values": ["Tabular"], "Type": "Class"},
  30. "Task": {"Values": ["Regression"], "Type": "Class"},
  31. "Device": {"Values": ["GPU"], "Type": "Tag"},
  32. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  33. "Description": {"Values": "", "Type": "Description"},
  34. "Name": {"Values": "learnware_3", "Type": "Name"},
  35. },
  36. ]
  37. user_senmantic = {
  38. "Data": {"Values": ["Tabular"], "Type": "Class"},
  39. "Task": {
  40. "Values": ["Classification"],
  41. "Type": "Class",
  42. },
  43. "Device": {"Values": ["GPU"], "Type": "Tag"},
  44. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  45. "Description": {"Values": "", "Type": "Description"},
  46. "Name": {"Values": "learnware", "Type": "Name"},
  47. }
  48. def prepare_learnware(learnware_num=10):
  49. np.random.seed(2023)
  50. for i in range(learnware_num):
  51. dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
  52. os.makedirs(dir_path, exist_ok=True)
  53. print("Preparing Learnware: %d" % (i))
  54. data_X = np.random.randn(5000, 20) * i
  55. data_y = np.random.randn(5000)
  56. data_y = np.where(data_y > 0, 1, 0)
  57. clf = svm.SVC(kernel="linear")
  58. clf.fit(data_X, data_y)
  59. joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))
  60. spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
  61. spec.save(os.path.join(dir_path, "svm.json"))
  62. init_file = os.path.join(dir_path, "__init__.py")
  63. os.system(f"cp example_init.py {init_file}")
  64. yaml_file = os.path.join(dir_path, "learnware.yaml")
  65. os.system(f"cp example.yaml {yaml_file}")
  66. zip_file = dir_path + ".zip"
  67. os.system(f"zip -q -r -j {zip_file} {dir_path}")
  68. os.system(f"rm -r {dir_path}")
  69. def get_zip_path_list():
  70. root_path = os.path.join(curr_root, "learnware_pool")
  71. zip_path_list = [os.path.join(root_path, path) for path in os.listdir(root_path)]
  72. return zip_path_list
  73. def test_market():
  74. database_ops.clear_learnware_table()
  75. easy_market = EasyMarket()
  76. print("Total Item:", len(easy_market))
  77. zip_path_list = get_zip_path_list() # the path list for learnware .zip
  78. for idx, zip_path in enumerate(zip_path_list):
  79. semantic_spec = semantic_specs[idx % 3]
  80. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  81. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  82. easy_market.add_learnware(zip_path, semantic_spec)
  83. print("Total Item:", len(easy_market))
  84. curr_inds = easy_market._get_ids()
  85. print("Available ids:", curr_inds)
  86. easy_market.delete_learnware(curr_inds[3])
  87. easy_market.delete_learnware(curr_inds[2])
  88. curr_inds = easy_market._get_ids()
  89. print("Available ids:", curr_inds)
  90. def test_search_semantics():
  91. easy_market = EasyMarket()
  92. print("Total Item:", len(easy_market))
  93. root_path = "./learnware_pool"
  94. os.makedirs(root_path, exist_ok=True)
  95. test_learnware_num = 3
  96. prepare_learnware(test_learnware_num)
  97. test_folder = "./test_stat"
  98. zip_path_list = get_zip_path_list()
  99. idx, zip_path = 1, zip_path_list[1]
  100. unzip_dir = os.path.join(test_folder, f"{idx}")
  101. os.makedirs(unzip_dir, exist_ok=True)
  102. os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
  103. user_spec = specification.rkme.RKMEStatSpecification()
  104. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  105. user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic)
  106. _, single_learnware_list, _ = easy_market.search_learnware(user_info)
  107. print("User info:", user_info.get_semantic_spec())
  108. print(f"search result of user{idx}:")
  109. for learnware in single_learnware_list:
  110. print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())
  111. os.system(f"rm -r {test_folder}")
  112. def test_stat_search():
  113. easy_market = EasyMarket()
  114. print("Total Item:", len(easy_market))
  115. test_folder = "./test_stat"
  116. zip_path_list = get_zip_path_list()
  117. for idx, zip_path in enumerate(zip_path_list):
  118. unzip_dir = os.path.join(test_folder, f"{idx}")
  119. os.makedirs(unzip_dir, exist_ok=True)
  120. os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
  121. user_spec = specification.rkme.RKMEStatSpecification()
  122. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  123. user_info = BaseUserInfo(
  124. id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
  125. )
  126. sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
  127. print(f"search result of user{idx}:")
  128. for score, learnware in zip(sorted_score_list, single_learnware_list):
  129. print(f"score: {score}, learnware_id: {learnware.id}")
  130. mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
  131. print(f"mixture_learnware: {mixture_id}\n")
  132. os.system(f"rm -r {test_folder}")
  133. if __name__ == "__main__":
  134. learnware_num = 10
  135. prepare_learnware(learnware_num)
  136. test_market()
  137. test_stat_search()
  138. test_search_semantics()

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。