You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework_hed.py 16 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. import pickle as pk
  13. import torch
  14. import torch.nn as nn
  15. import numpy as np
  16. import os
  17. from utils.plog import INFO, DEBUG, clocker
  18. from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res
  19. from models.nn import MLP, SymbolNetAutoencoder
  20. from models.basic_model import BasicModel, BasicDataset
  21. from datasets.hed.get_hed import get_pretrain_data
  22. def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
  23. result = {}
  24. if char_acc_flag:
  25. char_acc_num = 0
  26. char_num = 0
  27. for pred_z, z in zip(pred_Z, Z):
  28. char_num += len(z)
  29. for zidx in range(len(z)):
  30. if pred_z[zidx] == z[zidx]:
  31. char_acc_num += 1
  32. char_acc = char_acc_num / char_num
  33. result["Character level accuracy"] = char_acc
  34. abl_acc_num = 0
  35. for pred_z, y in zip(pred_Z, Y):
  36. if logic_forward(pred_z) == y:
  37. abl_acc_num += 1
  38. abl_acc = abl_acc_num / len(Y)
  39. result["ABL accuracy"] = abl_acc
  40. return result
  41. def filter_data(X, abduced_Z):
  42. finetune_Z = []
  43. finetune_X = []
  44. for abduced_x, abduced_z in zip(X, abduced_Z):
  45. if abduced_z is not []:
  46. finetune_X.append(abduced_x)
  47. finetune_Z.append(abduced_z)
  48. return finetune_X, finetune_Z
  49. def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbose=-1):
  50. train_X, train_Z, train_Y = train_data
  51. test_X, test_Z, test_Y = test_data
  52. # Set default parameters
  53. if sample_num == -1:
  54. sample_num = len(train_X)
  55. if verbose < 1:
  56. verbose = epochs
  57. char_acc_flag = 1
  58. if train_Z == None:
  59. char_acc_flag = 0
  60. train_Z = [None] * len(train_X)
  61. predict_func = clocker(model.predict)
  62. train_func = clocker(model.train)
  63. abduce_func = clocker(abducer.batch_abduce)
  64. for epoch_idx in range(epochs):
  65. X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx)
  66. preds_res = predict_func(X)
  67. # input()
  68. abduced_Z = abduce_func(preds_res, Y)
  69. if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
  70. res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
  71. INFO('epoch: ', epoch_idx + 1, ' ', res)
  72. finetune_X, finetune_Z = filter_data(X, abduced_Z)
  73. if len(finetune_X) > 0:
  74. # model.valid(finetune_X, finetune_Z)
  75. train_func(finetune_X, finetune_Z)
  76. else:
  77. INFO("lack of data, all abduced failed", len(finetune_X))
  78. return res
  79. def hed_pretrain(kb, cls, recorder):
  80. cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
  81. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  82. if not os.path.exists("./weights/pretrain_weights.pth"):
  83. INFO("Pretrain Start")
  84. pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11'])
  85. pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
  86. pretrain_data_loader = torch.utils.data.DataLoader(pretrain_data, batch_size=64, shuffle=True)
  87. criterion = nn.MSELoss()
  88. optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)
  89. pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
  90. pretrain_model.fit(pretrain_data_loader)
  91. torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
  92. cls.load_state_dict(cls_autoencoder.base_model.state_dict())
  93. else:
  94. cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  95. def get_char_acc(model, X, consistent_pred_res, mapping):
  96. original_pred_res = model.predict(X)['cls']
  97. pred_res = flatten(mapping_res(original_pred_res, mapping))
  98. INFO('Current model\'s output: ', pred_res)
  99. INFO('Abduced labels: ', flatten(consistent_pred_res))
  100. assert len(pred_res) == len(flatten(consistent_pred_res))
  101. return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res)
  102. def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
  103. select_idx = np.random.randint(len(train_X_true), size=select_num)
  104. X = []
  105. for idx in select_idx:
  106. X.append(train_X_true[idx])
  107. original_pred_res = model.predict(X)['cls']
  108. if mapping == None:
  109. mappings = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1])
  110. else:
  111. mappings = [mapping]
  112. consistent_idx = []
  113. consistent_pred_res = []
  114. for m in mappings:
  115. pred_res = mapping_res(original_pred_res, m)
  116. max_abduce_num = 20
  117. solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num)
  118. all_address_flag = reform_idx(solution, pred_res)
  119. consistent_idx_tmp = []
  120. consistent_pred_res_tmp = []
  121. for idx in range(len(pred_res)):
  122. address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
  123. candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True)
  124. if len(candidate) > 0:
  125. consistent_idx_tmp.append(idx)
  126. consistent_pred_res_tmp.append(candidate[0][0])
  127. if len(consistent_idx_tmp) > len(consistent_idx):
  128. consistent_idx = consistent_idx_tmp
  129. consistent_pred_res = consistent_pred_res_tmp
  130. if len(mappings) > 1:
  131. mapping = m
  132. if len(consistent_idx) == 0:
  133. return 0, 0, None
  134. if len(mappings) > 1:
  135. INFO('Final mapping is: ', mapping)
  136. INFO('Train pool size is:', len(flatten(consistent_pred_res)))
  137. INFO("Start to use abduced pseudo label to train model...")
  138. model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))
  139. consistent_acc = len(consistent_idx) / select_num
  140. char_acc = get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
  141. INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
  142. return consistent_acc, char_acc, mapping
  143. def output_rules(rules):
  144. all_rule_dict = {}
  145. for rule in rules:
  146. for r in rule:
  147. all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
  148. rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items()}# if cnt >= 5}
  149. return rule_dict
  150. def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim):
  151. rules = []
  152. for _ in range(logic_output_dim):
  153. while True:
  154. select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
  155. X = []
  156. for idx in select_idx:
  157. X.append(train_X_true[idx])
  158. original_pred_res = model.predict(X)['cls']
  159. pred_res = mapping_res(original_pred_res, mapping)
  160. consistent_idx = []
  161. consistent_pred_res = []
  162. for idx in range(len(pred_res)):
  163. if abducer.kb.logic_forward([pred_res[idx]]):
  164. consistent_idx.append(idx)
  165. consistent_pred_res.append(pred_res[idx])
  166. if len(consistent_pred_res) != 0:
  167. rule = abducer.abduce_rules(consistent_pred_res)
  168. if rule != None:
  169. break
  170. rules.append(rule)
  171. return rules
  172. def get_mlp_vector(model, abducer, mapping, X, rules):
  173. original_pred_res = model.predict([X])['cls']
  174. pred_res = flatten(mapping_res(original_pred_res, mapping))
  175. vector = []
  176. for rule in rules:
  177. if abducer.kb.consist_rule(pred_res, rule):
  178. vector.append(1)
  179. else:
  180. vector.append(0)
  181. return vector
  182. def get_mlp_data(model, abducer, mapping, X_true, X_false, rules):
  183. mlp_vectors = []
  184. mlp_labels = []
  185. for X in X_true:
  186. mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules))
  187. mlp_labels.append(1)
  188. for X in X_false:
  189. mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules))
  190. mlp_labels.append(0)
  191. return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64)
  192. def get_all_mlp_data(model, abducer, mapping, X_true, X_false, rules, min_len, max_len):
  193. for equation_len in range(min_len, max_len + 1):
  194. mlp_vectors, mlp_labels = get_mlp_data(model, abducer, mapping, X_true[equation_len], X_false[equation_len], rules)
  195. if equation_len == min_len:
  196. all_mlp_vectors = mlp_vectors
  197. all_mlp_labels = mlp_labels
  198. else:
  199. all_mlp_vectors = np.concatenate((all_mlp_vectors, mlp_vectors))
  200. all_mlp_labels = np.concatenate((all_mlp_labels, mlp_labels))
  201. return all_mlp_vectors, all_mlp_labels
  202. def validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false):
  203. mlp_train_vectors, mlp_train_labels = get_mlp_data(model, abducer, mapping, train_X_true, train_X_false, rules)
  204. mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels)
  205. mlp_val_vectors, mlp_val_labels = get_mlp_data(model, abducer, mapping, val_X_true, val_X_false, rules)
  206. mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels)
  207. best_accuracy = 0
  208. # Try three times to find the best mlp
  209. for _ in range(3):
  210. INFO("Training mlp...")
  211. mlp = MLP(input_dim=logic_output_dim)
  212. criterion = nn.CrossEntropyLoss()
  213. optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
  214. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  215. mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
  216. mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
  217. loss = mlp_model.fit(mlp_train_data_loader)
  218. INFO("mlp training final loss is %f" % loss)
  219. mlp_val_data_loader = torch.utils.data.DataLoader(mlp_val_data, batch_size=64, shuffle=True)
  220. accuracy = mlp_model.val(mlp_val_data_loader)
  221. if accuracy > best_accuracy:
  222. best_accuracy = accuracy
  223. return best_accuracy
  224. def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8):
  225. train_X = train_data
  226. val_X = val_data
  227. logic_output_dim = 50
  228. samples_per_rule = 3
  229. # Start training / for each length of equations
  230. for equation_len in range(min_len, max_len):
  231. INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1))
  232. train_X_true = train_X[1][equation_len]
  233. train_X_false = train_X[0][equation_len]
  234. val_X_true = val_X[1][equation_len]
  235. val_X_false = val_X[0][equation_len]
  236. train_X_true.extend(train_X[1][equation_len + 1])
  237. train_X_false.extend(train_X[0][equation_len + 1])
  238. val_X_true.extend(val_X[1][equation_len + 1])
  239. val_X_false.extend(val_X[0][equation_len + 1])
  240. condition_cnt = 0
  241. while True:
  242. if equation_len == min_len:
  243. mapping = None
  244. # Abduce and train NN
  245. consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, mapping, train_X_true, select_num)
  246. if consistent_acc == 0:
  247. continue
  248. # Test if we can use mlp to evaluate
  249. if consistent_acc >= 0.9 and char_acc >= 0.9:
  250. condition_cnt += 1
  251. else:
  252. condition_cnt = 0
  253. # The condition has been satisfied continuously five times
  254. if condition_cnt >= 5:
  255. INFO("Now checking if we can go to next course")
  256. rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim)
  257. INFO('Learned rules from data:', output_rules(rules))
  258. best_accuracy = validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false)
  259. INFO('best_accuracy is %f\n' %(best_accuracy))
  260. # decide next course or restart
  261. if best_accuracy > 0.88:
  262. torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
  263. break
  264. else:
  265. if equation_len == min_len:
  266. model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  267. else:
  268. model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1)))
  269. condition_cnt = 0
  270. INFO('Reload Model and retrain')
  271. return model, mapping
  272. def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
  273. train_X = train_data
  274. test_X = test_data
  275. # Calcualte how many equations should be selected in each length
  276. # for each length, there are select_equation_cnt[equation_len] rules
  277. print("Now begin to train final mlp model")
  278. select_equation_cnt = []
  279. len_cnt = max_len - min_len + 1
  280. logic_output_dim = 50
  281. select_equation_cnt += [0] * min_len
  282. if logic_output_dim % len_cnt == 0:
  283. select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt
  284. else:
  285. select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt
  286. select_equation_cnt[-1] += logic_output_dim % len_cnt
  287. assert sum(select_equation_cnt) == logic_output_dim
  288. # Abduce rules
  289. rules = []
  290. samples_per_rule = 3
  291. for equation_len in range(min_len, max_len + 1):
  292. equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, select_equation_cnt[equation_len])
  293. rules.extend(equation_rules)
  294. INFO('Learned rules from data:', output_rules(rules))
  295. mlp_train_vectors, mlp_train_labels = get_all_mlp_data(model, abducer, mapping, train_X[1], train_X[0], rules, min_len, max_len)
  296. mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels)
  297. # Try three times to find the best mlp
  298. for _ in range(3):
  299. mlp = MLP(input_dim=logic_output_dim)
  300. criterion = nn.CrossEntropyLoss()
  301. optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
  302. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  303. mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
  304. mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
  305. loss = mlp_model.fit(mlp_train_data_loader)
  306. INFO("mlp training final loss is %f" % loss)
  307. for equation_len in range(5, 27):
  308. mlp_test_vectors, mlp_test_labels = get_mlp_data(model, abducer, mapping, test_X[1][equation_len], test_X[0][equation_len], rules)
  309. mlp_test_data = BasicDataset(mlp_test_vectors, mlp_test_labels)
  310. mlp_test_data_loader = torch.utils.data.DataLoader(mlp_test_data, batch_size=64, shuffle=True)
  311. accuracy = mlp_model.val(mlp_test_data_loader)
  312. INFO("The accuracy of testing length %d equations is: %f" % (equation_len, accuracy))
  313. INFO("\n")
  314. if __name__ == "__main__":
  315. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.