From 83dff66d9ff4b8871b8496c9cd75536625bfc070 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 15 Nov 2022 20:04:27 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 123 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index af4c8ec..c3ddff0 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -17,7 +17,7 @@ import numpy as np from collections import defaultdict -from itertools import product +from itertools import product, combinations class KBBase(ABC): def __init__(self): @@ -80,20 +80,30 @@ class add_KB(KBBase): def get_all_candidates(self): return sum([sum(v.values(), []) for v in self.base.values()], []) - def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): + def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address): if key is None: return self.get_all_candidates() candidates = [] - all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) - for address_num in range(length + 1): + + for address_num in range(len(pred_res) + 1): if(address_num > max_address_num): print('No candidates found') return None, None, None - for c in all_candidates: - if(dist_func(c, pred_res) == address_num): - if(self.logic_forward(c) == key): - candidates.append(c) + + if(address_num == 0): + if(self.logic_forward(pred_res) == key): + candidates.append(pred_res) + else: + all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + for address_idx in address_idx_list: + for c in all_address_candidate: + pred_res_array = np.array(pred_res) + pred_res_array[np.array(address_idx)] = c + if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): + candidates.append(pred_res_array) + if(len(candidates) > 0): min_address_num = address_num break @@ -101,10 +111,14 @@ class add_KB(KBBase): for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): if(address_num > max_address_num): return candidates, min_address_num, address_num - 1 - for c in all_candidates: - if(dist_func(c, pred_res) == address_num): - if(self.logic_forward(c) == key): - candidates.append(c) + all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + for address_idx in address_idx_list: + for c in all_candidate: + pred_res_array = np.array(pred_res) + pred_res_array[np.array(address_idx)] = c + if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): + candidates.append(pred_res_array) return candidates, min_address_num, address_num @@ -114,6 +128,91 @@ class add_KB(KBBase): def __len__(self): return sum(self._dict_len(v) for v in self.base.values()) + +# class hwf_KB(KBBase): +# def __init__(self, pseudo_label_list, max_len = 5): +# super().__init__() +# self.pseudo_label_list = pseudo_label_list +# self.base = {} + +# X = self.get_X(self.pseudo_label_list, max_len) +# Y = self.get_Y(X, self.logic_forward) + +# for x, y in zip(X, Y): +# self.base.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) + +# def logic_forward(self, nums): +# return sum(nums) + +# def get_X(self, pseudo_label_list, max_len): +# res = [] +# assert(max_len >= 2) +# for len in range(2, max_len + 1): +# res += list(product(pseudo_label_list, repeat = len)) +# return res + +# def get_Y(self, X, logic_forward): +# return [logic_forward(nums) for nums in X] + +# def get_candidates(self, key, length = None): +# if key is None: +# return self.get_all_candidates() + +# length = self._length(length) +# return sum([self.base[l][key] for l in length], []) + +# def get_all_candidates(self): +# return sum([sum(v.values(), []) for v in self.base.values()], []) + +# def get_abduce_candidates(self, pred_res, key, length, dist_func, max_address_num, require_more_address): +# if key is None: +# return self.get_all_candidates() + +# candidates = [] +# # all_candidates = list(product(self.pseudo_label_list, repeat = len(pred_res))) + +# for address_num in range(length + 1): +# if(address_num > max_address_num): +# print('No candidates found') +# return None, None, None + +# if(address_num == 0): +# if(self.logic_forward(pred_res) == key): +# candidates.append(pred_res) +# else: +# all_address_candidate = list(product(self.pseudo_label_list, repeat = address_num)) +# address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) +# for address_idx in address_idx_list: +# for c in all_address_candidate: +# pred_res_array = np.array(pred_res) +# pred_res_array[np.array(address_idx)] = c +# if(np.count_nonzero(np.array(c) != np.array(pred_res)[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): +# candidates.append(pred_res_array) + +# if(len(candidates) > 0): +# min_address_num = address_num +# break + +# for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): +# if(address_num > max_address_num): +# return candidates, min_address_num, address_num - 1 +# all_candidate = list(product(self.pseudo_label_list, repeat = address_num)) +# address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) +# for address_idx in address_idx_list: +# for c in all_candidate: +# pred_res_array = np.array(pred_res) +# pred_res_array[np.array(address_idx)] = c +# if(np.count_nonzero(np.array(c) != pred_res_array[np.array(address_idx)]) == address_num and self.logic_forward(pred_res_array) == key): +# candidates.append(pred_res_array) + +# return candidates, min_address_num, address_num + + +# def _dict_len(self, dic): +# return sum(len(c) for c in dic.values()) + +# def __len__(self): +# return sum(self._dict_len(v) for v in self.base.values()) class cls_KB(KBBase): def __init__(self, X, Y = None):