From 2587ac2f7e4d64b2539c10d97ff04927970a0e26 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 7 Dec 2022 12:52:14 +0800 Subject: [PATCH] Add kb and abducer for HED --- abducer/abducer_base.py | 80 +++++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/abducer/abducer_base.py b/abducer/abducer_base.py index e2072ec..8fde5a1 100644 --- a/abducer/abducer_base.py +++ b/abducer/abducer_base.py @@ -14,19 +14,20 @@ import sys sys.path.append("..") import abc -from abducer.kb import add_KB, hwf_KB, add_prolog_KB +from abducer.kb import add_KB, HWF_KB, add_prolog_KB, HED_prolog_KB import numpy as np from zoopt import Dimension, Objective, Parameter, Opt import time class AbducerBase(abc.ABC): - def __init__(self, kb, dist_func = 'confidence', zoopt = False, cache = True): + def __init__(self, kb, dist_func = 'confidence', zoopt = False, multiple_predictions = False, cache = True): self.kb = kb assert(dist_func == 'hamming' or dist_func == 'confidence') self.dist_func = dist_func - self.cache = cache self.zoopt = zoopt + self.multiple_predictions = multiple_predictions + self.cache = cache if self.cache: self.cache_min_address_num = {} @@ -58,10 +59,10 @@ class AbducerBase(abc.ABC): elif self.dist_func == 'confidence': return self.confidence_dist(pred_res_prob, candidates) - def get_min_cost_candidate(self, pred_res, pred_res_prob, candidates): + def get_one_candidate(self, pred_res, pred_res_prob, candidates): if len(candidates) == 0: return [] - elif len(candidates) == 1: + elif len(candidates) == 1 or self.zoopt: return candidates[0] else: cost_list = self.get_cost_list(pred_res, pred_res_prob, candidates) @@ -71,8 +72,19 @@ class AbducerBase(abc.ABC): + + # for multiple_prediction + def flatten(self, l): + if self.multiple_predictions: + return [item for sublist in l for item in sublist] + else: + return l + + + + # for zoopt def zoopt_address_score(self, pred_res, key, address_idx): - candidates = self.kb.address_by_idx(pred_res, key, address_idx) + candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) return 0 if len(candidates) > 0 else 1 def constraint_address_num(self, solution, max_address_num): @@ -80,9 +92,9 @@ class AbducerBase(abc.ABC): return max_address_num - x.sum() def zoopt_get_address_idx(self, pred_res, key, max_address_num): - dimension = Dimension(size=len(pred_res), - regs=[[0, 1]] * len(pred_res), - tys=[False] * len(pred_res)) + dimension = Dimension(size=len(self.flatten(pred_res)), + regs=[[0, 1]] * len(self.flatten(pred_res)), + tys=[False] * len(self.flatten(pred_res))) objective = Objective(lambda sol: self.zoopt_address_score(pred_res, key, [idx for idx, i in enumerate(sol.get_x()) if i != 0]), dim=dimension, constraint=lambda sol: self.constraint_address_num(sol, max_address_num)) @@ -90,7 +102,7 @@ class AbducerBase(abc.ABC): solution = Opt.min(objective, parameter).get_x() address_idx = [idx for idx, i in enumerate(solution) if i != 0] - address_num = solution.sum() + address_num = int(solution.sum()) return address_idx, address_num @@ -102,32 +114,34 @@ class AbducerBase(abc.ABC): if max_address_num == -1: max_address_num = len(pred_res) - if self.cache and (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) - if (tuple(pred_res), key, address_num) in self.cache_candidates: - candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] + if self.cache and (tuple(self.flatten(pred_res)), key) in self.cache_min_address_num: + address_num = min(max_address_num, self.cache_min_address_num[(tuple(self.flatten(pred_res)), key)] + require_more_address) + if (tuple(self.flatten(pred_res)), key, address_num) in self.cache_candidates: + candidates = self.cache_candidates[(tuple(self.flatten(pred_res)), key, address_num)] if self.zoopt: return candidates[0] else: - return self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) + return self.get_one_candidate(pred_res, pred_res_prob, candidates) if self.zoopt: address_idx, address_num = self.zoopt_get_address_idx(pred_res, key, max_address_num) - candidates = self.kb.address_by_idx(pred_res, key, address_idx) + candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) min_address_num = address_num - candidate = candidates[0] else: - candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address) - candidate = self.get_min_cost_candidate(pred_res, pred_res_prob, candidates) + candidates, min_address_num, address_num = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address, self.multiple_predictions) + + candidate = self.get_one_candidate(pred_res, pred_res_prob, candidates) if self.cache: - self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num - self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates + self.cache_min_address_num[(tuple(self.flatten(pred_res)), key)] = min_address_num + self.cache_candidates[(tuple(self.flatten(pred_res)), key, address_num)] = candidates - return candidate + def abduce_rules(self, pred_res): + return self.kb.abduce_rules(pred_res) + def batch_abduce(self, Z, Y, max_address_num = -1, require_more_address = 0): return [ @@ -185,7 +199,7 @@ if __name__ == '__main__': print(res) print() - kb = hwf_KB(len_list = [1, 3, 5]) + kb = HWF_KB(len_list = [1, 3, 5]) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '2'], None, 3), max_address_num = 2, require_more_address = 0) print(res) @@ -197,3 +211,23 @@ if __name__ == '__main__': print(res) print() + kb = HED_prolog_KB() + abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) + consist_re = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 0, '=', 1, 1]] + consist_re2 = [[1, '+', 0, '=', 0], [1, '+', 1, '=', 0], [0, '+', 1, '=', 1, 1]] # not consistent with rules + inconsist_re = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] + rules = ['my_op([0], [0], [1, 1])', 'my_op([1], [1], [0])', 'my_op([1], [0], [0])'] + + print(kb.logic_forward(consist_re), kb.logic_forward(inconsist_re)) + print(kb.consist_rule(consist_re, rules), kb.consist_rule(consist_re2, rules)) + print() + + res = abd.abduce((consist_re, None, None)) + print(res) + res = abd.abduce((inconsist_re, None, None)) + print(res) + print() + + abduced_rules = abd.abduce_rules(consist_re) + print(abduced_rules) +