You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.3 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy as np
  2. from utils.plog import INFO
  3. from collections import OrderedDict
  4. # for multiple predictions, modify from `learn_add.py`
  5. def flatten(l):
  6. return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l]
  7. # for multiple predictions, modify from `learn_add.py`
  8. def reform_idx(flatten_pred_res, save_pred_res):
  9. re = []
  10. i = 0
  11. for e in save_pred_res:
  12. j = 0
  13. idx = []
  14. while j < len(e):
  15. idx.append(flatten_pred_res[i + j])
  16. j += 1
  17. re.append(idx)
  18. i = i + j
  19. return re
  20. def hamming_dist(A, B):
  21. B = np.array(B)
  22. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  23. return np.sum(A != B, axis = 1)
  24. def confidence_dist(A, B):
  25. B = np.array(B)
  26. A = np.clip(A, 1e-9, 1)
  27. A = np.expand_dims(A, axis=0)
  28. A = A.repeat(axis=0, repeats=(len(B)))
  29. rows = np.array(range(len(B)))
  30. rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
  31. cols = np.array(range(len(B[0])))
  32. cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
  33. return 1 - np.prod(A[rows, cols, B], axis = 1)
  34. def block_sample(X, Z, Y, sample_num, epoch_idx):
  35. part_num = len(X) // sample_num
  36. if part_num == 0:
  37. part_num = 1
  38. seg_idx = epoch_idx % part_num
  39. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
  40. X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  41. Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  42. Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  43. return X, Z, Y
  44. def gen_mappings(chars, symbs):
  45. n_char = len(chars)
  46. n_symbs = len(symbs)
  47. if n_char != n_symbs:
  48. print('Characters and symbols size dosen\'t match.')
  49. return
  50. from itertools import permutations
  51. mappings = []
  52. # returned mappings
  53. perms = permutations(symbs)
  54. for p in perms:
  55. mappings.append(dict(zip(chars, list(p))))
  56. return mappings
  57. def mapping_res(original_pred_res, m):
  58. return [[m[symbol] for symbol in formula] for formula in original_pred_res]
  59. def remapping_res(pred_res, m):
  60. remapping = {}
  61. for key, value in m.items():
  62. remapping[value] = key
  63. return [[remapping[symbol] for symbol in formula] for formula in pred_res]

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.