diff --git a/weights/all_weights_here.txt b/abl/__init__.py similarity index 100% rename from weights/all_weights_here.txt rename to abl/__init__.py diff --git a/abl/abducer/__init__.py b/abl/abducer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/abducer/abducer_base.py b/abl/abducer/abducer_base.py similarity index 54% rename from abducer/abducer_base.py rename to abl/abducer/abducer_base.py index c0fd102..cf14bb7 100644 --- a/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -10,32 +10,15 @@ # # ================================================================# -import sys - -sys.path.append(".") -sys.path.append("..") - import abc -from abducer.kb import * import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from utils.utils import confidence_dist, flatten, hamming_dist - -import math -import time - +from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist class AbducerBase(abc.ABC): - def __init__( - self, - kb, - dist_func="confidence", - zoopt=False, - multiple_predictions=False, - cache=True, - ): + def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): self.kb = kb - assert dist_func == "hamming" or dist_func == "confidence" + assert dist_func == 'hamming' or dist_func == 'confidence' self.dist_func = dist_func self.zoopt = zoopt self.multiple_predictions = multiple_predictions @@ -46,41 +29,42 @@ class AbducerBase(abc.ABC): self.cache_candidates = {} def _get_cost_list(self, pred_res, pred_res_prob, candidates): - if self.dist_func == "hamming": + if self.dist_func == 'hamming': + if self.multiple_predictions: + pred_res = flatten(pred_res) + candidates = [flatten(c) for c in candidates] + return hamming_dist(pred_res, candidates) - elif self.dist_func == "confidence": - mapping = dict( - zip( - self.kb.pseudo_label_list, - list(range(len(self.kb.pseudo_label_list))), - ) - ) - return confidence_dist( - pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] - ) + + elif self.dist_func == 'confidence': + if self.multiple_predictions: + pred_res_prob = flatten(pred_res_prob) + candidates = [flatten(c) for c in candidates] + + mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) + candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] + return confidence_dist(pred_res_prob, candidates) def _get_one_candidate(self, pred_res, pred_res_prob, candidates): if len(candidates) == 0: return [] elif len(candidates) == 1 or self.zoopt: return candidates[0] + else: cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) min_address_num = np.min(cost_list) idxs = np.where(cost_list == min_address_num)[0] - return [candidates[idx] for idx in idxs][0] + candidate = [candidates[idx] for idx in idxs][0] + return candidate # for zoopt def _zoopt_score_multiple(self, pred_res, key, solution): all_address_flag = reform_idx(solution, pred_res) score = 0 for idx in range(len(pred_res)): - address_idx = [ - i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 - ] - candidate = self.kb.address_by_idx( - [pred_res[idx]], key[idx], address_idx, True - ) + address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] + candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) if len(candidate) > 0: score += 1 return score @@ -88,9 +72,7 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx( - pred_res, key, address_idx, self.multiple_predictions - ) + candidates = self.address_by_idx(pred_res, key, address_idx) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -107,7 +89,7 @@ class AbducerBase(abc.ABC): dim=dimension, constraint=lambda sol: self._constrain_address_num(sol, max_address_num), ) - parameter = Parameter(budget=100, autoset=True) + parameter = Parameter(budget=100, intermediate_result=False, autoset=True) solution = Opt.min(objective, parameter).get_x() return solution @@ -118,11 +100,7 @@ class AbducerBase(abc.ABC): pred_res = flatten(pred_res) key = tuple(key) if (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min( - max_address_num, - self.cache_min_address_num[(tuple(pred_res), key)] - + require_more_address, - ) + address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) if (tuple(pred_res), key, address_num) in self.cache_candidates: candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] if self.zoopt: @@ -137,6 +115,9 @@ class AbducerBase(abc.ABC): key = tuple(key) self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates + + def address_by_idx(self, pred_res, key, address_idx): + return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) def abduce(self, data, max_address_num=-1, require_more_address=0): pred_res, pred_res_prob, key = data @@ -151,18 +132,12 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx( - pred_res, key, address_idx, self.multiple_predictions - ) + candidates = self.address_by_idx(pred_res, key, address_idx) address_num = int(solution.sum()) min_address_num = address_num else: candidates, min_address_num, address_num = self.kb.abduce_candidates( - pred_res, - key, - max_address_num, - require_more_address, - self.multiple_predictions, + pred_res, key, max_address_num, require_more_address, self.multiple_predictions ) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) @@ -176,32 +151,22 @@ class AbducerBase(abc.ABC): return self.kb.abduce_rules(pred_res) def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): - if self.multiple_predictions: - return self.abduce( - (Z["cls"], Z["prob"], Y), max_address_num, require_more_address - ) - else: - return [ - self.abduce((z, prob, y), max_address_num, require_more_address) - for z, prob, y in zip(Z["cls"], Z["prob"], Y) - ] + # if self.multiple_predictions: + return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) + # else: + # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) +if __name__ == '__main__': + from kb import add_KB, prolog_KB, HWF_KB + + prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] -if __name__ == "__main__": - prob1 = [ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] - prob2 = [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] - - kb = add_KB() - abd = AbducerBase(kb, "confidence") + kb = add_KB(GKB_flag=True) + abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -213,9 +178,23 @@ if __name__ == "__main__": res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) print(res) print() + + + multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], + [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + + + kb = add_KB() + abd = AbducerBase(kb, 'confidence', multiple_predictions=True) + res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0) + print(res) + res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1) + print(res) + print() + - kb = add_prolog_KB() - abd = AbducerBase(kb, "confidence") + kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') + abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -228,8 +207,8 @@ if __name__ == "__main__": print(res) print() - kb = add_prolog_KB() - abd = AbducerBase(kb, "confidence", zoopt=True) + kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') + abd = AbducerBase(kb, 'confidence', zoopt=True) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -242,49 +221,63 @@ if __name__ == "__main__": print(res) print() - kb = HWF_KB(len_list=[1, 3, 5]) - abd = AbducerBase(kb, "hamming") - res = abd.abduce( - (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 - ) + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + print(res) + print() + + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + print(res) + res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + print(res) + print() + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming', multiple_predictions=True) + res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0) + print(res) + print() + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 - ) + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 - ) + res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "8", "8", "8", "8"], None, 3.17), - max_address_num=5, - require_more_address=3, - ) + res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) print(res) print() - kb = HED_prolog_KB() + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) - consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] - consist_exs2 = [ - [1, "+", 0, "=", 0], - [1, "+", 1, "=", 0], - [0, "+", 1, "=", 1, 1], - ] # not consistent with rules - inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] + consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] + inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] - rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] + rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] - print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) - print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) + print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True)) + print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() - res = abd.abduce((consist_exs, None, [1] * len(consist_exs))) + res = abd.abduce((consist_exs, None, [None] * len(consist_exs))) print(res) - res = abd.abduce((inconsist_exs, None, [1] * len(consist_exs))) + res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs))) print(res) print() abduced_rules = abd.abduce_rules(consist_exs) - print(abduced_rules) + print(abduced_rules) \ No newline at end of file diff --git a/abducer/kb.py b/abl/abducer/kb.py similarity index 50% rename from abducer/kb.py rename to abl/abducer/kb.py index b9beb3f..5b82f03 100644 --- a/abducer/kb.py +++ b/abl/abducer/kb.py @@ -15,93 +15,26 @@ import bisect import copy import numpy as np -import sys - -sys.path.append("..") - from collections import defaultdict from itertools import product, combinations -from utils.utils import flatten, reform_idx, hamming_dist, check_equal +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal from multiprocessing import Pool import pyswip - class KBBase(ABC): - def __init__(self, pseudo_label_list=None): - pass - - @abstractmethod - def logic_forward(self): - pass - - @abstractmethod - def abduce_candidates(self): - pass - - @abstractmethod - def address_by_idx(self): - pass - - def _address(self, address_num, pred_res, key, multiple_predictions=False): - new_candidates = [] - if not multiple_predictions: - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - else: - address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) - - for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) - new_candidates += candidates - return new_candidates - - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): - candidates = [] - - for address_num in range(len(flatten(pred_res)) + 1): - if address_num == 0: - if check_equal(self.logic_forward(pred_res), key): - candidates.append(pred_res) - else: - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - if len(candidates) > 0: - min_address_num = address_num - break - - if address_num >= max_address_num: - return [], 0, 0 - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - return candidates, min_address_num, address_num - - def __len__(self): - pass - - -class ClsKB(KBBase): - def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): - super().__init__() - self.GKB_flag = GKB_flag + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list self.len_list = len_list + self.GKB_flag = GKB_flag + self.max_err = max_err if GKB_flag: self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) - else: - self.all_address_candidate_dict = {} - for address_num in range(max(self.len_list) + 1): - self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -130,39 +63,80 @@ class ClsKB(KBBase): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) + if type(Y[0]) in (int, float): + sorted_XY = sorted(list(zip(Y, X))) + X = [x for y, x in sorted_XY] + Y = [y for y, x in sorted_XY] return X, Y - def logic_forward(self): + @abstractmethod + def logic_forward(self, pseudo_labels): pass + + def _logic_forward(self, xs, multiple_predictions=False): + if not multiple_predictions: + return self.logic_forward(xs) + else: + res = [self.logic_forward(x) for x in xs] + return res - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): + def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): if self.GKB_flag: - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) + + @abstractmethod + def _find_candidate_GKB(self, pred_res, key): + pass + + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + if self.base == {}: + return [], 0, 0 - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): - if self.base == {} or len(pred_res) not in self.len_list: - return [] - - all_candidates = self.base[len(pred_res)][key] - - if len(all_candidates) == 0: - candidates = [] - min_address_num = 0 - address_num = 0 + if not multiple_predictions: + if len(pred_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._find_candidate_GKB(pred_res, key) + if len(all_candidates) == 0: + return [], 0, 0 + else: + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num + else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) + min_address_num = 0 + all_candidates_save = [] + cost_list_save = [] + + for p_res, k in zip(pred_res, key): + if len(p_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._find_candidate_GKB(p_res, k) + if len(all_candidates) == 0: + return [], 0, 0 + else: + all_candidates_save.append(all_candidates) + cost_list = hamming_dist(p_res, all_candidates) + min_address_num += np.min(cost_list) + cost_list_save.append(cost_list) + + multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] + assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) + multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + assert len(multiple_all_candidates) == len(multiple_cost_list) address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - - return candidates, min_address_num, address_num - + idxs = np.where(multiple_cost_list <= address_num)[0] + candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + return candidates, min_address_num, address_num + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - abduce_c = self.all_address_candidate_dict[len(address_idx)] + abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) if multiple_predictions: save_pred_res = pred_res @@ -176,10 +150,48 @@ class ClsKB(KBBase): if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if self.logic_forward(candidate) == key: + if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): candidates.append(candidate) return candidates + def _address(self, address_num, pred_res, key, multiple_predictions): + new_candidates = [] + if not multiple_predictions: + address_idx_list = combinations(list(range(len(pred_res))), address_num) + else: + address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num) + + for address_idx in address_idx_list: + candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) + new_candidates += candidates + return new_candidates + + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + candidates = [] + + for address_num in range(len(flatten(pred_res)) + 1): + if address_num == 0: + if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): + candidates.append(pred_res) + else: + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + if len(candidates) > 0: + min_address_num = address_num + break + + if address_num >= max_address_num: + return [], 0, 0 + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates, min_address_num, address_num - 1 + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + return candidates, min_address_num, address_num + def _dict_len(self, dic): if not self.GKB_flag: return 0 @@ -193,130 +205,77 @@ class ClsKB(KBBase): return sum(self._dict_len(v) for v in self.base.values()) -class add_KB(ClsKB): - def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): - super().__init__(GKB_flag, pseudo_label_list, len_list) - - def logic_forward(self, nums): - return sum(nums) +class ClsKB(KBBase): + def __init__(self, pseudo_label_list, len_list, GKB_flag): + super().__init__(pseudo_label_list, len_list, GKB_flag) + def _find_candidate_GKB(self, pred_res, key): + return self.base[len(pred_res)][key] -class HWF_KB(ClsKB): - def __init__( - self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] - ): - super().__init__(GKB_flag, pseudo_label_list, len_list) - def valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: - return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: - return False - return True +class add_KB(ClsKB): + def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): + super().__init__(pseudo_label_list, len_list, GKB_flag) - def logic_forward(self, formula): - if not self.valid_candidate(formula): - return np.inf - mapping = { - '1': '1', - '2': '2', - '3': '3', - '4': '4', - '5': '5', - '6': '6', - '7': '7', - '8': '8', - '9': '9', - '+': '+', - '-': '-', - 'times': '*', - 'div': '/', - } - formula = [mapping[f] for f in formula] - return round(eval(''.join(formula)), 2) + def logic_forward(self, nums): + return sum(nums) class prolog_KB(KBBase): - def __init__(self, pseudo_label_list): - super().__init__() - self.pseudo_label_list = pseudo_label_list + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() + self.prolog.consult(pl_file) - def logic_forward(self): - pass - - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) + def logic_forward(self, pseudo_labels): + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] + if result == 'true': + return True + elif result == 'false': + return False + return result + + def _address_pred_res(self, pred_res, address_idx, multiple_predictions): + import re + address_pred_res = pred_res.copy() + if multiple_predictions: + address_pred_res = flatten(address_pred_res) + + for idx in address_idx: + address_pred_res[idx] = 'P' + str(idx) + if multiple_predictions: + address_pred_res = reform_idx(address_pred_res, pred_res) + + # TODO:不知道有没有更简洁的方法 + regex = r"'P\d+'" + return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) + + def get_query_string(self, pred_res, key, address_idx, multiple_predictions): + query_string = "logic_forward(" + query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) + key_is_none_flag = key is None or (type(key) == list and key[0] is None) + query_string += ",%s)." % key if not key_is_none_flag else ")." + return query_string + def _find_candidate_GKB(self, pred_res, key): + pass + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - # print(address_idx) - if not multiple_predictions: - query_string = self.get_query_string(pred_res, key, address_idx) - else: - query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) - + query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) - - abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] + abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - candidates.append(candidate) return candidates - -class add_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list=list(range(10))): - super().__init__(pseudo_label_list) - for i in self.pseudo_label_list: - self.prolog.assertz("pseudo_label(%s)" % i) - self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") - - def logic_forward(self, nums): - return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] - - def get_query_string(self, pred_res, key, address_idx): - query_string = "addition(" - for idx, i in enumerate(pred_res): - tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' - query_string += tmp - query_string += "%s)." % key - return query_string - - -class HED_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list=[0, 1, '+', '=']): - super().__init__(pseudo_label_list) - self.prolog.consult('./datasets/hed/learn_add.pl') - - # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` - def logic_forward(self, exs): - return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 - - def get_query_string_need_flatten(self, pred_res, key, address_idx): - # flatten - flatten_pred_res = flatten(pred_res) - # add variables for prolog - for idx in range(len(flatten_pred_res)): - if idx in address_idx: - flatten_pred_res[idx] = 'X' + str(idx) - # unflatten - new_pred_res = reform_idx(flatten_pred_res, pred_res) - - query_string = "abduce_consistent_insts(%s)." % new_pred_res - return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") - def consist_rule(self, exs, rules): rules = str(rules).replace("\'","") return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 @@ -327,92 +286,79 @@ class HED_prolog_KB(prolog_KB): if len(prolog_result) == 0: return None prolog_rules = prolog_result[0]['X'] - rules = [] - for rule in prolog_rules: - rules.append(rule.value) + rules = [rule.value for rule in prolog_rules] return rules - # def consist_rules(self, pred_res, rules): - class RegKB(KBBase): - def __init__(self, GKB_flag=False, X=None, Y=None): - super().__init__() - tmp_dict = {} - for x, y in zip(X, Y): - tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) - - self.base = {} - for l in tmp_dict.keys(): - data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) - X = [x for y, x in data] - Y = [y for y, x in data] - self.base[l] = (X, Y) - - def valid_candidate(self): - pass - - def logic_forward(self): - pass - - def abduce_candidates(self, key, length=None): - if key is None: - return self.get_all_candidates() - - length = self._length(length) + def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=1e-3): + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) + + def _find_candidate_GKB(self, pred_res, key): + potential_candidates = self.base[len(pred_res)] + key_list = list(potential_candidates.keys()) + key_idx = bisect.bisect_left(key_list, key) + + all_candidates = [] + for idx in range(key_idx - 1, 0, -1): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + + for idx in range(key_idx, len(key_list)): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + return all_candidates + - min_err = 999999 - candidates = [] - for l in length: - X, Y = self.base[l] - - idx = bisect.bisect_left(Y, key) - begin = max(0, idx - 1) - end = min(idx + 2, len(X)) - - for idx in range(begin, end): - err = abs(Y[idx] - key) - if abs(err - min_err) < 1e-9: - candidates.extend(X[idx]) - elif err < min_err: - candidates = copy.deepcopy(X[idx]) - min_err = err - return candidates +class HWF_KB(RegKB): + def __init__( + self, + pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], + len_list=[1, 3, 5, 7], + GKB_flag=False, + max_err=1e-3 + ): + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - def get_all_candidates(self): - return sum([sum(D[0], []) for D in self.base.values()], []) + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + return False + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + return False + return True - def __len__(self): - return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return np.inf + mapping = { + '1': '1', + '2': '2', + '3': '3', + '4': '4', + '5': '5', + '6': '6', + '7': '7', + '8': '8', + '9': '9', + '+': '+', + '-': '-', + 'times': '*', + 'div': '/', + } + formula = [mapping[f] for f in formula] + return eval(''.join(formula)) import time if __name__ == "__main__": - t1 = time.time() - kb = HWF_KB(True) - t2 = time.time() - print(t2 - t1) - - # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] - # Y = [2, 1, 1, 2, 2] - # kb = ClsKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(2, 5) - # print(res) - # res = kb.get_candidates(2, 3) - # print(res) - # res = kb.get_candidates(None) - # print(res) - # print() - - # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] - # Y = [2, 1, 1, 2, 1.5, 1.5] - # kb = RegKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(1.6) - # print(res) - # res = kb.get_candidates(1.6, length = 9) - # print(res) - # res = kb.get_candidates(None) - # print(res) + pass \ No newline at end of file diff --git a/framework.py b/abl/framework.py similarity index 98% rename from framework.py rename to abl/framework.py index 2a1fdad..c3a5c6e 100644 --- a/framework.py +++ b/abl/framework.py @@ -14,7 +14,7 @@ import pickle as pk import numpy as np -from utils.plog import INFO, DEBUG, clocker +from .utils.plog import INFO, DEBUG, clocker def block_sample(X, Z, Y, sample_num, epoch_idx): part_num = (len(X) // sample_num) diff --git a/framework_hed.py b/abl/framework_hed.py similarity index 95% rename from framework_hed.py rename to abl/framework_hed.py index b7439c3..3f09ff6 100644 --- a/framework_hed.py +++ b/abl/framework_hed.py @@ -16,12 +16,15 @@ import torch.nn as nn import numpy as np import os -from utils.plog import INFO, DEBUG, clocker -from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res +from .utils.plog import INFO, DEBUG, clocker +from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res -from models.nn import MLP, SymbolNetAutoencoder -from models.basic_model import BasicModel, BasicDataset -from datasets.hed.get_hed import get_pretrain_data +from .models.nn import MLP, SymbolNetAutoencoder +from .models.basic_model import BasicModel, BasicDataset + +import sys +sys.path.append("..") +from examples.datasets.hed.get_hed import get_pretrain_data def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): result = {} @@ -147,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for m in mappings: pred_res = mapping_res(original_pred_res, m) max_abduce_num = 20 - solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num) + solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num) all_address_flag = reform_idx(solution, pred_res) consistent_idx_tmp = [] @@ -155,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) + candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0]) @@ -211,7 +214,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, consistent_idx = [] consistent_pred_res = [] for idx in range(len(pred_res)): - if abducer.kb.logic_forward([pred_res[idx]]): + if abducer.kb.logic_forward(pred_res[idx]): consistent_idx.append(idx) consistent_pred_res.append(pred_res[idx]) diff --git a/abl/models/__init__.py b/abl/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/basic_model.py b/abl/models/basic_model.py similarity index 100% rename from models/basic_model.py rename to abl/models/basic_model.py diff --git a/models/lenet5.py b/abl/models/lenet5.py similarity index 100% rename from models/lenet5.py rename to abl/models/lenet5.py diff --git a/models/nn.py b/abl/models/nn.py similarity index 97% rename from models/nn.py rename to abl/models/nn.py index cecbeea..7a0f560 100644 --- a/models/nn.py +++ b/abl/models/nn.py @@ -10,9 +10,6 @@ # # ================================================================# -import sys - -sys.path.append("..") import torchvision @@ -23,8 +20,6 @@ from torch.autograd import Variable import torchvision.transforms as transforms import numpy as np -from models.basic_model import BasicModel -import utils.plog as plog class LeNet5(nn.Module): diff --git a/models/wabl_models.py b/abl/models/wabl_models.py similarity index 99% rename from models/wabl_models.py rename to abl/models/wabl_models.py index 9c97bbc..3b682ee 100644 --- a/models/wabl_models.py +++ b/abl/models/wabl_models.py @@ -21,7 +21,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process.kernels import RBF -from models.basic_model import BasicModel import pickle as pk import random diff --git a/utils/plog.py b/abl/utils/plog.py similarity index 100% rename from utils/plog.py rename to abl/utils/plog.py diff --git a/utils/utils.py b/abl/utils/utils.py similarity index 83% rename from utils/utils.py rename to abl/utils/utils.py index 5cd433d..67c3cf7 100644 --- a/utils/utils.py +++ b/abl/utils/utils.py @@ -1,30 +1,23 @@ import torch import torch.nn as nn import numpy as np -from utils.plog import INFO +from .plog import INFO from collections import OrderedDict +from itertools import chain -# for multiple predictions, modify from `learn_add.py` +# for multiple predictions def flatten(l): - return ( - [item for sublist in l for item in flatten(sublist)] - if isinstance(l, list) - else [l] - ) - - -# for multiple predictions, modify from `learn_add.py` + if not isinstance(l[0], (list, tuple)): + return l + return list(chain.from_iterable(l)) + +# for multiple predictions def reform_idx(flatten_pred_res, save_pred_res): re = [] i = 0 for e in save_pred_res: - j = 0 - idx = [] - while j < len(e): - idx.append(flatten_pred_res[i + j]) - j += 1 - re.append(idx) - i = i + j + re.append(flatten_pred_res[i:i + len(e)]) + i += len(e) return re @@ -85,11 +78,10 @@ def remapping_res(pred_res, m): remapping[value] = key return [[remapping[symbol] for symbol in formula] for formula in pred_res] - -def check_equal(a, b): +def check_equal(a, b, max_err=0): if isinstance(a, (int, float)) and isinstance(b, (int, float)): - return abs(a - b) <= 1e-3 - + return abs(a - b) <= max_err + if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): return False @@ -119,4 +111,4 @@ def reduce_dimension(data): [extract_feature(symbol_img) for symbol_img in equation] for equation in equations ] - data[truth_value][equation_len] = reduced_equations + data[truth_value][equation_len] = reduced_equations diff --git a/datasets/data_generator.py b/examples/datasets/data_generator.py similarity index 100% rename from datasets/data_generator.py rename to examples/datasets/data_generator.py diff --git a/datasets/hed/BK.pl b/examples/datasets/hed/BK.pl similarity index 100% rename from datasets/hed/BK.pl rename to examples/datasets/hed/BK.pl diff --git a/datasets/hed/README.md b/examples/datasets/hed/README.md similarity index 100% rename from datasets/hed/README.md rename to examples/datasets/hed/README.md diff --git a/datasets/hed/get_hed.py b/examples/datasets/hed/get_hed.py similarity index 100% rename from datasets/hed/get_hed.py rename to examples/datasets/hed/get_hed.py diff --git a/datasets/hed/learn_add.pl b/examples/datasets/hed/learn_add.pl similarity index 92% rename from datasets/hed/learn_add.pl rename to examples/datasets/hed/learn_add.pl index af1d6bb..fbf698f 100644 --- a/datasets/hed/learn_add.pl +++ b/examples/datasets/hed/learn_add.pl @@ -32,6 +32,9 @@ abduce_consistent_insts(Exs):- % (Experimental) Uncomment to use parallel abduction % abduce_consistent_exs_concurrent(Exs), !. +logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false. +logic_forward(Exs) :- abduce_consistent_insts(Exs). + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% Abduce Delta_C given pseudo-labels %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/datasets/hwf/README.md b/examples/datasets/hwf/README.md similarity index 100% rename from datasets/hwf/README.md rename to examples/datasets/hwf/README.md diff --git a/datasets/hwf/get_hwf.py b/examples/datasets/hwf/get_hwf.py similarity index 100% rename from datasets/hwf/get_hwf.py rename to examples/datasets/hwf/get_hwf.py diff --git a/examples/datasets/mnist_add/add.pl b/examples/datasets/mnist_add/add.pl new file mode 100644 index 0000000..96f0869 --- /dev/null +++ b/examples/datasets/mnist_add/add.pl @@ -0,0 +1,2 @@ +pseudo_label(N) :- between(0, 9, N). +logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2. diff --git a/datasets/mnist_add/get_mnist_add.py b/examples/datasets/mnist_add/get_mnist_add.py similarity index 100% rename from datasets/mnist_add/get_mnist_add.py rename to examples/datasets/mnist_add/get_mnist_add.py diff --git a/datasets/mnist_add/test_data.txt b/examples/datasets/mnist_add/test_data.txt similarity index 100% rename from datasets/mnist_add/test_data.txt rename to examples/datasets/mnist_add/test_data.txt diff --git a/datasets/mnist_add/train_data.txt b/examples/datasets/mnist_add/train_data.txt similarity index 100% rename from datasets/mnist_add/train_data.txt rename to examples/datasets/mnist_add/train_data.txt diff --git a/example.py b/examples/example.py similarity index 86% rename from example.py rename to examples/example.py index 3751db3..0a69539 100644 --- a/example.py +++ b/examples/example.py @@ -10,24 +10,24 @@ # # ================================================================# -from utils.plog import logger, INFO -from utils.utils import reduce_dimension +import sys +sys.path.append("../") + +from abl.utils.plog import logger, INFO import torch.nn as nn import torch -from models.nn import LeNet5, SymbolNet -from models.basic_model import BasicModel, BasicDataset -from models.wabl_models import DecisionTree, WABLBasicModel -from sklearn.neighbors import KNeighborsClassifier +from abl.models.nn import LeNet5, SymbolNet +from abl.models.basic_model import BasicModel, BasicDataset +from abl.models.wabl_models import DecisionTree, WABLBasicModel from multiprocessing import Pool -from abducer.abducer_base import AbducerBase -from abducer.kb import add_KB, HWF_KB, HED_prolog_KB +from abl.abducer.abducer_base import AbducerBase +from abl.abducer.kb import add_KB, HWF_KB, prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, split_equation -import framework_hed -import framework_hed_knn +from abl import framework_hed def run_test(): @@ -36,7 +36,7 @@ def run_test(): # kb = HWF_KB(True) # abducer = AbducerBase(kb) - kb = HED_prolog_KB() + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() diff --git a/nonshare_example.py b/examples/nonshare_example.py similarity index 100% rename from nonshare_example.py rename to examples/nonshare_example.py diff --git a/share_example.py b/examples/share_example.py similarity index 100% rename from share_example.py rename to examples/share_example.py diff --git a/examples/weights/all_weights_here.txt b/examples/weights/all_weights_here.txt new file mode 100644 index 0000000..e69de29