You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import numpy as np
  2. import torch
  3. from get_data import *
  4. import os
  5. import random
  6. from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
  7. from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser
  8. import time
  9. from learnware.market import EasyMarket, BaseUserInfo
  10. from learnware.market import database_ops
  11. from learnware.learnware import Learnware
  12. import learnware.specification as specification
  13. from learnware.logger import get_module_logger
  14. from shutil import copyfile, rmtree
  15. import zipfile
  16. logger = get_module_logger("image_test", level="INFO")
  17. origin_data_root = "./data/origin_data"
  18. processed_data_root = "./data/processed_data"
  19. tmp_dir = "./data/tmp"
  20. learnware_pool_dir = "./data/learnware_pool"
  21. dataset = "cifar10"
  22. n_uploaders = 50
  23. n_users = 20
  24. n_classes = 10
  25. data_root = os.path.join(origin_data_root, dataset)
  26. data_save_root = os.path.join(processed_data_root, dataset)
  27. user_save_root = os.path.join(data_save_root, "user")
  28. uploader_save_root = os.path.join(data_save_root, "uploader")
  29. model_save_root = os.path.join(data_save_root, "uploader_model")
  30. os.makedirs(data_root, exist_ok=True)
  31. os.makedirs(user_save_root, exist_ok=True)
  32. os.makedirs(uploader_save_root, exist_ok=True)
  33. os.makedirs(model_save_root, exist_ok=True)
  34. semantic_specs = [
  35. {
  36. "Data": {"Values": ["Tabular"], "Type": "Class"},
  37. "Task": {"Values": ["Classification"], "Type": "Class"},
  38. "Library": {"Values": ["Pytorch"], "Type": "Class"},
  39. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  40. "Description": {"Values": "", "Type": "String"},
  41. "Name": {"Values": "learnware_1", "Type": "String"},
  42. }
  43. ]
  44. user_semantic = {
  45. "Data": {"Values": ["Tabular"], "Type": "Class"},
  46. "Task": {"Values": ["Classification"], "Type": "Class"},
  47. "Library": {"Values": ["Pytorch"], "Type": "Class"},
  48. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  49. "Description": {"Values": "", "Type": "String"},
  50. "Name": {"Values": "", "Type": "String"},
  51. }
  52. def prepare_data():
  53. if dataset == "cifar10":
  54. X_train, y_train, X_test, y_test = get_cifar10(data_root)
  55. elif dataset == "mnist":
  56. X_train, y_train, X_test, y_test = get_mnist(data_root)
  57. else:
  58. return
  59. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  60. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  61. def prepare_model():
  62. dataloader = ImageDataLoader(data_save_root, train=True)
  63. for i in range(n_uploaders):
  64. logger.info("Train on uploader: %d" % (i))
  65. X, y = dataloader.get_idx_data(i)
  66. model = train(X, y, out_classes=n_classes)
  67. model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  68. torch.save(model.state_dict(), model_save_path)
  69. logger.info("Model saved to '%s'" % (model_save_path))
  70. def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name):
  71. os.makedirs(save_root, exist_ok=True)
  72. tmp_spec_path = os.path.join(save_root, "rkme.json")
  73. tmp_model_path = os.path.join(save_root, "conv_model.pth")
  74. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  75. tmp_init_path = os.path.join(save_root, "__init__.py")
  76. tmp_model_file_path = os.path.join(save_root, "model.py")
  77. mmodel_file_path = "./example_files/model.py"
  78. X = np.load(data_path)
  79. st = time.time()
  80. user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0)
  81. ed = time.time()
  82. logger.info("Stat spec generated in %.3f s" % (ed - st))
  83. user_spec.save(tmp_spec_path)
  84. copyfile(model_path, tmp_model_path)
  85. copyfile(yaml_path, tmp_yaml_path)
  86. copyfile(init_file_path, tmp_init_path)
  87. copyfile(mmodel_file_path, tmp_model_file_path)
  88. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  89. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  90. zip_obj.write(tmp_spec_path, "rkme.json")
  91. zip_obj.write(tmp_model_path, "conv_model.pth")
  92. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  93. zip_obj.write(tmp_init_path, "__init__.py")
  94. zip_obj.write(tmp_model_file_path, "model.py")
  95. rmtree(save_root)
  96. logger.info("New Learnware Saved to %s" % (zip_file_name))
  97. return zip_file_name
  98. def prepare_market():
  99. image_market = EasyMarket(market_id="cifar10", rebuild=True)
  100. try:
  101. rmtree(learnware_pool_dir)
  102. except:
  103. pass
  104. os.makedirs(learnware_pool_dir, exist_ok=True)
  105. for i in range(n_uploaders):
  106. data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
  107. model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  108. init_file_path = "./example_files/example_init.py"
  109. yaml_file_path = "./example_files/example_yaml.yaml"
  110. new_learnware_path = prepare_learnware(
  111. data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i)
  112. )
  113. semantic_spec = semantic_specs[0]
  114. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  115. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  116. image_market.add_learnware(new_learnware_path, semantic_spec)
  117. logger.info("Total Item: %d" % (len(image_market)))
  118. curr_inds = image_market._get_ids()
  119. logger.info("Available ids: " + str(curr_inds))
  120. def test_search(gamma=0.1, load_market=True):
  121. if load_market:
  122. image_market = EasyMarket(market_id="cifar10")
  123. else:
  124. prepare_market()
  125. image_market = EasyMarket(market_id="cifar10")
  126. logger.info("Number of items in the market: %d" % len(image_market))
  127. select_list = []
  128. avg_list = []
  129. improve_list = []
  130. job_selector_score_list = []
  131. ensemble_score_list = []
  132. for i in range(n_users):
  133. user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i))
  134. user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
  135. user_data = np.load(user_data_path)
  136. user_label = np.load(user_label_path)
  137. user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=gamma, cuda_idx=0)
  138. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec})
  139. logger.info("Searching Market for user: %d" % (i))
  140. sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware(
  141. user_info
  142. )
  143. l = len(sorted_score_list)
  144. acc_list = []
  145. for idx in range(l):
  146. learnware = single_learnware_list[idx]
  147. score = sorted_score_list[idx]
  148. pred_y = learnware.predict(user_data)
  149. acc = eval_prediction(pred_y, user_label)
  150. acc_list.append(acc)
  151. logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  152. # test reuse (job selector)
  153. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  154. reuse_predict = reuse_baseline.predict(user_data=user_data)
  155. reuse_score = eval_prediction(reuse_predict, user_label)
  156. job_selector_score_list.append(reuse_score)
  157. print(f"mixture reuse loss: {reuse_score}")
  158. # test reuse (ensemble)
  159. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote")
  160. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  161. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  162. ensemble_score_list.append(ensemble_score)
  163. print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n")
  164. select_list.append(acc_list[0])
  165. avg_list.append(np.mean(acc_list))
  166. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  167. logger.info(
  168. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  169. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  170. )
  171. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  172. logger.info(
  173. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  174. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  175. )
  176. logger.info(
  177. "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  178. )
  179. if __name__ == "__main__":
  180. prepare_data()
  181. prepare_model()
  182. test_search(load_market=False)

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。