You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import os
  2. import fire
  3. import time
  4. import random
  5. import pickle
  6. import tempfile
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. from sklearn.metrics import accuracy_score
  10. from sklearn.naive_bayes import MultinomialNB
  11. from sklearn.feature_extraction.text import TfidfVectorizer
  12. from learnware.client import LearnwareClient
  13. from learnware.logger import get_module_logger
  14. from learnware.specification import RKMETextSpecification
  15. from learnware.tests.benchmarks import LearnwareBenchmark
  16. from learnware.market import instantiate_learnware_market, BaseUserInfo
  17. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  18. from config import text_benchmark_config
  19. logger = get_module_logger("text_workflow", level="INFO")
  20. class TextDatasetWorkflow:
  21. @staticmethod
  22. def _train_model(X, y):
  23. vectorizer = TfidfVectorizer(stop_words="english")
  24. X_tfidf = vectorizer.fit_transform(X)
  25. clf = MultinomialNB(alpha=0.1)
  26. clf.fit(X_tfidf, y)
  27. return vectorizer, clf
  28. @staticmethod
  29. def _eval_prediction(pred_y, target_y):
  30. if not isinstance(pred_y, np.ndarray):
  31. pred_y = pred_y.detach().cpu().numpy()
  32. pred_y = np.array(pred_y) if len(pred_y.shape) == 1 else np.argmax(pred_y, 1)
  33. target_y = np.array(target_y)
  34. return accuracy_score(target_y, pred_y)
  35. def _plot_labeled_peformance_curves(self, all_user_curves_data):
  36. plt.figure(figsize=(10, 6))
  37. plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)
  38. styles = [
  39. {"color": "navy", "linestyle": "-", "marker": "o"},
  40. {"color": "magenta", "linestyle": "-.", "marker": "d"},
  41. ]
  42. labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
  43. user_array, pruning_array = all_user_curves_data
  44. for array, style, label in zip([user_array, pruning_array], styles, labels):
  45. mean_curve = np.array([item[0] for item in array])
  46. std_curve = np.array([item[1] for item in array])
  47. plt.plot(mean_curve, **style, label=label)
  48. plt.fill_between(
  49. range(len(mean_curve)),
  50. mean_curve - std_curve,
  51. mean_curve + std_curve,
  52. color=style["color"],
  53. alpha=0.2,
  54. )
  55. plt.xlabel("Amout of Labeled User Data", fontsize=14)
  56. plt.ylabel("1 - Accuracy", fontsize=14)
  57. plt.title(f"Results on Text Experimental Scenario", fontsize=16)
  58. plt.legend(fontsize=14)
  59. plt.tight_layout()
  60. plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  61. def _prepare_market(self, rebuild=False):
  62. client = LearnwareClient()
  63. self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config)
  64. self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild)
  65. self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0])
  66. self.user_semantic["Name"]["Values"] = ""
  67. if len(self.text_market) == 0 or rebuild == True:
  68. for learnware_id in self.text_benchmark.learnware_ids:
  69. with tempfile.TemporaryDirectory(prefix="text_benchmark_") as tempdir:
  70. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  71. for i in range(20):
  72. try:
  73. semantic_spec = client.get_semantic_specification(learnware_id)
  74. client.download_learnware(learnware_id, zip_path)
  75. self.text_market.add_learnware(zip_path, semantic_spec)
  76. break
  77. except:
  78. time.sleep(1)
  79. continue
  80. logger.info("Total Item: %d" % (len(self.text_market)))
  81. def unlabeled_text_example(self, rebuild=False):
  82. self._prepare_market(rebuild)
  83. select_list = []
  84. avg_list = []
  85. best_list = []
  86. improve_list = []
  87. job_selector_score_list = []
  88. ensemble_score_list = []
  89. all_learnwares = self.text_market.get_learnwares()
  90. for i in range(text_benchmark_config.user_num):
  91. user_data, user_label = self.text_benchmark.get_test_data(user_ids=i)
  92. user_stat_spec = RKMETextSpecification()
  93. user_stat_spec.generate_stat_spec_from_data(X=user_data)
  94. user_info = BaseUserInfo(
  95. semantic_spec=self.user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}
  96. )
  97. logger.info("Searching Market for user: %d" % (i))
  98. search_result = self.text_market.search_learnware(user_info)
  99. single_result = search_result.get_single_results()
  100. multiple_result = search_result.get_multiple_results()
  101. print(f"search result of user{i}:")
  102. print(
  103. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  104. )
  105. acc_list = []
  106. for idx in range(len(all_learnwares)):
  107. learnware = all_learnwares[idx]
  108. pred_y = learnware.predict(user_data)
  109. acc = self._eval_prediction(pred_y, user_label)
  110. acc_list.append(acc)
  111. learnware = single_result[0].learnware
  112. pred_y = learnware.predict(user_data)
  113. best_acc = self._eval_prediction(pred_y, user_label)
  114. best_list.append(np.max(acc_list))
  115. select_list.append(best_acc)
  116. avg_list.append(np.mean(acc_list))
  117. improve_list.append((best_acc - np.mean(acc_list)) / np.mean(acc_list))
  118. print(f"market mean accuracy: {np.mean(acc_list)}, market best accuracy: {np.max(acc_list)}")
  119. print(
  120. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {best_acc}"
  121. )
  122. if len(multiple_result) > 0:
  123. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  124. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  125. mixture_learnware_list = multiple_result[0].learnwares
  126. else:
  127. mixture_learnware_list = [single_result[0].learnware]
  128. # test reuse (job selector)
  129. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  130. reuse_predict = reuse_baseline.predict(user_data=user_data)
  131. reuse_score = self._eval_prediction(reuse_predict, user_label)
  132. job_selector_score_list.append(reuse_score)
  133. print(f"mixture reuse accuracy (job selector): {reuse_score}")
  134. # test reuse (ensemble)
  135. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
  136. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  137. ensemble_score = self._eval_prediction(ensemble_predict_y, user_label)
  138. ensemble_score_list.append(ensemble_score)
  139. print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n")
  140. logger.info(
  141. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f"
  142. % (
  143. np.mean(select_list),
  144. np.std(select_list),
  145. np.mean(avg_list),
  146. np.std(avg_list),
  147. np.mean(best_list),
  148. np.std(best_list),
  149. )
  150. )
  151. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  152. logger.info(
  153. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  154. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  155. )
  156. logger.info(
  157. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  158. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  159. )
  160. def labeled_text_example(self, rebuild=False, skip_test=False):
  161. self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000]
  162. self.repeated_list = [10, 10, 10, 3, 3, 3]
  163. self.root_path = os.path.dirname(os.path.abspath(__file__))
  164. self.fig_path = os.path.join(self.root_path, "figs")
  165. self.curve_path = os.path.join(self.root_path, "curves")
  166. if not skip_test:
  167. self._prepare_market(rebuild)
  168. os.makedirs(self.fig_path, exist_ok=True)
  169. os.makedirs(self.curve_path, exist_ok=True)
  170. for i in range(text_benchmark_config.user_num):
  171. user_model_score_mat = []
  172. pruning_score_mat = []
  173. single_score_mat = []
  174. test_x, test_y = self.text_benchmark.get_test_data(user_ids=i)
  175. test_y = np.array(test_y)
  176. train_x, train_y = self.text_benchmark.get_train_data(user_ids=i)
  177. train_y = np.array(train_y)
  178. user_stat_spec = RKMETextSpecification()
  179. user_stat_spec.generate_stat_spec_from_data(X=test_x)
  180. user_info = BaseUserInfo(
  181. semantic_spec=self.user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}
  182. )
  183. logger.info(f"Searching Market for user_{i}")
  184. search_result = self.text_market.search_learnware(user_info)
  185. single_result = search_result.get_single_results()
  186. multiple_result = search_result.get_multiple_results()
  187. learnware = single_result[0].learnware
  188. pred_y = learnware.predict(test_x)
  189. best_acc = self._eval_prediction(pred_y, test_y)
  190. print(f"search result of user_{i}:")
  191. print(
  192. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}, single model acc: {best_acc}"
  193. )
  194. if len(multiple_result) > 0:
  195. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  196. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  197. mixture_learnware_list = multiple_result[0].learnwares
  198. else:
  199. mixture_learnware_list = [single_result[0].learnware]
  200. for n_label, repeated in zip(self.n_labeled_list, self.repeated_list):
  201. user_model_score_list, reuse_pruning_score_list = [], []
  202. if n_label > len(train_x):
  203. n_label = len(train_x)
  204. for _ in range(repeated):
  205. x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
  206. x_train = list(x_train)
  207. y_train = np.array(list(y_train))
  208. modelv, modell = self._train_model(x_train, y_train)
  209. user_model_predict_y = modell.predict(modelv.transform(test_x))
  210. user_model_score = self._eval_prediction(user_model_predict_y, test_y)
  211. user_model_score_list.append(user_model_score)
  212. reuse_pruning = EnsemblePruningReuser(
  213. learnware_list=mixture_learnware_list, mode="classification"
  214. )
  215. reuse_pruning.fit(x_train, y_train)
  216. reuse_pruning_predict_y = reuse_pruning.predict(user_data=test_x)
  217. reuse_pruning_score = self._eval_prediction(reuse_pruning_predict_y, test_y)
  218. reuse_pruning_score_list.append(reuse_pruning_score)
  219. single_score_mat.append([best_acc] * repeated)
  220. user_model_score_mat.append(user_model_score_list)
  221. pruning_score_mat.append(reuse_pruning_score_list)
  222. print(
  223. f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}"
  224. )
  225. logger.info(f"Saving Curves for User_{i}")
  226. user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat)
  227. with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "wb") as f:
  228. pickle.dump(user_curves_data, f)
  229. pruning_curves_data, user_model_curves_data = [], []
  230. total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  231. total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  232. for user_idx in range(text_benchmark_config.user_num):
  233. with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f:
  234. user_curves_data = pickle.load(f)
  235. (single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data
  236. for i in range(len(self.n_labeled_list)):
  237. total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i])
  238. total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i])
  239. for i in range(len(self.n_labeled_list)):
  240. total_user_model_score_mat[i] /= text_benchmark_config.user_num
  241. total_pruning_score_mat[i] /= text_benchmark_config.user_num
  242. user_model_curves_data.append(
  243. (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i]))
  244. )
  245. pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i])))
  246. self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])
  247. if __name__ == "__main__":
  248. fire.Fire(TextDatasetWorkflow)