You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.1 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435
  1. import numpy as np
  2. # for multiple predictions, modify from `learn_add.py`
  3. def _flatten(l):
  4. return [item for sublist in l for item in _flatten(sublist)] if isinstance(l, list) else [l]
  5. # for multiple predictions, modify from `learn_add.py`
  6. def _reform_ids(flatten_pred_res, save_pred_res):
  7. re = []
  8. i = 0
  9. for e in save_pred_res:
  10. j = 0
  11. ids = []
  12. while j < len(e):
  13. ids.append(flatten_pred_res[i + j])
  14. j += 1
  15. re.append(ids)
  16. i = i + j
  17. return re
  18. def _hamming_dist(A, B):
  19. B = np.array(B)
  20. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  21. return np.sum(A != B, axis = 1)
  22. def _confidence_dist(A, B):
  23. B = np.array(B)
  24. A = np.clip(A, 1e-9, 1)
  25. A = np.expand_dims(A, axis=0)
  26. A = A.repeat(axis=0, repeats=(len(B)))
  27. rows = np.array(range(len(B)))
  28. rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
  29. cols = np.array(range(len(B[0])))
  30. cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
  31. return 1 - np.prod(A[rows, cols, B], axis = 1)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.