You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 9.5 kB

2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import numpy as np
  2. import torch
  3. from get_data import get_sst2
  4. import os
  5. import random
  6. from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction
  7. from learnware.learnware import Learnware
  8. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  9. import time
  10. import pickle
  11. from learnware.market import instantiate_learnware_market, BaseUserInfo
  12. from learnware.specification import RKMETextSpecification
  13. from learnware.logger import get_module_logger
  14. from shutil import copyfile, rmtree
  15. import zipfile
  16. logger = get_module_logger("text_test", level="INFO")
  17. origin_data_root = "./data/origin_data"
  18. processed_data_root = "./data/processed_data"
  19. tmp_dir = "./data/tmp"
  20. learnware_pool_dir = "./data/learnware_pool"
  21. dataset = "sst2"
  22. n_uploaders = 10
  23. n_users = 5
  24. n_classes = 2
  25. data_root = os.path.join(origin_data_root, dataset)
  26. data_save_root = os.path.join(processed_data_root, dataset)
  27. user_save_root = os.path.join(data_save_root, "user")
  28. uploader_save_root = os.path.join(data_save_root, "uploader")
  29. model_save_root = os.path.join(data_save_root, "uploader_model")
  30. os.makedirs(data_root, exist_ok=True)
  31. os.makedirs(user_save_root, exist_ok=True)
  32. os.makedirs(uploader_save_root, exist_ok=True)
  33. os.makedirs(model_save_root, exist_ok=True)
  34. output_description = {
  35. "Dimension": 2,
  36. "Description": {
  37. "0": "the probability of being negative",
  38. "1": "the probability of being positive",
  39. },
  40. }
  41. semantic_specs = [
  42. {
  43. "Data": {"Values": ["Text"], "Type": "Class"},
  44. "Task": {"Values": ["Classification"], "Type": "Class"},
  45. "Library": {"Values": ["PyTorch"], "Type": "Class"},
  46. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  47. "Description": {"Values": "", "Type": "String"},
  48. "Name": {"Values": "learnware_1", "Type": "String"},
  49. "Output": output_description,
  50. }
  51. ]
  52. user_semantic = {
  53. "Data": {"Values": ["Text"], "Type": "Class"},
  54. "Task": {"Values": ["Classification"], "Type": "Class"},
  55. "Library": {"Values": ["PyTorch"], "Type": "Class"},
  56. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  57. "Description": {"Values": "", "Type": "String"},
  58. "Name": {"Values": "", "Type": "String"},
  59. "Output": output_description,
  60. }
  61. def prepare_data():
  62. if dataset == "sst2":
  63. X_train, y_train, X_test, y_test = get_sst2(data_root)
  64. else:
  65. return
  66. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  67. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  68. def prepare_model():
  69. dataloader = TextDataLoader(data_save_root, train=True)
  70. for i in range(n_uploaders):
  71. logger.info("Train on uploader: %d" % (i))
  72. X, y = dataloader.get_idx_data(i)
  73. model = train(X, y, out_classes=n_classes)
  74. model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  75. torch.save(model.state_dict(), model_save_path)
  76. logger.info("Model saved to '%s'" % (model_save_path))
  77. def prepare_learnware(data_path, model_path, init_file_path, yaml_path, env_file_path, save_root, zip_name):
  78. os.makedirs(save_root, exist_ok=True)
  79. tmp_spec_path = os.path.join(save_root, "rkme.json")
  80. tmp_model_path = os.path.join(save_root, "model.pth")
  81. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  82. tmp_init_path = os.path.join(save_root, "__init__.py")
  83. tmp_env_path = os.path.join(save_root, "requirements.txt")
  84. with open(data_path, "rb") as f:
  85. X = pickle.load(f)
  86. semantic_spec = semantic_specs[0]
  87. st = time.time()
  88. user_spec = RKMETextSpecification()
  89. user_spec.generate_stat_spec_from_data(X=X)
  90. ed = time.time()
  91. logger.info("Stat spec generated in %.3f s" % (ed - st))
  92. user_spec.save(tmp_spec_path)
  93. copyfile(model_path, tmp_model_path)
  94. copyfile(yaml_path, tmp_yaml_path)
  95. copyfile(init_file_path, tmp_init_path)
  96. copyfile(env_file_path, tmp_env_path)
  97. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  98. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  99. zip_obj.write(tmp_spec_path, "rkme.json")
  100. zip_obj.write(tmp_model_path, "model.pth")
  101. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  102. zip_obj.write(tmp_init_path, "__init__.py")
  103. zip_obj.write(tmp_env_path, "requirements.txt")
  104. rmtree(save_root)
  105. logger.info("New Learnware Saved to %s" % (zip_file_name))
  106. return zip_file_name
  107. def prepare_market():
  108. text_market = instantiate_learnware_market(market_id="sst2", rebuild=True)
  109. try:
  110. rmtree(learnware_pool_dir)
  111. except:
  112. pass
  113. os.makedirs(learnware_pool_dir, exist_ok=True)
  114. for i in range(n_uploaders):
  115. data_path = os.path.join(uploader_save_root, "uploader_%d_X.pkl" % (i))
  116. model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  117. init_file_path = "./example_files/example_init.py"
  118. yaml_file_path = "./example_files/example_yaml.yaml"
  119. env_file_path = "./example_files/requirements.txt"
  120. new_learnware_path = prepare_learnware(
  121. data_path, model_path, init_file_path, yaml_file_path, env_file_path, tmp_dir, "%s_%d" % (dataset, i)
  122. )
  123. semantic_spec = semantic_specs[0]
  124. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  125. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  126. text_market.add_learnware(new_learnware_path, semantic_spec)
  127. logger.info("Total Item: %d" % (len(text_market)))
  128. def test_search(gamma=0.1, load_market=True):
  129. if load_market:
  130. text_market = instantiate_learnware_market(market_id="sst2")
  131. else:
  132. prepare_market()
  133. text_market = instantiate_learnware_market(market_id="sst2")
  134. logger.info("Number of items in the market: %d" % len(text_market))
  135. select_list = []
  136. avg_list = []
  137. improve_list = []
  138. job_selector_score_list = []
  139. ensemble_score_list = []
  140. pruning_score_list = []
  141. for i in range(n_users):
  142. user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i))
  143. user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i))
  144. with open(user_data_path, "rb") as f:
  145. user_data = pickle.load(f)
  146. with open(user_label_path, "rb") as f:
  147. user_label = pickle.load(f)
  148. user_stat_spec = RKMETextSpecification()
  149. user_stat_spec.generate_stat_spec_from_data(X=user_data)
  150. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec})
  151. logger.info("Searching Market for user: %d" % (i))
  152. sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware(
  153. user_info
  154. )
  155. l = len(sorted_score_list)
  156. acc_list = []
  157. for idx in range(l):
  158. learnware = single_learnware_list[idx]
  159. score = sorted_score_list[idx]
  160. pred_y = learnware.predict(user_data)
  161. acc = eval_prediction(pred_y, user_label)
  162. acc_list.append(acc)
  163. logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  164. # test reuse (job selector)
  165. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  166. reuse_predict = reuse_baseline.predict(user_data=user_data)
  167. reuse_score = eval_prediction(reuse_predict, user_label)
  168. job_selector_score_list.append(reuse_score)
  169. print(f"mixture reuse loss(job selector): {reuse_score}")
  170. # test reuse (ensemble)
  171. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
  172. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  173. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  174. ensemble_score_list.append(ensemble_score)
  175. print(f"mixture reuse accuracy (ensemble): {ensemble_score}")
  176. # test reuse (ensemblePruning)
  177. reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list)
  178. pruning_predict_y = reuse_pruning.predict(user_data=user_data)
  179. pruning_score = eval_prediction(pruning_predict_y, user_label)
  180. pruning_score_list.append(pruning_score)
  181. print(f"mixture reuse accuracy (ensemble Pruning): {pruning_score}\n")
  182. select_list.append(acc_list[0])
  183. avg_list.append(np.mean(acc_list))
  184. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  185. logger.info(
  186. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  187. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  188. )
  189. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  190. logger.info(
  191. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  192. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  193. )
  194. logger.info(
  195. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  196. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  197. )
  198. logger.info(
  199. "Selective Ensemble Reuse Performance: %.3f +/- %.3f"
  200. % (np.mean(pruning_score_list), np.std(pruning_score_list))
  201. )
  202. if __name__ == "__main__":
  203. prepare_data()
  204. prepare_model()
  205. test_search(load_market=False)