Browse Source

Update abduce for HED

pull/3/head
troyyyyy 2 years ago
parent
commit
f9dee56b54
3 changed files with 38 additions and 43 deletions
  1. +19
    -20
      abl/abducer/abducer_base.py
  2. +11
    -21
      abl/abducer/kb.py
  3. +8
    -2
      abl/utils/utils.py

+ 19
- 20
abl/abducer/abducer_base.py View File

@@ -13,7 +13,7 @@
import abc
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, nested_length

class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False):
@@ -52,14 +52,14 @@ class AbducerBase(abc.ABC):
return len(pred_res)
def _zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
# if not self.multiple_predictions:
return self._zoopt_address_score_single(sol.get_x(), pred_res, pred_res_prob, key)
# else:
# all_address_flag = reform_idx(sol.get_x(), pred_res)
# score = 0
# for idx in range(len(pred_res)):
# score += self._zoopt_address_score_single(all_address_flag[idx], pred_res[idx], pred_res_prob[idx], key)
# return score
all_address_flag = reform_idx(sol.get_x(), pred_res)
if nested_length(pred_res) == 1:
return self._zoopt_address_score_single(all_address_flag[idx], pred_res, pred_res_prob, key)
else:
score = 0
for idx in range(nested_length(pred_res)):
score += self._zoopt_address_score_single(all_address_flag[idx], [pred_res[idx]], [pred_res_prob[idx]], [key[idx]])
return score
def _constrain_address_num(self, solution, max_address_num):
x = solution.get_x()
@@ -78,19 +78,18 @@ class AbducerBase(abc.ABC):
return solution
def address_by_idx(self, pred_res, key, address_idx):
# print(pred_res, address_idx)
return self.kb.address_by_idx(pred_res, key, address_idx)

def abduce(self, data, max_address=-1, require_more_address=0):
pred_res, pred_res_prob, key = data
# if max_address_num == -1:
# max_address_num = len(flatten(pred_res))
assert(type(max_address) in (int, float))
if max_address == -1:
max_address_num = len(pred_res)
max_address_num = len(flatten(pred_res))
elif type(max_address) == float:
assert(max_address >= 0 and max_address <= 1)
max_address_num = round(len(pred_res) * max_address)
max_address_num = round(len(flatten(pred_res)) * max_address)
else:
assert(max_address >= 0)
max_address_num = max_address
@@ -267,11 +266,11 @@ if __name__ == '__main__':
print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print()

# res = abd.abduce((consist_exs, [None] * len(consist_exs), [None] * len(consist_exs)))
# print(res)
# res = abd.batch_abduce((inconsist_exs, [None] * len(consist_exs), [None] * len(inconsist_exs)))
# print(res)
# print()
res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)))
print(res)
res = abd.abduce((inconsist_exs, [[[None]]] * len(consist_exs), [None] * len(inconsist_exs)))
print(res)
print()

# abduced_rules = abd.batch_abduce_rules(consist_exs)
# print(abduced_rules)
abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)

+ 11
- 21
abl/abducer/kb.py View File

@@ -109,16 +109,10 @@ class KBBase(ABC):
def address_by_idx(self, pred_res, key, address_idx):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
# if multiple_predictions:
# save_pred_res = pred_res
# pred_res = flatten(pred_res)

for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
# if multiple_predictions:
# candidate = reform_idx(candidate, save_pred_res)
if check_equal(self.logic_forward(candidate), key, self.max_err):
candidates.append(candidate)
return candidates
@@ -139,7 +133,7 @@ class KBBase(ABC):
key = hashable_to_list(key)
candidates = []
for address_num in range(len(flatten(pred_res)) + 1):
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
candidates.append(pred_res)
@@ -202,24 +196,22 @@ class prolog_KB(KBBase):
return False
return result
def _address_pred_res(self, pred_res, address_idx, multiple_predictions):
def _address_pred_res(self, pred_res, address_idx):
import re
address_pred_res = pred_res.copy()
if multiple_predictions:
address_pred_res = flatten(address_pred_res)
address_pred_res = flatten(address_pred_res)
for idx in address_idx:
address_pred_res[idx] = 'P' + str(idx)
if multiple_predictions:
address_pred_res = reform_idx(address_pred_res, pred_res)
address_pred_res = reform_idx(address_pred_res, pred_res)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
def get_query_string(self, pred_res, key, address_idx, multiple_predictions):
def get_query_string(self, pred_res, key, address_idx):
query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions)
query_string += self._address_pred_res(pred_res, address_idx)
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
return query_string
@@ -227,19 +219,17 @@ class prolog_KB(KBBase):
def _find_candidate_GKB(self, pred_res, key):
pass
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
def address_by_idx(self, pred_res, key, address_idx):
candidates = []
query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions)
if multiple_predictions:
save_pred_res = pred_res
pred_res = flatten(pred_res)
query_string = self.get_query_string(pred_res, key, address_idx)
save_pred_res = pred_res
pred_res = flatten(pred_res)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res)
candidate = reform_idx(candidate, save_pred_res)
candidates.append(candidate)
return candidates



+ 8
- 2
abl/utils/utils.py View File

@@ -3,14 +3,20 @@ from .plog import INFO
from collections import OrderedDict
from itertools import chain

# for multiple predictions
def nested_length(l):
if not isinstance(l[0], (list, tuple)):
return 1
return len(l)

def flatten(l):
if not isinstance(l[0], (list, tuple)):
return l
return list(chain.from_iterable(l))
# for multiple predictions
def reform_idx(flatten_pred_res, save_pred_res):
if not isinstance(save_pred_res[0], (list, tuple)):
return flatten_pred_res
re = []
i = 0
for e in save_pred_res:


Loading…
Cancel
Save