You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.7 kB

3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
2 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from .plog import INFO
  5. from collections import OrderedDict
  6. from itertools import chain
  7. def flatten(l):
  8. if not isinstance(l[0], (list, tuple)):
  9. return l
  10. return list(chain.from_iterable(l))
  11. def reform_idx(flatten_pred_res, save_pred_res):
  12. if not isinstance(save_pred_res[0], (list, tuple)):
  13. return flatten_pred_res
  14. re = []
  15. i = 0
  16. for e in save_pred_res:
  17. re.append(flatten_pred_res[i:i + len(e)])
  18. i += len(e)
  19. return re
  20. def hamming_dist(A, B):
  21. A = np.array(A, dtype='<U')
  22. B = np.array(B, dtype='<U')
  23. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  24. return np.sum(A != B, axis = 1)
  25. def confidence_dist(A, B):
  26. B = np.array(B)
  27. A = np.clip(A, 1e-9, 1)
  28. A = np.expand_dims(A, axis=0)
  29. A = A.repeat(axis=0, repeats=(len(B)))
  30. rows = np.array(range(len(B)))
  31. rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0]))
  32. cols = np.array(range(len(B[0])))
  33. cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B))
  34. return 1 - np.prod(A[rows, cols, B], axis=1)
  35. def block_sample(X, Z, Y, sample_num, epoch_idx):
  36. part_num = len(X) // sample_num
  37. if part_num == 0:
  38. part_num = 1
  39. seg_idx = epoch_idx % part_num
  40. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
  41. X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  42. Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  43. Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  44. return X, Z, Y
  45. def gen_mappings(chars, symbs):
  46. n_char = len(chars)
  47. n_symbs = len(symbs)
  48. if n_char != n_symbs:
  49. print("Characters and symbols size dosen't match.")
  50. return
  51. from itertools import permutations
  52. mappings = []
  53. # returned mappings
  54. perms = permutations(symbs)
  55. for p in perms:
  56. mappings.append(dict(zip(chars, list(p))))
  57. return mappings
  58. def mapping_res(original_pred_res, m):
  59. return [[m[symbol] for symbol in formula] for formula in original_pred_res]
  60. def remapping_res(pred_res, m):
  61. remapping = {}
  62. for key, value in m.items():
  63. remapping[value] = key
  64. return [[remapping[symbol] for symbol in formula] for formula in pred_res]
  65. def check_equal(a, b, max_err=0):
  66. if isinstance(a, (int, float)) and isinstance(b, (int, float)):
  67. return abs(a - b) <= max_err
  68. if isinstance(a, list) and isinstance(b, list):
  69. if len(a) != len(b):
  70. return False
  71. for i in range(len(a)):
  72. if not check_equal(a[i], b[i]):
  73. return False
  74. return True
  75. else:
  76. return a == b
  77. def extract_feature(img):
  78. extractor = nn.AvgPool2d(2, stride=2)
  79. feature_map = np.array(extractor(torch.Tensor(img)))
  80. return feature_map.reshape((-1,))
  81. return np.concatenate(
  82. (np.squeeze(np.sum(img, axis=1)), np.squeeze(np.sum(img, axis=2))), axis=0
  83. )
  84. def reduce_dimension(data):
  85. for truth_value in [0, 1]:
  86. for equation_len in range(5, 27):
  87. equations = data[truth_value][equation_len]
  88. reduced_equations = [
  89. [extract_feature(symbol_img) for symbol_img in equation]
  90. for equation in equations
  91. ]
  92. data[truth_value][equation_len] = reduced_equations
  93. def to_hashable(l):
  94. if type(l) is not list:
  95. return l
  96. if type(l[0]) is not list:
  97. return tuple(l)
  98. return tuple(tuple(sublist) for sublist in l)
  99. def hashable_to_list(t):
  100. if type(t) is not tuple:
  101. return t
  102. if type(t[0]) is not tuple:
  103. return list(t)
  104. return [list(subtuple) for subtuple in t]

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.