Browse Source

Add kb and abducer for HED

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
2587ac2f7e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 57 additions and 23 deletions
  1. +57
    -23
      abducer/abducer_base.py

+ 57
- 23
abducer/abducer_base.py View File

@@ -14,19 +14,20 @@ import sys
sys.path.append("..")

import abc
from abducer.kb import add_KB, hwf_KB, add_prolog_KB
from abducer.kb import add_KB, HWF_KB, add_prolog_KB, HED_prolog_KB
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt

import time

class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = 'confidence', zoopt = False, cache = True):
def __init__(self, kb, dist_func = 'confidence', zoopt = False, multiple_predictions = False, cache = True):
self.kb = kb
assert(dist_func == 'hamming' or dist_func == 'confidence')
self.dist_func = dist_func
self.cache = cache
self.zoopt = zoopt
self.multiple_predictions = multiple_predictions
self.cache = cache
if self.cache:
self.cache_min_address_num = {}
@@ -58,10 +59,10 @@ class AbducerBase(abc.ABC):
elif self.dist_func == 'confidence':
return self.confidence_dist(pred_res_prob, candidates)

def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates):
def get_one_candidate(self, pred_res, pred_res_prob, candidates):
if len(candidates) == 0:
return []
elif len(candidates) == 1:
elif len(candidates) == 1 or self.zoopt:
return candidates[0]
else:
cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates)
@@ -71,8 +72,19 @@ class AbducerBase(abc.ABC):




# for multiple_prediction
def flatten(self, l):
if self.multiple_predictions:
return [item for sublist in l for item in sublist]
else:
return l

# for zoopt
def zoopt_address_score(self, pred_res, key, address_idx):
candidates = self.kb.address_by_idx(pred_res, key, address_idx)
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
return 0 if len(candidates) > 0 else 1
def constraint_address_num(self, solution, max_address_num):
@@ -80,9 +92,9 @@ class AbducerBase(abc.ABC):
return max_address_num - x.sum()

def zoopt_get_address_idx(self, pred_res, key, max_address_num):
dimension = Dimension(size=len(pred_res),
regs=[[0, 1]] * len(pred_res),
tys=[False] * len(pred_res))
dimension = Dimension(size=len(self.flatten(pred_res)),
regs=[[0, 1]] * len(self.flatten(pred_res)),
tys=[False] * len(self.flatten(pred_res)))
objective = Objective(lambda sol: self.zoopt_address_score(pred_res, key, [idx for idx, i in enumerate(sol.get_x()) if i != 0]),
dim=dimension,
constraint=lambda sol: self.constraint_address_num(sol, max_address_num))
@@ -90,7 +102,7 @@ class AbducerBase(abc.ABC):
solution = Opt.min(objective, parameter).get_x()
address_idx = [idx for idx, i in enumerate(solution) if i != 0]
address_num = solution.sum()
address_num = int(solution.sum())
return address_idx, address_num
@@ -102,32 +114,34 @@ class AbducerBase(abc.ABC):
if max_address_num == -1:
max_address_num = len(pred_res)
if self.cache and (tuple(pred_res), key) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
if (tuple(pred_res), key, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
if self.cache and (tuple(self.flatten(pred_res)), key) in self.cache_min_address_num:
address_num = min(max_address_num, self.cache_min_address_num[(tuple(self.flatten(pred_res)), key)] + require_more_address)
if (tuple(self.flatten(pred_res)), key, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(self.flatten(pred_res)), key, address_num)]
if self.zoopt:
return candidates[0]
else:
return self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
return self.get_one_candidate(pred_res, pred_res_prob, candidates)
if self.zoopt:
address_idx, address_num = self.zoopt_get_address_idx(pred_res, key, max_address_num)
candidates = self.kb.address_by_idx(pred_res, key, address_idx)
candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)
min_address_num = address_num
candidate = candidates[0]
else:
candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address)
candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates)
candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address, self.multiple_predictions)
candidate = self.get_one_candidate(pred_res, pred_res_prob, candidates)
if self.cache:
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates
self.cache_min_address_num[(tuple(self.flatten(pred_res)), key)] = min_address_num
self.cache_candidates[(tuple(self.flatten(pred_res)), key, address_num)] = candidates

return candidate
def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)
def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0):
return [
@@ -185,7 +199,7 @@ if __name__ == '__main__':
print(res)
print()
kb = hwf_KB(len_list = [1, 3, 5])
kb = HWF_KB(len_list = [1, 3, 5])
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0)
print(res)
@@ -197,3 +211,23 @@ if __name__ == '__main__':
print(res)
print()
kb = HED_prolog_KB()
abd = AbducerBase(kb, zoopt=True, multiple_predictions=True)
consist_re = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]]
consist_re2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules
inconsist_re = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])']
print(kb.logic_forward(consist_re), kb.logic_forward(inconsist_re))
print(kb.consist_rule(consist_re, rules), kb.consist_rule(consist_re2, rules))
print()
res = abd.abduce((consist_re, None, None))
print(res)
res = abd.abduce((inconsist_re, None, None))
print(res)
print()
abduced_rules = abd.abduce_rules(consist_re)
print(abduced_rules)

Loading…
Cancel
Save