Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
3d14196107
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 84 additions and 34 deletions
  1. +84
    -34
      abducer/abducer_base.py

+ 84
- 34
abducer/abducer_base.py View File

@@ -26,9 +26,16 @@ import time




class AbducerBase(abc.ABC): class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True):
def __init__(
self,
kb,
dist_func="confidence",
zoopt=False,
multiple_predictions=False,
cache=True,
):
self.kb = kb self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence'
assert dist_func == "hamming" or dist_func == "confidence"
self.dist_func = dist_func self.dist_func = dist_func
self.zoopt = zoopt self.zoopt = zoopt
self.multiple_predictions = multiple_predictions self.multiple_predictions = multiple_predictions
@@ -39,11 +46,18 @@ class AbducerBase(abc.ABC):
self.cache_candidates = {} self.cache_candidates = {}


def _get_cost_list(self, pred_res, pred_res_prob, candidates): def _get_cost_list(self, pred_res, pred_res_prob, candidates):
if self.dist_func == 'hamming':
if self.dist_func == "hamming":
return hamming_dist(pred_res, candidates) return hamming_dist(pred_res, candidates)
elif self.dist_func == 'confidence':
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
return confidence_dist(pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates])
elif self.dist_func == "confidence":
mapping = dict(
zip(
self.kb.pseudo_label_list,
list(range(len(self.kb.pseudo_label_list))),
)
)
return confidence_dist(
pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]
)


def _get_one_candidate(self, pred_res, pred_res_prob, candidates): def _get_one_candidate(self, pred_res, pred_res_prob, candidates):
if len(candidates) == 0: if len(candidates) == 0:
@@ -61,8 +75,12 @@ class AbducerBase(abc.ABC):
all_address_flag = reform_idx(solution, pred_res) all_address_flag = reform_idx(solution, pred_res)
score = 0 score = 0
for idx in enumerate(len(pred_res)): for idx in enumerate(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True)
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
candidate = self.kb.address_by_idx(
[pred_res[idx]], key[idx], address_idx, True
)
if len(candidate) > 0: if len(candidate) > 0:
score += 1 score += 1
return score return score
@@ -70,7 +88,9 @@ class AbducerBase(abc.ABC):
def _zoopt_address_score(self, pred_res, key, sol): def _zoopt_address_score(self, pred_res, key, sol):
if not self.multiple_predictions: if not self.multiple_predictions:
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0]
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
candidates = self.kb.address_by_idx(
pred_res, key, address_idx, self.multiple_predictions
)
return 1 if len(candidates) > 0 else 0 return 1 if len(candidates) > 0 else 0
else: else:
return self._zoopt_score_multiple(pred_res, key, sol.get_x()) return self._zoopt_score_multiple(pred_res, key, sol.get_x())
@@ -98,7 +118,11 @@ class AbducerBase(abc.ABC):
pred_res = flatten(pred_res) pred_res = flatten(pred_res)
key = tuple(key) key = tuple(key)
if (tuple(pred_res), key) in self.cache_min_address_num: if (tuple(pred_res), key) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
address_num = min(
max_address_num,
self.cache_min_address_num[(tuple(pred_res), key)]
+ require_more_address,
)
if (tuple(pred_res), key, address_num) in self.cache_candidates: if (tuple(pred_res), key, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
if self.zoopt: if self.zoopt:
@@ -127,12 +151,18 @@ class AbducerBase(abc.ABC):
if self.zoopt: if self.zoopt:
solution = self.zoopt_get_solution(pred_res, key, max_address_num) solution = self.zoopt_get_solution(pred_res, key, max_address_num)
address_idx = [idx for idx, i in enumerate(solution) if i != 0] address_idx = [idx for idx, i in enumerate(solution) if i != 0]
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
candidates = self.kb.address_by_idx(
pred_res, key, address_idx, self.multiple_predictions
)
address_num = int(solution.sum()) address_num = int(solution.sum())
min_address_num = address_num min_address_num = address_num
else: else:
candidates, min_address_num, address_num = self.kb.abduce_candidates( candidates, min_address_num, address_num = self.kb.abduce_candidates(
pred_res, key, max_address_num, require_more_address, self.multiple_predictions
pred_res,
key,
max_address_num,
require_more_address,
self.multiple_predictions,
) )


candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
@@ -147,23 +177,31 @@ class AbducerBase(abc.ABC):


def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0):
if self.multiple_predictions: if self.multiple_predictions:
return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address)
return self.abduce(
(Z["cls"], Z["prob"], Y), max_address_num, require_more_address
)
else: else:
return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
return [
self.abduce((z, prob, y), max_address_num, require_more_address)
for z, prob, y in zip(Z["cls"], Z["prob"], Y)
]


def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): def __call__(self, Z, Y, max_address_num=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address) return self.batch_abduce(Z, Y, max_address_num, require_more_address)







if __name__ == '__main__':
prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
if __name__ == "__main__":
prob1 = [
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
prob2 = [
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]


kb = add_KB() kb = add_KB()
abd = AbducerBase(kb, 'confidence')
abd = AbducerBase(kb, "confidence")
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -177,7 +215,7 @@ if __name__ == '__main__':
print() print()


kb = add_prolog_KB() kb = add_prolog_KB()
abd = AbducerBase(kb, 'confidence')
abd = AbducerBase(kb, "confidence")
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -191,7 +229,7 @@ if __name__ == '__main__':
print() print()


kb = add_prolog_KB() kb = add_prolog_KB()
abd = AbducerBase(kb, 'confidence', zoopt=True)
abd = AbducerBase(kb, "confidence", zoopt=True)
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -205,24 +243,38 @@ if __name__ == '__main__':
print() print()


kb = HWF_KB(len_list=[1, 3, 5]) kb = HWF_KB(len_list=[1, 3, 5])
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
abd = AbducerBase(kb, "hamming")
res = abd.abduce(
(["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0
)
print(res) print(res)
res = abd.abduce((['5', '+', '2'], None, 64), max_address_num=3, require_more_address=0)
res = abd.abduce(
(["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0
)
print(res) print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0)
res = abd.abduce(
(["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0
)
print(res) print(res)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3)
res = abd.abduce(
(["5", "8", "8", "8", "8"], None, 3.17),
max_address_num=5,
require_more_address=3,
)
print(res) print(res)
print() print()


kb = HED_prolog_KB() kb = HED_prolog_KB()
abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) abd = AbducerBase(kb, zoopt=True, multiple_predictions=True)
consist_exs = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]]
consist_exs2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules
inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]]
consist_exs2 = [
[1, "+", 0, "=", 0],
[1, "+", 1, "=", 0],
[0, "+", 1, "=", 1, 1],
] # not consistent with rules
inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
# inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']]
rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])']
rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"]


print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs))
print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules))
@@ -236,5 +288,3 @@ if __name__ == '__main__':


abduced_rules = abd.abduce_rules(consist_exs) abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules) print(abduced_rules)


Loading…
Cancel
Save