You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. import pickle as pk
  13. import numpy as np
  14. from utils.plog import INFO, DEBUG, clocker
  15. def block_sample(X, Z, Y, sample_num, epoch_idx):
  16. part_num = (len(X) // sample_num)
  17. if part_num == 0:
  18. part_num = 1
  19. seg_idx = epoch_idx % part_num
  20. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
  21. X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  22. Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  23. Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)]
  24. return X, Z, Y
  25. def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
  26. result = {}
  27. if char_acc_flag:
  28. char_acc_num = 0
  29. char_num = 0
  30. for pred_z, z in zip(pred_Z, Z):
  31. char_num += len(z)
  32. for zidx in range(len(z)):
  33. if(pred_z[zidx] == z[zidx]):
  34. char_acc_num += 1
  35. char_acc = char_acc_num / char_num
  36. result["Character level accuracy"] = char_acc
  37. abl_acc_num = 0
  38. for pred_z, y in zip(pred_Z, Y):
  39. if(logic_forward(pred_z) == y):
  40. abl_acc_num += 1
  41. abl_acc = abl_acc_num / len(Y)
  42. result["ABL accuracy"] = abl_acc
  43. return result
  44. def filter_data(X, abduced_Z):
  45. finetune_Z = []
  46. finetune_X = []
  47. for abduced_x, abduced_z in zip(X, abduced_Z):
  48. if abduced_z is not []:
  49. finetune_X.append(abduced_x)
  50. finetune_Z.append(abduced_z)
  51. return finetune_X, finetune_Z
  52. def pretrain(model, X, Z):
  53. pass
  54. def train(model, abducer, train_data, test_data, epochs = 50, sample_num = -1, verbose = -1):
  55. train_X, train_Z, train_Y = train_data
  56. test_X, test_Z, test_Y = test_data
  57. # Set default parameters
  58. if sample_num == -1:
  59. sample_num = len(train_X)
  60. if verbose < 1:
  61. verbose = epochs
  62. char_acc_flag = 1
  63. if train_Z == None:
  64. char_acc_flag = 0
  65. train_Z = [None] * len(X)
  66. predict_func = clocker(model.predict)
  67. train_func = clocker(model.train)
  68. abduce_func = clocker(abducer.batch_abduce)
  69. # Abductive learning train process
  70. for epoch_idx in range(epochs):
  71. X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx)
  72. preds_res = predict_func(X)
  73. abduced_Z = abduce_func(preds_res, Y)
  74. if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
  75. res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
  76. INFO('epoch: ', epoch_idx + 1, ' ', res)
  77. finetune_X, finetune_Z = filter_data(X, abduced_Z)
  78. if len(finetune_X) > 0:
  79. # model.valid(finetune_X, finetune_Z)
  80. train_func(finetune_X, finetune_Z)
  81. else:
  82. INFO("lack of data, all abduced failed", len(finetune_X))
  83. return res
  84. if __name__ == "__main__":
  85. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.