You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework_hed.py 13 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. import pickle as pk
  13. import torch
  14. import torch.nn as nn
  15. import numpy as np
  16. import os
  17. from utils.plog import INFO, DEBUG, clocker
  18. from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res
  19. from models.nn import MLP, SymbolNetAutoencoder
  20. from models.basic_model import BasicModel, BasicDataset
  21. from datasets.hed.get_hed import get_pretrain_data
  22. def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
  23. result = {}
  24. if char_acc_flag:
  25. char_acc_num = 0
  26. char_num = 0
  27. for pred_z, z in zip(pred_Z, Z):
  28. char_num += len(z)
  29. for zidx in range(len(z)):
  30. if pred_z[zidx] == z[zidx]:
  31. char_acc_num += 1
  32. char_acc = char_acc_num / char_num
  33. result["Character level accuracy"] = char_acc
  34. abl_acc_num = 0
  35. for pred_z, y in zip(pred_Z, Y):
  36. if logic_forward(pred_z) == y:
  37. abl_acc_num += 1
  38. abl_acc = abl_acc_num / len(Y)
  39. result["ABL accuracy"] = abl_acc
  40. return result
  41. def filter_data(X, abduced_Z):
  42. finetune_Z = []
  43. finetune_X = []
  44. for abduced_x, abduced_z in zip(X, abduced_Z):
  45. if abduced_z is not []:
  46. finetune_X.append(abduced_x)
  47. finetune_Z.append(abduced_z)
  48. return finetune_X, finetune_Z
  49. def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbose=-1):
  50. train_X, train_Z, train_Y = train_data
  51. test_X, test_Z, test_Y = test_data
  52. # Set default parameters
  53. if sample_num == -1:
  54. sample_num = len(train_X)
  55. if verbose < 1:
  56. verbose = epochs
  57. char_acc_flag = 1
  58. if train_Z == None:
  59. char_acc_flag = 0
  60. train_Z = [None] * len(train_X)
  61. predict_func = clocker(model.predict)
  62. train_func = clocker(model.train)
  63. abduce_func = clocker(abducer.batch_abduce)
  64. for epoch_idx in range(epochs):
  65. X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx)
  66. preds_res = predict_func(X)
  67. # input()
  68. abduced_Z = abduce_func(preds_res, Y)
  69. if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
  70. res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
  71. INFO('epoch: ', epoch_idx + 1, ' ', res)
  72. finetune_X, finetune_Z = filter_data(X, abduced_Z)
  73. if len(finetune_X) > 0:
  74. # model.valid(finetune_X, finetune_Z)
  75. train_func(finetune_X, finetune_Z)
  76. else:
  77. INFO("lack of data, all abduced failed", len(finetune_X))
  78. return res
  79. def hed_pretrain(kb, cls, recorder):
  80. cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
  81. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  82. if not os.path.exists("./weights/pretrain_weights.pth"):
  83. INFO("Pretrain Start")
  84. pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11'])
  85. pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
  86. pretrain_data_loader = torch.utils.data.DataLoader(pretrain_data, batch_size=64, shuffle=True)
  87. criterion = nn.MSELoss()
  88. optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)
  89. pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
  90. pretrain_model.fit(pretrain_data_loader)
  91. torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
  92. cls.load_state_dict(cls_autoencoder.base_model.state_dict())
  93. else:
  94. cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  95. def _get_char_acc(model, X, consistent_pred_res, mapping):
  96. original_pred_res = model.predict(X)['cls']
  97. pred_res = flatten(mapping_res(original_pred_res, mapping))
  98. INFO('Current model\'s output: ', pred_res)
  99. INFO('Abduced labels: ', flatten(consistent_pred_res))
  100. assert len(pred_res) == len(flatten(consistent_pred_res))
  101. return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res)
  102. def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
  103. select_idx = np.random.randint(len(train_X_true), size=select_num)
  104. X = []
  105. for idx in select_idx:
  106. X.append(train_X_true[idx])
  107. original_pred_res = model.predict(X)['cls']
  108. if mapping == None:
  109. mappings = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1])
  110. else:
  111. mappings = [mapping]
  112. consistent_idx = []
  113. consistent_pred_res = []
  114. for m in mappings:
  115. pred_res = mapping_res(original_pred_res, m)
  116. max_abduce_num = 20
  117. solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num)
  118. all_address_flag = reform_idx(solution, pred_res)
  119. consistent_idx_tmp = []
  120. consistent_pred_res_tmp = []
  121. for idx in range(len(pred_res)):
  122. address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
  123. candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True)
  124. if len(candidate) > 0:
  125. consistent_idx_tmp.append(idx)
  126. consistent_pred_res_tmp.append(candidate[0][0])
  127. if len(consistent_idx_tmp) > len(consistent_idx):
  128. consistent_idx = consistent_idx_tmp
  129. consistent_pred_res = consistent_pred_res_tmp
  130. if len(mappings) > 1:
  131. mapping = m
  132. if len(consistent_idx) == 0:
  133. return 0, 0, None
  134. if len(mappings) > 1:
  135. INFO('Final mapping is: ', mapping)
  136. INFO('Train pool size is:', len(flatten(consistent_pred_res)))
  137. INFO("Start to use abduced pseudo label to train model...")
  138. model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))
  139. consistent_acc = len(consistent_idx) / select_num
  140. char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
  141. INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
  142. return consistent_acc, char_acc, mapping
  143. def _remove_duplicate_rule(rule_dict):
  144. add_nums_dict = {}
  145. for r in list(rule_dict):
  146. add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
  147. if add_nums in add_nums_dict:
  148. old_r = add_nums_dict[add_nums]
  149. if rule_dict[r] >= rule_dict[old_r]:
  150. rule_dict.pop(old_r)
  151. add_nums_dict[add_nums] = r
  152. else:
  153. rule_dict.pop(r)
  154. else:
  155. add_nums_dict[add_nums] = r
  156. return list(rule_dict)
  157. def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num):
  158. rules = []
  159. for _ in range(samples_num):
  160. while True:
  161. select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
  162. X = []
  163. for idx in select_idx:
  164. X.append(train_X_true[idx])
  165. original_pred_res = model.predict(X)['cls']
  166. pred_res = mapping_res(original_pred_res, mapping)
  167. consistent_idx = []
  168. consistent_pred_res = []
  169. for idx in range(len(pred_res)):
  170. if abducer.kb.logic_forward([pred_res[idx]]):
  171. consistent_idx.append(idx)
  172. consistent_pred_res.append(pred_res[idx])
  173. if len(consistent_pred_res) != 0:
  174. rule = abducer.abduce_rules(consistent_pred_res)
  175. if rule != None:
  176. break
  177. rules.append(rule)
  178. all_rule_dict = {}
  179. for rule in rules:
  180. for r in rule:
  181. all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
  182. rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
  183. rules = _remove_duplicate_rule(rule_dict)
  184. return rules
  185. def _get_consist_rule_acc(model, abducer, mapping, rules, X):
  186. cnt = 0
  187. for x in X:
  188. original_pred_res = model.predict([x])['cls']
  189. pred_res = flatten(mapping_res(original_pred_res, mapping))
  190. if abducer.kb.consist_rule(pred_res, rules):
  191. cnt += 1
  192. return cnt / len(X)
  193. def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8):
  194. train_X = train_data
  195. val_X = val_data
  196. samples_num = 50
  197. samples_per_rule = 3
  198. # Start training / for each length of equations
  199. for equation_len in range(min_len, max_len):
  200. INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1))
  201. train_X_true = train_X[1][equation_len]
  202. train_X_false = train_X[0][equation_len]
  203. val_X_true = val_X[1][equation_len]
  204. val_X_false = val_X[0][equation_len]
  205. train_X_true.extend(train_X[1][equation_len + 1])
  206. train_X_false.extend(train_X[0][equation_len + 1])
  207. val_X_true.extend(val_X[1][equation_len + 1])
  208. val_X_false.extend(val_X[0][equation_len + 1])
  209. condition_cnt = 0
  210. while True:
  211. if equation_len == min_len:
  212. mapping = None
  213. # Abduce and train NN
  214. consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, mapping, train_X_true, select_num)
  215. if consistent_acc == 0:
  216. continue
  217. # Test if we can use mlp to evaluate
  218. if consistent_acc >= 0.9 and char_acc >= 0.9:
  219. condition_cnt += 1
  220. else:
  221. condition_cnt = 0
  222. # The condition has been satisfied continuously five times
  223. if condition_cnt >= 5:
  224. INFO("Now checking if we can go to next course")
  225. rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num)
  226. INFO('Learned rules from data:', rules)
  227. true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true)
  228. false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false)
  229. INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc))
  230. # decide next course or restart
  231. if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1:
  232. torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
  233. break
  234. else:
  235. if equation_len == min_len:
  236. model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  237. else:
  238. model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1)))
  239. condition_cnt = 0
  240. INFO('Reload Model and retrain')
  241. return model, mapping
  242. def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
  243. train_X = train_data
  244. test_X = test_data
  245. # Calcualte how many equations should be selected in each length
  246. # for each length, there are equation_samples_num[equation_len] rules
  247. print("Now begin to train final mlp model")
  248. equation_samples_num = []
  249. len_cnt = max_len - min_len + 1
  250. samples_num = 50
  251. equation_samples_num += [0] * min_len
  252. if samples_num % len_cnt == 0:
  253. equation_samples_num += [samples_num // len_cnt] * len_cnt
  254. else:
  255. equation_samples_num += [samples_num // len_cnt] * len_cnt
  256. equation_samples_num[-1] += samples_num % len_cnt
  257. assert sum(equation_samples_num) == samples_num
  258. # Abduce rules
  259. rules = []
  260. samples_per_rule = 3
  261. for equation_len in range(min_len, max_len + 1):
  262. equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len])
  263. rules.extend(equation_rules)
  264. rules = list(set(rules))
  265. INFO('Learned rules from data:', rules)
  266. for equation_len in range(5, 27):
  267. true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len])
  268. false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len])
  269. INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc))
  270. if __name__ == "__main__":
  271. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.