You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 7.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import numpy as np
  2. import torch
  3. import get_data
  4. import os
  5. import random
  6. from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
  7. import time
  8. from learnware.market import EasyMarket, BaseUserInfo
  9. from learnware.market import database_ops
  10. from learnware.learnware import Learnware
  11. import learnware.specification as specification
  12. from learnware.logger import get_module_logger
  13. from shutil import copyfile, rmtree
  14. import zipfile
  15. logger = get_module_logger("image_test", level="INFO")
  16. origin_data_root = "./data/origin_data"
  17. processed_data_root = "./data/processed_data"
  18. tmp_dir = "./data/tmp"
  19. learnware_pool_dir = "./data/learnware_pool"
  20. dataset = "cifar10"
  21. n_uploaders = 50
  22. n_users = 10
  23. n_classes = 10
  24. data_root = os.path.join(origin_data_root, dataset)
  25. data_save_root = os.path.join(processed_data_root, dataset)
  26. user_save_root = os.path.join(data_save_root, "user")
  27. uploader_save_root = os.path.join(data_save_root, "uploader")
  28. model_save_root = os.path.join(data_save_root, "uploader_model")
  29. os.makedirs(data_root, exist_ok=True)
  30. os.makedirs(user_save_root, exist_ok=True)
  31. os.makedirs(uploader_save_root, exist_ok=True)
  32. os.makedirs(model_save_root, exist_ok=True)
  33. semantic_specs = [
  34. {
  35. "Data": {"Values": ["Tabular"], "Type": "Class"},
  36. "Task": {
  37. "Values": ["Classification"],
  38. "Type": "Class",
  39. },
  40. "Device": {"Values": ["GPU"], "Type": "Tag"},
  41. "Scenario": {"Values": ["Nature"], "Type": "Tag"},
  42. "Description": {"Values": "", "Type": "String"},
  43. "Name": {"Values": "learnware_1", "Type": "String"},
  44. },
  45. {
  46. "Data": {"Values": ["Tabular"], "Type": "Class"},
  47. "Task": {
  48. "Values": ["Classification"],
  49. "Type": "Class",
  50. },
  51. "Device": {"Values": ["GPU"], "Type": "Tag"},
  52. "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
  53. "Description": {"Values": "", "Type": "String"},
  54. "Name": {"Values": "learnware_2", "Type": "String"},
  55. },
  56. {
  57. "Data": {"Values": ["Tabular"], "Type": "Class"},
  58. "Task": {
  59. "Values": ["Classification"],
  60. "Type": "Class",
  61. },
  62. "Device": {"Values": ["GPU"], "Type": "Tag"},
  63. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  64. "Description": {"Values": "", "Type": "String"},
  65. "Name": {"Values": "learnware_3", "Type": "String"},
  66. },
  67. ]
  68. user_senmantic = {
  69. "Data": {"Values": ["Tabular"], "Type": "Class"},
  70. "Task": {
  71. "Values": ["Classification"],
  72. "Type": "Class",
  73. },
  74. "Device": {"Values": ["GPU"], "Type": "Tag"},
  75. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  76. "Description": {"Values": "", "Type": "String"},
  77. "Name": {"Values": "", "Type": "String"},
  78. }
  79. def prepare_data():
  80. if dataset == "cifar10":
  81. X_train, y_train, X_test, y_test = get_data.get_cifar10(data_root)
  82. elif dataset == "mnist":
  83. X_train, y_train, X_test, y_test = get_data.get_mnist(data_root)
  84. else:
  85. return
  86. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  87. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  88. def prepare_model():
  89. dataloader = ImageDataLoader(data_save_root, train=True)
  90. for i in range(n_uploaders):
  91. logger.info("Train on uploader: %d" % (i))
  92. X, y = dataloader.get_idx_data(i)
  93. model = train(X, y, out_classes=n_classes)
  94. model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  95. torch.save(model.state_dict(), model_save_path)
  96. logger.info("Model saved to '%s'" % (model_save_path))
  97. def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name):
  98. os.makedirs(save_root, exist_ok=True)
  99. tmp_spec_path = os.path.join(save_root, "rkme.json")
  100. tmp_model_path = os.path.join(save_root, "conv_model.pth")
  101. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  102. tmp_init_path = os.path.join(save_root, "__init__.py")
  103. X = np.load(data_path)
  104. st = time.time()
  105. user_spec = specification.utils.generate_rkme_spec(X=X, gamma=0.1, cuda_idx=0)
  106. ed = time.time()
  107. logger.info("Stat spec generated in %.3f s" % (ed - st))
  108. user_spec.save(tmp_spec_path)
  109. copyfile(model_path, tmp_model_path)
  110. copyfile(yaml_path, tmp_yaml_path)
  111. copyfile(init_file_path, tmp_init_path)
  112. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  113. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  114. zip_obj.write(tmp_spec_path, "rkme.json")
  115. zip_obj.write(tmp_model_path, "conv_model.pth")
  116. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  117. zip_obj.write(tmp_init_path, "__init__.py")
  118. rmtree(save_root)
  119. logger.info("New Learnware Saved to %s" % (zip_file_name))
  120. return zip_file_name
  121. def prepare_market():
  122. image_market = EasyMarket(rebuild=True)
  123. rmtree(learnware_pool_dir)
  124. os.makedirs(learnware_pool_dir, exist_ok=True)
  125. for i in range(n_uploaders):
  126. data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
  127. model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  128. init_file_path = "./example_init.py"
  129. yaml_file_path = "./example_yaml.yaml"
  130. new_learnware_path = prepare_learnware(
  131. data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i)
  132. )
  133. semantic_spec = semantic_specs[i % 3]
  134. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  135. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  136. image_market.add_learnware(new_learnware_path, semantic_spec)
  137. logger.info("Total Item:", len(image_market))
  138. curr_inds = image_market._get_ids()
  139. logger.info("Available ids:", curr_inds)
  140. def test_search(load_market=True):
  141. if load_market:
  142. image_market = EasyMarket()
  143. else:
  144. prepare_market()
  145. image_market = EasyMarket()
  146. logger.info("Number of items in the market: %d" % len(image_market))
  147. for i in range(n_users):
  148. user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i))
  149. user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
  150. user_data = np.load(user_data_path)
  151. user_label = np.load(user_label_path)
  152. user_stat_spec = specification.utils.generate_rkme_spec(X=user_data, gamma=0.1, cuda_idx=0)
  153. user_info = BaseUserInfo(
  154. id=f"user_{i}", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_stat_spec}
  155. )
  156. logger.info("Searching Market for user: %d" % (i))
  157. sorted_score_list, single_learnware_list, mixture_learnware_list = image_market.search_learnware(user_info)
  158. l = len(sorted_score_list)
  159. for idx in range(min(l, 10)):
  160. learnware = single_learnware_list[idx]
  161. score = sorted_score_list[idx]
  162. pred_y = learnware.predict(user_data)
  163. acc = eval_prediction(pred_y, user_label)
  164. logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  165. if __name__ == "__main__":
  166. # prepare_data()
  167. # prepare_model()
  168. test_search(False)

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。