You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. from .utils.plog import INFO, clocker
  13. from .utils.utils import block_sample, float_parameter
  14. def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
  15. result = {}
  16. if char_acc_flag:
  17. char_acc_num = 0
  18. char_num = 0
  19. for pred_z, z in zip(pred_Z, Z):
  20. char_num += len(z)
  21. for zidx in range(len(z)):
  22. if pred_z[zidx] == z[zidx]:
  23. char_acc_num += 1
  24. char_acc = char_acc_num / char_num
  25. result["Character level accuracy"] = char_acc
  26. abl_acc_num = 0
  27. for pred_z, y in zip(pred_Z, Y):
  28. if logic_forward(pred_z) == y:
  29. abl_acc_num += 1
  30. abl_acc = abl_acc_num / len(Y)
  31. result["ABL accuracy"] = abl_acc
  32. return result
  33. def filter_data(X, abduced_Z):
  34. finetune_Z = []
  35. finetune_X = []
  36. for x, abduced_z in zip(X, abduced_Z):
  37. if len(abduced_z) > 0:
  38. finetune_X.append(x)
  39. finetune_Z.append(abduced_z)
  40. return finetune_X, finetune_Z
  41. def train(model, abducer, train_data, epochs=50, sample=-1, verbose=-1):
  42. train_X, train_Z, train_Y = train_data
  43. # Set default parameters
  44. sample_num = float_parameter(sample, len(train_X))
  45. part_num = (len(train_X) - 1) // sample_num + 1
  46. if verbose < 1:
  47. verbose = epochs
  48. char_acc_flag = 1
  49. if train_Z == None:
  50. char_acc_flag = 0
  51. train_Z = [None] * len(train_X)
  52. predict_func = clocker(model.predict)
  53. train_func = clocker(model.train)
  54. abduce_func = clocker(abducer.batch_abduce)
  55. for epoch in range(epochs):
  56. for seg_idx in range(part_num):
  57. X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, seg_idx)
  58. INFO("epoch:", epoch + 1, ", seg_idx:", seg_idx + 1, "/", part_num, ", data num:", len(X))
  59. preds_res = predict_func(X)
  60. abduced_Z = abduce_func(preds_res, Y)
  61. ## TODO: change verbose
  62. if ((seg_idx + 1) % verbose == 0) or (seg_idx == epochs - 1):
  63. pseudo_label = [[abducer.mapping[label] for label in formula] for formula in preds_res['label']]
  64. res = result_statistics(pseudo_label, Z, Y, abducer.kb.logic_forward, char_acc_flag)
  65. INFO("seg: ", seg_idx + 1, " ", res)
  66. finetune_X, finetune_Z = filter_data(X, abduced_Z)
  67. finetune_Z = [[abducer.remapping[symbol] for symbol in formula] for formula in finetune_Z]
  68. if len(finetune_X) > 0:
  69. # model.valid(finetune_X, finetune_Z)
  70. train_func(finetune_X, finetune_Z)
  71. else:
  72. INFO("lack of data, all abduced failed", len(finetune_X))
  73. return model
  74. ## TODO: test
  75. def test(model, abducer, test_data):
  76. test_X, test_Z, test_Y = test_data
  77. predict_func = clocker(model.predict)
  78. preds_res = predict_func(test_X)
  79. char_acc_flag = 1
  80. if test_Z == None:
  81. char_acc_flag = 0
  82. test_Z = [None] * len(test_X)
  83. res = result_statistics(preds_res["cls"], test_Z, test_Y, abducer.kb.logic_forward, char_acc_flag)
  84. INFO(res)
  85. if __name__ == "__main__":
  86. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.