You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. import pickle as pk
  13. import numpy as np
  14. from utils.plog import INFO, DEBUG, clocker
  15. @clocker
  16. def block_sample(X, Z, Y, sample_num, epoch_idx):
  17. part_num = (len(X) // sample_num)
  18. if part_num == 0:
  19. part_num = 1
  20. seg_idx = epoch_idx % part_num
  21. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
  22. X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  23. Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  24. Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  25. return X, Z, Y
  26. def get_taglist(self, Z):
  27. tmp = [[str(x) for x in label] for label in Z]
  28. tmp = sorted(list(set(tmp)))
  29. return tmp
  30. def get_abl_acc(Y, pseudo_Z, logic_forward):
  31. abl_acc = 0
  32. for y, pseudo_z in zip(Y, pseudo_Z):
  33. if(logic_forward(pseudo_z) == y):
  34. abl_acc += 1
  35. return abl_acc / len(Y)
  36. def get_char_acc(Z, pseudo_Z):
  37. char_acc = 0
  38. char_num = 0
  39. for z, pseudo_z in zip(Z, pseudo_Z):
  40. char_num += len(z)
  41. for zidx in range(len(z)):
  42. if(z[zidx] == pseudo_z[zidx]):
  43. char_acc += 1
  44. return char_acc / char_num
  45. # def result_statistics(pseudo_Y, Y, abduced_Y):
  46. # abd_err_num = 0
  47. # abd_char_num = 0
  48. # abd_char_acc = 0
  49. # abd_failed = 0
  50. # word_err_num = 0
  51. # ori_char_num = 0
  52. # ori_char_acc = 0
  53. # for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)):
  54. # pseudo_y = pseudo_y
  55. # if sum(abduced_y != y) != 0:
  56. # abd_err_num += 1
  57. # if abduced_y is not None:
  58. # abd_char_num += len(y)
  59. # abd_char_acc += sum(abduced_y == y)
  60. # else:
  61. # abd_failed += 1
  62. # ori_char_num += len(pseudo_y)
  63. # ori_char_acc += sum(pseudo_y == y)
  64. # if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0:
  65. # INFO(pseudo_y, y, abduced_y)
  66. # pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb"))
  67. # if sum(pseudo_y != y) != 0:
  68. # word_err_num += 1
  69. # INFO("")
  70. # INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y))
  71. # INFO("Abd char level accuracy:", abd_char_acc / abd_char_num)
  72. # INFO("Ori char level accuracy:", ori_char_acc / ori_char_num)
  73. # INFO("")
  74. # result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num,
  75. # "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc,
  76. # "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc,
  77. # "total_abd_failed": abd_failed}
  78. # return result
  79. @clocker
  80. def filter_data(X, abduced_Z):
  81. finetune_Z = []
  82. finetune_X = []
  83. for abduced_x, abduced_z in zip(X, abduced_Z):
  84. if abduced_z is not None:
  85. finetune_X.append(abduced_x)
  86. finetune_Z.append(abduced_z)
  87. return finetune_X, finetune_Z
  88. @clocker
  89. def is_all_sublabel_exist(labels, std_label_list):
  90. if not labels:
  91. return False
  92. labels = np.array(labels).T
  93. for idx, (std_label, label) in enumerate(zip(std_label_list, labels)):
  94. std_num = len(set(std_label))
  95. sublabel_num = len(set(label))
  96. if std_num != sublabel_num:
  97. INFO(f"sublabel {idx} should have {std_num} class, but data only have {sublabel_num} class", screen=True)
  98. return False
  99. return True
  100. def pretrain(model, X, Z):
  101. pass
  102. def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1):
  103. # Set default parameters
  104. if sample_num == -1:
  105. sample_num = len(X)
  106. if verbose < 1:
  107. verbose = epochs
  108. char_acc_flag = 1
  109. if Z == None:
  110. char_acc_flag = 0
  111. Z = [None] * len(X)
  112. predict_func = clocker(model.predict)
  113. train_func = clocker(model.train)
  114. abduce_func = clocker(abducer.batch_abduce)
  115. epochs = 50
  116. # Abductive learning train process
  117. for epoch_idx in range(epochs):
  118. X, Z, Y = block_sample(X, Z, Y, sample_num, epoch_idx)
  119. preds_res = predict_func(X)
  120. abduced_Z = abduce_func(preds_res, Y)
  121. abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward)
  122. if(not char_acc_flag):
  123. ori_char_acc = get_char_acc(Z, preds_res['cls'])
  124. abd_char_acc = get_char_acc(abduced_Z, preds_res['cls'])
  125. print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc)
  126. finetune_X, finetune_Z = filter_data(X, abduced_Z)
  127. if len(finetune_X) > 0:
  128. train_func(finetune_X, finetune_Z)
  129. else:
  130. INFO("lack of data, all abduced failed", len(finetune_X))
  131. return abl_acc
  132. # def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):
  133. # # Set default parameters
  134. # if sample_num == -1:
  135. # sample_num = len(X)
  136. # if verbose < 1:
  137. # verbose = epochs
  138. # if C is None:
  139. # C = [None] * len(X)
  140. # # Set function running time recorder
  141. # valid_func = clocker(model.valid)
  142. # predict_func = clocker(model.predict)
  143. # train_func = clocker(model.train)
  144. # abduce_func = clocker(abducer.batch_abduce)
  145. # X_bak = X
  146. # Y_bak = Y
  147. # C_bak = C
  148. # # Abductive learning train process
  149. # res = {}
  150. # for epoch_idx in range(epochs):
  151. # X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx)
  152. # preds_res = predict_func(X)
  153. # abduced_Y = abduce_func(preds_res, C)
  154. # finetune_X, finetune_Y = filter_data(X, abduced_Y)
  155. # score, score_list = valid_func(X, Y)
  156. # if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
  157. # res = result_statistics(preds_res["cls"], Y, abduced_Y)
  158. # INFO(res)
  159. # if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)):
  160. # INFO("There is some sub label missing", len(finetune_Y))
  161. # break
  162. # if len(finetune_X) > 0:
  163. # train_func(finetune_X, finetune_Y)#, n_epoch = 10)
  164. # else:
  165. # INFO("lack of data, all abduced failed", len(finetune_X))
  166. # return res
  167. if __name__ == "__main__":
  168. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.