You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. import pickle as pk
  13. import numpy as np
  14. from utils.plog import INFO, DEBUG, clocker
  15. @clocker
  16. def block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx):
  17. part_num = (len(X_bak) // sample_num)
  18. if part_num == 0:
  19. part_num = 1
  20. seg_idx = epoch_idx % part_num
  21. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
  22. X = X_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  23. Y = Y_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  24. C = C_bak[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  25. return X, Y, C
  26. def get_taglist(self, Y):
  27. tmp = [[str(x) for x in label] for label in Y]
  28. tmp = sorted(list(set(tmp)))
  29. return tmp
  30. @clocker
  31. def result_statistics(pseudo_Y, Y, abduced_Y):
  32. abd_err_num = 0
  33. abd_char_num = 0
  34. abd_char_acc = 0
  35. abd_failed = 0
  36. word_err_num = 0
  37. ori_char_num = 0
  38. ori_char_acc = 0
  39. for tidx, (pseudo_y, y, abduced_y) in enumerate(zip(pseudo_Y, Y, abduced_Y)):
  40. pseudo_y = pseudo_y
  41. if sum(abduced_y != y) != 0:
  42. abd_err_num += 1
  43. if abduced_y is not None:
  44. abd_char_num += len(y)
  45. abd_char_acc += sum(abduced_y == y)
  46. else:
  47. abd_failed += 1
  48. ori_char_num += len(pseudo_y)
  49. ori_char_acc += sum(pseudo_y == y)
  50. if abduced_y is not None and sum(y != pseudo_y) == 0 and sum(pseudo_y != abduced_y) > 0:
  51. INFO(pseudo_y, y, abduced_y)
  52. pk.dump((pseudo_y, y, abduced_y), open("bug.pk", "wb"))
  53. if sum(pseudo_y != y) != 0:
  54. word_err_num += 1
  55. INFO("")
  56. INFO("Abd word level accuracy:", 1 - word_err_num / len(pseudo_Y))
  57. INFO("Abd char level accuracy:", abd_char_acc / abd_char_num)
  58. INFO("Ori char level accuracy:", ori_char_acc / ori_char_num)
  59. INFO("")
  60. result = {"total_word" : len(pseudo_Y), "accuracy_word" : len(pseudo_Y) - word_err_num,
  61. "total_abd_char": abd_char_num, "accuracy_abd_char" : abd_char_acc,
  62. "total_ori_char": ori_char_num, "accuracy_ori_char" : ori_char_acc,
  63. "total_abd_failed": abd_failed}
  64. return result
  65. @clocker
  66. def filter_data(X, abduced_Y):
  67. finetune_Y = []
  68. finetune_X = []
  69. for abduced_x, abduced_y in zip(X, abduced_Y):
  70. if abduced_y is not None:
  71. finetune_X.append(abduced_x)
  72. finetune_Y.append(abduced_y)
  73. return finetune_X, finetune_Y
  74. @clocker
  75. def is_all_sublabel_exist(labels, std_label_list):
  76. if not labels:
  77. return False
  78. labels = np.array(labels).T
  79. for idx, (std_label, label) in enumerate(zip(std_label_list, labels)):
  80. std_num = len(set(std_label))
  81. sublabel_num = len(set(label))
  82. if std_num != sublabel_num:
  83. INFO(f"sublabel {idx} should have {std_num} class, but data only have {sublabel_num} class", screen=True)
  84. return False
  85. return True
  86. def pretrain(model, X, Y):
  87. pass
  88. def train(model, abducer, X, Y, C = None, epochs = 10, sample_num = -1, verbose = -1, check_sublabel = True):
  89. # Set default parameters
  90. if sample_num == -1:
  91. sample_num = len(X)
  92. if verbose < 1:
  93. verbose = epochs
  94. if C is None:
  95. C = [None] * len(X)
  96. # Set function running time recorder
  97. valid_func = clocker(model.valid)
  98. predict_func = clocker(model.predict)
  99. train_func = clocker(model.train)
  100. abduce_func = clocker(abducer.batch_abduce)
  101. X_bak = X
  102. Y_bak = Y
  103. C_bak = C
  104. # Abductive learning train process
  105. res = {}
  106. for epoch_idx in range(epochs):
  107. X, Y, C = block_sample(X_bak, Y_bak, C_bak, sample_num, epoch_idx)
  108. preds_res = predict_func(X)
  109. abduced_Y = abduce_func(preds_res, C)
  110. finetune_X, finetune_Y = filter_data(X, abduced_Y)
  111. score, score_list = valid_func(X, Y)
  112. if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
  113. res = result_statistics(preds_res["cls"], Y, abduced_Y)
  114. INFO(res)
  115. if check_sublabel and (not is_all_sublabel_exist(finetune_Y, model.label_lists)):
  116. INFO("There is some sub label missing", len(finetune_Y))
  117. break
  118. if len(finetune_X) > 0:
  119. train_func(finetune_X, finetune_Y)#, n_epoch = 10)
  120. else:
  121. INFO("lack of data, all abduced failed", len(finetune_X))
  122. return res
  123. #return ret
  124. if __name__ == "__main__":
  125. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.