You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. import torch
  3. from get_data import *
  4. import os
  5. import random
  6. from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
  7. from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser
  8. import time
  9. from learnware.market import EasyMarket, BaseUserInfo
  10. from learnware.market import database_ops
  11. from learnware.learnware import Learnware
  12. import learnware.specification as specification
  13. from learnware.logger import get_module_logger
  14. from shutil import copyfile, rmtree
  15. import zipfile
  16. logger = get_module_logger("image_test", level="INFO")
  17. origin_data_root = "./data/origin_data"
  18. processed_data_root = "./data/processed_data"
  19. tmp_dir = "./data/tmp"
  20. learnware_pool_dir = "./data/learnware_pool"
  21. dataset = "cifar10"
  22. n_uploaders = 50
  23. n_users = 20
  24. n_classes = 10
  25. data_root = os.path.join(origin_data_root, dataset)
  26. data_save_root = os.path.join(processed_data_root, dataset)
  27. user_save_root = os.path.join(data_save_root, "user")
  28. uploader_save_root = os.path.join(data_save_root, "uploader")
  29. model_save_root = os.path.join(data_save_root, "uploader_model")
  30. os.makedirs(data_root, exist_ok=True)
  31. os.makedirs(user_save_root, exist_ok=True)
  32. os.makedirs(uploader_save_root, exist_ok=True)
  33. os.makedirs(model_save_root, exist_ok=True)
  34. semantic_specs = [
  35. {
  36. "Data": {"Values": ["Tabular"], "Type": "Class"},
  37. "Task": {"Values": ["Classification"], "Type": "Class"},
  38. "Device": {"Values": ["GPU"], "Type": "Tag"},
  39. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  40. "Description": {"Values": "", "Type": "String"},
  41. "Name": {"Values": "learnware_1", "Type": "String"},
  42. }
  43. ]
  44. user_senmantic = {
  45. "Data": {"Values": ["Tabular"], "Type": "Class"},
  46. "Task": {"Values": ["Classification"], "Type": "Class"},
  47. "Device": {"Values": ["GPU"], "Type": "Tag"},
  48. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  49. "Description": {"Values": "", "Type": "String"},
  50. "Name": {"Values": "", "Type": "String"},
  51. }
  52. def prepare_data():
  53. if dataset == "cifar10":
  54. X_train, y_train, X_test, y_test = get_cifar10(data_root)
  55. elif dataset == "mnist":
  56. X_train, y_train, X_test, y_test = get_mnist(data_root)
  57. else:
  58. return
  59. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  60. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  61. def prepare_model():
  62. dataloader = ImageDataLoader(data_save_root, train=True)
  63. for i in range(n_uploaders):
  64. logger.info("Train on uploader: %d" % (i))
  65. X, y = dataloader.get_idx_data(i)
  66. model = train(X, y, out_classes=n_classes)
  67. model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  68. torch.save(model.state_dict(), model_save_path)
  69. logger.info("Model saved to '%s'" % (model_save_path))
  70. def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name):
  71. os.makedirs(save_root, exist_ok=True)
  72. tmp_spec_path = os.path.join(save_root, "rkme.json")
  73. tmp_model_path = os.path.join(save_root, "conv_model.pth")
  74. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  75. tmp_init_path = os.path.join(save_root, "__init__.py")
  76. tmp_model_file_path = os.path.join(save_root, "model.py")
  77. mmodel_file_path = "./example_files/model.py"
  78. X = np.load(data_path)
  79. st = time.time()
  80. user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0)
  81. ed = time.time()
  82. logger.info("Stat spec generated in %.3f s" % (ed - st))
  83. user_spec.save(tmp_spec_path)
  84. copyfile(model_path, tmp_model_path)
  85. copyfile(yaml_path, tmp_yaml_path)
  86. copyfile(init_file_path, tmp_init_path)
  87. copyfile(mmodel_file_path, tmp_model_file_path)
  88. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  89. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  90. zip_obj.write(tmp_spec_path, "rkme.json")
  91. zip_obj.write(tmp_model_path, "conv_model.pth")
  92. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  93. zip_obj.write(tmp_init_path, "__init__.py")
  94. zip_obj.write(tmp_model_file_path, "model.py")
  95. rmtree(save_root)
  96. logger.info("New Learnware Saved to %s" % (zip_file_name))
  97. return zip_file_name
  98. def prepare_market():
  99. image_market = EasyMarket(rebuild=True)
  100. rmtree(learnware_pool_dir)
  101. os.makedirs(learnware_pool_dir, exist_ok=True)
  102. for i in range(n_uploaders):
  103. data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
  104. model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  105. init_file_path = "./example_files/example_init.py"
  106. yaml_file_path = "./example_files/example_yaml.yaml"
  107. new_learnware_path = prepare_learnware(
  108. data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i)
  109. )
  110. semantic_spec = semantic_specs[0]
  111. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  112. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  113. image_market.add_learnware(new_learnware_path, semantic_spec)
  114. logger.info("Total Item: %d" % (len(image_market)))
  115. curr_inds = image_market._get_ids()
  116. logger.info("Available ids: " + str(curr_inds))
  117. def test_search(gamma=0.1, load_market=True):
  118. if load_market:
  119. image_market = EasyMarket(market_id="image")
  120. else:
  121. prepare_market()
  122. image_market = EasyMarket(market_id="image")
  123. logger.info("Number of items in the market: %d" % len(image_market))
  124. select_list = []
  125. avg_list = []
  126. improve_list = []
  127. job_selector_score_list = []
  128. ensemble_score_list = []
  129. for i in range(n_users):
  130. user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i))
  131. user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
  132. user_data = np.load(user_data_path)
  133. user_label = np.load(user_label_path)
  134. user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0)
  135. user_info = BaseUserInfo(
  136. id=f"user_{i}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_stat_spec}
  137. )
  138. logger.info("Searching Market for user: %d" % (i))
  139. sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware(
  140. user_info
  141. )
  142. l = len(sorted_score_list)
  143. acc_list = []
  144. for idx in range(l):
  145. learnware = single_learnware_list[idx]
  146. score = sorted_score_list[idx]
  147. pred_y = learnware.predict(user_data)
  148. acc = eval_prediction(pred_y, user_label)
  149. acc_list.append(acc)
  150. logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  151. # test reuse
  152. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list)
  153. reuse_predict = reuse_baseline.predict(user_data=user_data)
  154. reuse_score = eval_prediction(reuse_predict, user_label)
  155. job_selector_score_list.append(reuse_score)
  156. print(f"mixture reuse loss: {reuse_score}\n")
  157. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote")
  158. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  159. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  160. ensemble_score_list.append(ensemble_score)
  161. print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n")
  162. select_list.append(acc_list[0])
  163. avg_list.append(np.mean(acc_list))
  164. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  165. logger.info(
  166. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  167. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  168. )
  169. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  170. logger.info(
  171. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  172. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  173. )
  174. logger.info(
  175. "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  176. )
  177. if __name__ == "__main__":
  178. # prepare_data()
  179. # prepare_model()
  180. test_search(load_market=False)

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。