| @@ -10,32 +10,15 @@ | |||
| # | |||
| # ================================================================# | |||
| import sys | |||
| sys.path.append(".") | |||
| sys.path.append("..") | |||
| import abc | |||
| from abducer.kb import * | |||
| import numpy as np | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from utils.utils import confidence_dist, flatten, hamming_dist | |||
| import math | |||
| import time | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist | |||
| class AbducerBase(abc.ABC): | |||
| def __init__( | |||
| self, | |||
| kb, | |||
| dist_func="confidence", | |||
| zoopt=False, | |||
| multiple_predictions=False, | |||
| cache=True, | |||
| ): | |||
| def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): | |||
| self.kb = kb | |||
| assert dist_func == "hamming" or dist_func == "confidence" | |||
| assert dist_func == 'hamming' or dist_func == 'confidence' | |||
| self.dist_func = dist_func | |||
| self.zoopt = zoopt | |||
| self.multiple_predictions = multiple_predictions | |||
| @@ -46,41 +29,42 @@ class AbducerBase(abc.ABC): | |||
| self.cache_candidates = {} | |||
| def _get_cost_list(self, pred_res, pred_res_prob, candidates): | |||
| if self.dist_func == "hamming": | |||
| if self.dist_func == 'hamming': | |||
| if self.multiple_predictions: | |||
| pred_res = flatten(pred_res) | |||
| candidates = [flatten(c) for c in candidates] | |||
| return hamming_dist(pred_res, candidates) | |||
| elif self.dist_func == "confidence": | |||
| mapping = dict( | |||
| zip( | |||
| self.kb.pseudo_label_list, | |||
| list(range(len(self.kb.pseudo_label_list))), | |||
| ) | |||
| ) | |||
| return confidence_dist( | |||
| pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] | |||
| ) | |||
| elif self.dist_func == 'confidence': | |||
| if self.multiple_predictions: | |||
| pred_res_prob = flatten(pred_res_prob) | |||
| candidates = [flatten(c) for c in candidates] | |||
| mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) | |||
| candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] | |||
| return confidence_dist(pred_res_prob, candidates) | |||
| def _get_one_candidate(self, pred_res, pred_res_prob, candidates): | |||
| if len(candidates) == 0: | |||
| return [] | |||
| elif len(candidates) == 1 or self.zoopt: | |||
| return candidates[0] | |||
| else: | |||
| cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) | |||
| min_address_num = np.min(cost_list) | |||
| idxs = np.where(cost_list == min_address_num)[0] | |||
| return [candidates[idx] for idx in idxs][0] | |||
| candidate = [candidates[idx] for idx in idxs][0] | |||
| return candidate | |||
| # for zoopt | |||
| def _zoopt_score_multiple(self, pred_res, key, solution): | |||
| all_address_flag = reform_idx(solution, pred_res) | |||
| score = 0 | |||
| for idx in range(len(pred_res)): | |||
| address_idx = [ | |||
| i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 | |||
| ] | |||
| candidate = self.kb.address_by_idx( | |||
| [pred_res[idx]], key[idx], address_idx, True | |||
| ) | |||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||
| candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) | |||
| if len(candidate) > 0: | |||
| score += 1 | |||
| return score | |||
| @@ -88,9 +72,7 @@ class AbducerBase(abc.ABC): | |||
| def _zoopt_address_score(self, pred_res, key, sol): | |||
| if not self.multiple_predictions: | |||
| address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] | |||
| candidates = self.kb.address_by_idx( | |||
| pred_res, key, address_idx, self.multiple_predictions | |||
| ) | |||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||
| return 1 if len(candidates) > 0 else 0 | |||
| else: | |||
| return self._zoopt_score_multiple(pred_res, key, sol.get_x()) | |||
| @@ -107,7 +89,7 @@ class AbducerBase(abc.ABC): | |||
| dim=dimension, | |||
| constraint=lambda sol: self._constrain_address_num(sol, max_address_num), | |||
| ) | |||
| parameter = Parameter(budget=100, autoset=True) | |||
| parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
| solution = Opt.min(objective, parameter).get_x() | |||
| return solution | |||
| @@ -118,11 +100,7 @@ class AbducerBase(abc.ABC): | |||
| pred_res = flatten(pred_res) | |||
| key = tuple(key) | |||
| if (tuple(pred_res), key) in self.cache_min_address_num: | |||
| address_num = min( | |||
| max_address_num, | |||
| self.cache_min_address_num[(tuple(pred_res), key)] | |||
| + require_more_address, | |||
| ) | |||
| address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) | |||
| if (tuple(pred_res), key, address_num) in self.cache_candidates: | |||
| candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] | |||
| if self.zoopt: | |||
| @@ -137,6 +115,9 @@ class AbducerBase(abc.ABC): | |||
| key = tuple(key) | |||
| self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num | |||
| self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates | |||
| def address_by_idx(self, pred_res, key, address_idx): | |||
| return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) | |||
| def abduce(self, data, max_address_num=-1, require_more_address=0): | |||
| pred_res, pred_res_prob, key = data | |||
| @@ -151,18 +132,12 @@ class AbducerBase(abc.ABC): | |||
| if self.zoopt: | |||
| solution = self.zoopt_get_solution(pred_res, key, max_address_num) | |||
| address_idx = [idx for idx, i in enumerate(solution) if i != 0] | |||
| candidates = self.kb.address_by_idx( | |||
| pred_res, key, address_idx, self.multiple_predictions | |||
| ) | |||
| candidates = self.address_by_idx(pred_res, key, address_idx) | |||
| address_num = int(solution.sum()) | |||
| min_address_num = address_num | |||
| else: | |||
| candidates, min_address_num, address_num = self.kb.abduce_candidates( | |||
| pred_res, | |||
| key, | |||
| max_address_num, | |||
| require_more_address, | |||
| self.multiple_predictions, | |||
| pred_res, key, max_address_num, require_more_address, self.multiple_predictions | |||
| ) | |||
| candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) | |||
| @@ -176,32 +151,22 @@ class AbducerBase(abc.ABC): | |||
| return self.kb.abduce_rules(pred_res) | |||
| def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| if self.multiple_predictions: | |||
| return self.abduce( | |||
| (Z["cls"], Z["prob"], Y), max_address_num, require_more_address | |||
| ) | |||
| else: | |||
| return [ | |||
| self.abduce((z, prob, y), max_address_num, require_more_address) | |||
| for z, prob, y in zip(Z["cls"], Z["prob"], Y) | |||
| ] | |||
| # if self.multiple_predictions: | |||
| return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) | |||
| # else: | |||
| # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] | |||
| def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): | |||
| return self.batch_abduce(Z, Y, max_address_num, require_more_address) | |||
| if __name__ == '__main__': | |||
| from kb import add_KB, prolog_KB, HWF_KB | |||
| prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| if __name__ == "__main__": | |||
| prob1 = [ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| prob2 = [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ] | |||
| kb = add_KB() | |||
| abd = AbducerBase(kb, "confidence") | |||
| kb = add_KB(GKB_flag=True) | |||
| abd = AbducerBase(kb, 'confidence') | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -213,9 +178,23 @@ if __name__ == "__main__": | |||
| res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| print() | |||
| multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], | |||
| [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| kb = add_KB() | |||
| abd = AbducerBase(kb, 'confidence', multiple_predictions=True) | |||
| res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1) | |||
| print(res) | |||
| print() | |||
| kb = add_prolog_KB() | |||
| abd = AbducerBase(kb, "confidence") | |||
| kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') | |||
| abd = AbducerBase(kb, 'confidence') | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -228,8 +207,8 @@ if __name__ == "__main__": | |||
| print(res) | |||
| print() | |||
| kb = add_prolog_KB() | |||
| abd = AbducerBase(kb, "confidence", zoopt=True) | |||
| kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') | |||
| abd = AbducerBase(kb, 'confidence', zoopt=True) | |||
| res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) | |||
| @@ -242,49 +221,63 @@ if __name__ == "__main__": | |||
| print(res) | |||
| print() | |||
| kb = HWF_KB(len_list=[1, 3, 5]) | |||
| abd = AbducerBase(kb, "hamming") | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 | |||
| ) | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| print() | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) | |||
| print(res) | |||
| print() | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) | |||
| abd = AbducerBase(kb, 'hamming', multiple_predictions=True) | |||
| res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0) | |||
| print(res) | |||
| print() | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 | |||
| ) | |||
| kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce( | |||
| (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 | |||
| ) | |||
| res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) | |||
| print(res) | |||
| res = abd.abduce( | |||
| (["5", "8", "8", "8", "8"], None, 3.17), | |||
| max_address_num=5, | |||
| require_more_address=3, | |||
| ) | |||
| res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) | |||
| print(res) | |||
| print() | |||
| kb = HED_prolog_KB() | |||
| kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||
| abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
| consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] | |||
| consist_exs2 = [ | |||
| [1, "+", 0, "=", 0], | |||
| [1, "+", 1, "=", 0], | |||
| [0, "+", 1, "=", 1, 1], | |||
| ] # not consistent with rules | |||
| inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
| consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] | |||
| inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] | |||
| # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] | |||
| rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] | |||
| rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] | |||
| print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) | |||
| print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) | |||
| print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True)) | |||
| print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) | |||
| print() | |||
| res = abd.abduce((consist_exs, None, [1] * len(consist_exs))) | |||
| res = abd.abduce((consist_exs, None, [None] * len(consist_exs))) | |||
| print(res) | |||
| res = abd.abduce((inconsist_exs, None, [1] * len(consist_exs))) | |||
| res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs))) | |||
| print(res) | |||
| print() | |||
| abduced_rules = abd.abduce_rules(consist_exs) | |||
| print(abduced_rules) | |||
| print(abduced_rules) | |||
| @@ -15,93 +15,26 @@ import bisect | |||
| import copy | |||
| import numpy as np | |||
| import sys | |||
| sys.path.append("..") | |||
| from collections import defaultdict | |||
| from itertools import product, combinations | |||
| from utils.utils import flatten, reform_idx, hamming_dist, check_equal | |||
| from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal | |||
| from multiprocessing import Pool | |||
| import pyswip | |||
| class KBBase(ABC): | |||
| def __init__(self, pseudo_label_list=None): | |||
| pass | |||
| @abstractmethod | |||
| def logic_forward(self): | |||
| pass | |||
| @abstractmethod | |||
| def abduce_candidates(self): | |||
| pass | |||
| @abstractmethod | |||
| def address_by_idx(self): | |||
| pass | |||
| def _address(self, address_num, pred_res, key, multiple_predictions=False): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) | |||
| else: | |||
| address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self.logic_forward(pred_res), key): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def __len__(self): | |||
| pass | |||
| class ClsKB(KBBase): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): | |||
| super().__init__() | |||
| self.GKB_flag = GKB_flag | |||
| def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): | |||
| self.pseudo_label_list = pseudo_label_list | |||
| self.len_list = len_list | |||
| self.GKB_flag = GKB_flag | |||
| self.max_err = max_err | |||
| if GKB_flag: | |||
| self.base = {} | |||
| X, Y = self._get_GKB() | |||
| for x, y in zip(X, Y): | |||
| self.base.setdefault(len(x), defaultdict(list))[y].append(x) | |||
| else: | |||
| self.all_address_candidate_dict = {} | |||
| for address_num in range(max(self.len_list) + 1): | |||
| self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) | |||
| # For parallel version of _get_GKB | |||
| def _get_XY_list(self, args): | |||
| @@ -130,39 +63,80 @@ class ClsKB(KBBase): | |||
| part_X, part_Y = zip(*XY_list) | |||
| X.extend(part_X) | |||
| Y.extend(part_Y) | |||
| if type(Y[0]) in (int, float): | |||
| sorted_XY = sorted(list(zip(Y, X))) | |||
| X = [x for y, x in sorted_XY] | |||
| Y = [y for y, x in sorted_XY] | |||
| return X, Y | |||
| def logic_forward(self): | |||
| @abstractmethod | |||
| def logic_forward(self, pseudo_labels): | |||
| pass | |||
| def _logic_forward(self, xs, multiple_predictions=False): | |||
| if not multiple_predictions: | |||
| return self.logic_forward(xs) | |||
| else: | |||
| res = [self.logic_forward(x) for x in xs] | |||
| return res | |||
| def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): | |||
| if self.GKB_flag: | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) | |||
| return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| else: | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| @abstractmethod | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| pass | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| if self.base == {}: | |||
| return [], 0, 0 | |||
| def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): | |||
| if self.base == {} or len(pred_res) not in self.len_list: | |||
| return [] | |||
| all_candidates = self.base[len(pred_res)][key] | |||
| if len(all_candidates) == 0: | |||
| candidates = [] | |||
| min_address_num = 0 | |||
| address_num = 0 | |||
| if not multiple_predictions: | |||
| if len(pred_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._find_candidate_GKB(pred_res, key) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| else: | |||
| cost_list = hamming_dist(pred_res, all_candidates) | |||
| min_address_num = np.min(cost_list) | |||
| min_address_num = 0 | |||
| all_candidates_save = [] | |||
| cost_list_save = [] | |||
| for p_res, k in zip(pred_res, key): | |||
| if len(p_res) not in self.len_list: | |||
| return [], 0, 0 | |||
| all_candidates = self._find_candidate_GKB(p_res, k) | |||
| if len(all_candidates) == 0: | |||
| return [], 0, 0 | |||
| else: | |||
| all_candidates_save.append(all_candidates) | |||
| cost_list = hamming_dist(p_res, all_candidates) | |||
| min_address_num += np.min(cost_list) | |||
| cost_list_save.append(cost_list) | |||
| multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] | |||
| assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) | |||
| multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) | |||
| assert len(multiple_all_candidates) == len(multiple_cost_list) | |||
| address_num = min(max_address_num, min_address_num + require_more_address) | |||
| idxs = np.where(cost_list <= address_num)[0] | |||
| candidates = [all_candidates[idx] for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| idxs = np.where(multiple_cost_list <= address_num)[0] | |||
| candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] | |||
| return candidates, min_address_num, address_num | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| abduce_c = self.all_address_candidate_dict[len(address_idx)] | |||
| abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| @@ -176,10 +150,48 @@ class ClsKB(KBBase): | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| if self.logic_forward(candidate) == key: | |||
| if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): | |||
| candidates.append(candidate) | |||
| return candidates | |||
| def _address(self, address_num, pred_res, key, multiple_predictions): | |||
| new_candidates = [] | |||
| if not multiple_predictions: | |||
| address_idx_list = combinations(list(range(len(pred_res))), address_num) | |||
| else: | |||
| address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num) | |||
| for address_idx in address_idx_list: | |||
| candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) | |||
| new_candidates += candidates | |||
| return new_candidates | |||
| def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| candidates = [] | |||
| for address_num in range(len(flatten(pred_res)) + 1): | |||
| if address_num == 0: | |||
| if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): | |||
| candidates.append(pred_res) | |||
| else: | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| if len(candidates) > 0: | |||
| min_address_num = address_num | |||
| break | |||
| if address_num >= max_address_num: | |||
| return [], 0, 0 | |||
| for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): | |||
| if address_num > max_address_num: | |||
| return candidates, min_address_num, address_num - 1 | |||
| new_candidates = self._address(address_num, pred_res, key, multiple_predictions) | |||
| candidates += new_candidates | |||
| return candidates, min_address_num, address_num | |||
| def _dict_len(self, dic): | |||
| if not self.GKB_flag: | |||
| return 0 | |||
| @@ -193,130 +205,77 @@ class ClsKB(KBBase): | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class add_KB(ClsKB): | |||
| def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class ClsKB(KBBase): | |||
| def __init__(self, pseudo_label_list, len_list, GKB_flag): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| return self.base[len(pred_res)][key] | |||
| class HWF_KB(ClsKB): | |||
| def __init__( | |||
| self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] | |||
| ): | |||
| super().__init__(GKB_flag, pseudo_label_list, len_list) | |||
| def valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| class add_KB(ClsKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag) | |||
| def logic_forward(self, formula): | |||
| if not self.valid_candidate(formula): | |||
| return np.inf | |||
| mapping = { | |||
| '1': '1', | |||
| '2': '2', | |||
| '3': '3', | |||
| '4': '4', | |||
| '5': '5', | |||
| '6': '6', | |||
| '7': '7', | |||
| '8': '8', | |||
| '9': '9', | |||
| '+': '+', | |||
| '-': '-', | |||
| 'times': '*', | |||
| 'div': '/', | |||
| } | |||
| formula = [mapping[f] for f in formula] | |||
| return round(eval(''.join(formula)), 2) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class prolog_KB(KBBase): | |||
| def __init__(self, pseudo_label_list): | |||
| super().__init__() | |||
| self.pseudo_label_list = pseudo_label_list | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog = pyswip.Prolog() | |||
| self.prolog.consult(pl_file) | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): | |||
| return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) | |||
| def logic_forward(self, pseudo_labels): | |||
| result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] | |||
| if result == 'true': | |||
| return True | |||
| elif result == 'false': | |||
| return False | |||
| return result | |||
| def _address_pred_res(self, pred_res, address_idx, multiple_predictions): | |||
| import re | |||
| address_pred_res = pred_res.copy() | |||
| if multiple_predictions: | |||
| address_pred_res = flatten(address_pred_res) | |||
| for idx in address_idx: | |||
| address_pred_res[idx] = 'P' + str(idx) | |||
| if multiple_predictions: | |||
| address_pred_res = reform_idx(address_pred_res, pred_res) | |||
| # TODO:不知道有没有更简洁的方法 | |||
| regex = r"'P\d+'" | |||
| return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) | |||
| def get_query_string(self, pred_res, key, address_idx, multiple_predictions): | |||
| query_string = "logic_forward(" | |||
| query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) | |||
| key_is_none_flag = key is None or (type(key) == list and key[0] is None) | |||
| query_string += ",%s)." % key if not key_is_none_flag else ")." | |||
| return query_string | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| pass | |||
| def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): | |||
| candidates = [] | |||
| # print(address_idx) | |||
| if not multiple_predictions: | |||
| query_string = self.get_query_string(pred_res, key, address_idx) | |||
| else: | |||
| query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) | |||
| query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) | |||
| if multiple_predictions: | |||
| save_pred_res = pred_res | |||
| pred_res = flatten(pred_res) | |||
| abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] | |||
| abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] | |||
| for c in abduce_c: | |||
| candidate = pred_res.copy() | |||
| for i, idx in enumerate(address_idx): | |||
| candidate[idx] = c[i] | |||
| if multiple_predictions: | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| class add_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list=list(range(10))): | |||
| super().__init__(pseudo_label_list) | |||
| for i in self.pseudo_label_list: | |||
| self.prolog.assertz("pseudo_label(%s)" % i) | |||
| self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") | |||
| def logic_forward(self, nums): | |||
| return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] | |||
| def get_query_string(self, pred_res, key, address_idx): | |||
| query_string = "addition(" | |||
| for idx, i in enumerate(pred_res): | |||
| tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' | |||
| query_string += tmp | |||
| query_string += "%s)." % key | |||
| return query_string | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list=[0, 1, '+', '=']): | |||
| super().__init__(pseudo_label_list) | |||
| self.prolog.consult('./datasets/hed/learn_add.pl') | |||
| # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` | |||
| def logic_forward(self, exs): | |||
| return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 | |||
| def get_query_string_need_flatten(self, pred_res, key, address_idx): | |||
| # flatten | |||
| flatten_pred_res = flatten(pred_res) | |||
| # add variables for prolog | |||
| for idx in range(len(flatten_pred_res)): | |||
| if idx in address_idx: | |||
| flatten_pred_res[idx] = 'X' + str(idx) | |||
| # unflatten | |||
| new_pred_res = reform_idx(flatten_pred_res, pred_res) | |||
| query_string = "abduce_consistent_insts(%s)." % new_pred_res | |||
| return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("\'","") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| @@ -327,92 +286,79 @@ class HED_prolog_KB(prolog_KB): | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]['X'] | |||
| rules = [] | |||
| for rule in prolog_rules: | |||
| rules.append(rule.value) | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| # def consist_rules(self, pred_res, rules): | |||
| class RegKB(KBBase): | |||
| def __init__(self, GKB_flag=False, X=None, Y=None): | |||
| super().__init__() | |||
| tmp_dict = {} | |||
| for x, y in zip(X, Y): | |||
| tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) | |||
| self.base = {} | |||
| for l in tmp_dict.keys(): | |||
| data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) | |||
| X = [x for y, x in data] | |||
| Y = [y for y, x in data] | |||
| self.base[l] = (X, Y) | |||
| def valid_candidate(self): | |||
| pass | |||
| def logic_forward(self): | |||
| pass | |||
| def abduce_candidates(self, key, length=None): | |||
| if key is None: | |||
| return self.get_all_candidates() | |||
| length = self._length(length) | |||
| def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=1e-3): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| def _find_candidate_GKB(self, pred_res, key): | |||
| potential_candidates = self.base[len(pred_res)] | |||
| key_list = list(potential_candidates.keys()) | |||
| key_idx = bisect.bisect_left(key_list, key) | |||
| all_candidates = [] | |||
| for idx in range(key_idx - 1, 0, -1): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| for idx in range(key_idx, len(key_list)): | |||
| k = key_list[idx] | |||
| if abs(k - key) <= self.max_err: | |||
| all_candidates += potential_candidates[k] | |||
| else: | |||
| break | |||
| return all_candidates | |||
| min_err = 999999 | |||
| candidates = [] | |||
| for l in length: | |||
| X, Y = self.base[l] | |||
| idx = bisect.bisect_left(Y, key) | |||
| begin = max(0, idx - 1) | |||
| end = min(idx + 2, len(X)) | |||
| for idx in range(begin, end): | |||
| err = abs(Y[idx] - key) | |||
| if abs(err - min_err) < 1e-9: | |||
| candidates.extend(X[idx]) | |||
| elif err < min_err: | |||
| candidates = copy.deepcopy(X[idx]) | |||
| min_err = err | |||
| return candidates | |||
| class HWF_KB(RegKB): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| len_list=[1, 3, 5, 7], | |||
| GKB_flag=False, | |||
| max_err=1e-3 | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) | |||
| def get_all_candidates(self): | |||
| return sum([sum(D[0], []) for D in self.base.values()], []) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| def __len__(self): | |||
| return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| mapping = { | |||
| '1': '1', | |||
| '2': '2', | |||
| '3': '3', | |||
| '4': '4', | |||
| '5': '5', | |||
| '6': '6', | |||
| '7': '7', | |||
| '8': '8', | |||
| '9': '9', | |||
| '+': '+', | |||
| '-': '-', | |||
| 'times': '*', | |||
| 'div': '/', | |||
| } | |||
| formula = [mapping[f] for f in formula] | |||
| return eval(''.join(formula)) | |||
| import time | |||
| if __name__ == "__main__": | |||
| t1 = time.time() | |||
| kb = HWF_KB(True) | |||
| t2 = time.time() | |||
| print(t2 - t1) | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] | |||
| # Y = [2, 1, 1, 2, 2] | |||
| # kb = ClsKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(2, 5) | |||
| # print(res) | |||
| # res = kb.get_candidates(2, 3) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||
| # print() | |||
| # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] | |||
| # Y = [2, 1, 1, 2, 1.5, 1.5] | |||
| # kb = RegKB(X, Y) | |||
| # print('len(kb):', len(kb)) | |||
| # res = kb.get_candidates(1.6) | |||
| # print(res) | |||
| # res = kb.get_candidates(1.6, length = 9) | |||
| # print(res) | |||
| # res = kb.get_candidates(None) | |||
| # print(res) | |||
| pass | |||
| @@ -14,7 +14,7 @@ import pickle as pk | |||
| import numpy as np | |||
| from utils.plog import INFO, DEBUG, clocker | |||
| from .utils.plog import INFO, DEBUG, clocker | |||
| def block_sample(X, Z, Y, sample_num, epoch_idx): | |||
| part_num = (len(X) // sample_num) | |||
| @@ -16,12 +16,15 @@ import torch.nn as nn | |||
| import numpy as np | |||
| import os | |||
| from utils.plog import INFO, DEBUG, clocker | |||
| from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res | |||
| from .utils.plog import INFO, DEBUG, clocker | |||
| from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res | |||
| from models.nn import MLP, SymbolNetAutoencoder | |||
| from models.basic_model import BasicModel, BasicDataset | |||
| from datasets.hed.get_hed import get_pretrain_data | |||
| from .models.nn import MLP, SymbolNetAutoencoder | |||
| from .models.basic_model import BasicModel, BasicDataset | |||
| import sys | |||
| sys.path.append("..") | |||
| from examples.datasets.hed.get_hed import get_pretrain_data | |||
| def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): | |||
| result = {} | |||
| @@ -147,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||
| for m in mappings: | |||
| pred_res = mapping_res(original_pred_res, m) | |||
| max_abduce_num = 20 | |||
| solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num) | |||
| solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num) | |||
| all_address_flag = reform_idx(solution, pred_res) | |||
| consistent_idx_tmp = [] | |||
| @@ -155,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): | |||
| for idx in range(len(pred_res)): | |||
| address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] | |||
| candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) | |||
| candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx) | |||
| if len(candidate) > 0: | |||
| consistent_idx_tmp.append(idx) | |||
| consistent_pred_res_tmp.append(candidate[0][0]) | |||
| @@ -211,7 +214,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, | |||
| consistent_idx = [] | |||
| consistent_pred_res = [] | |||
| for idx in range(len(pred_res)): | |||
| if abducer.kb.logic_forward([pred_res[idx]]): | |||
| if abducer.kb.logic_forward(pred_res[idx]): | |||
| consistent_idx.append(idx) | |||
| consistent_pred_res.append(pred_res[idx]) | |||
| @@ -10,9 +10,6 @@ | |||
| # | |||
| # ================================================================# | |||
| import sys | |||
| sys.path.append("..") | |||
| import torchvision | |||
| @@ -23,8 +20,6 @@ from torch.autograd import Variable | |||
| import torchvision.transforms as transforms | |||
| import numpy as np | |||
| from models.basic_model import BasicModel | |||
| import utils.plog as plog | |||
| class LeNet5(nn.Module): | |||
| @@ -21,7 +21,6 @@ from sklearn.preprocessing import StandardScaler | |||
| from sklearn.svm import SVC | |||
| from sklearn.gaussian_process import GaussianProcessClassifier | |||
| from sklearn.gaussian_process.kernels import RBF | |||
| from models.basic_model import BasicModel | |||
| import pickle as pk | |||
| import random | |||
| @@ -1,30 +1,23 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import numpy as np | |||
| from utils.plog import INFO | |||
| from .plog import INFO | |||
| from collections import OrderedDict | |||
| from itertools import chain | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| # for multiple predictions | |||
| def flatten(l): | |||
| return ( | |||
| [item for sublist in l for item in flatten(sublist)] | |||
| if isinstance(l, list) | |||
| else [l] | |||
| ) | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| if not isinstance(l[0], (list, tuple)): | |||
| return l | |||
| return list(chain.from_iterable(l)) | |||
| # for multiple predictions | |||
| def reform_idx(flatten_pred_res, save_pred_res): | |||
| re = [] | |||
| i = 0 | |||
| for e in save_pred_res: | |||
| j = 0 | |||
| idx = [] | |||
| while j < len(e): | |||
| idx.append(flatten_pred_res[i + j]) | |||
| j += 1 | |||
| re.append(idx) | |||
| i = i + j | |||
| re.append(flatten_pred_res[i:i + len(e)]) | |||
| i += len(e) | |||
| return re | |||
| @@ -85,11 +78,10 @@ def remapping_res(pred_res, m): | |||
| remapping[value] = key | |||
| return [[remapping[symbol] for symbol in formula] for formula in pred_res] | |||
| def check_equal(a, b): | |||
| def check_equal(a, b, max_err=0): | |||
| if isinstance(a, (int, float)) and isinstance(b, (int, float)): | |||
| return abs(a - b) <= 1e-3 | |||
| return abs(a - b) <= max_err | |||
| if isinstance(a, list) and isinstance(b, list): | |||
| if len(a) != len(b): | |||
| return False | |||
| @@ -119,4 +111,4 @@ def reduce_dimension(data): | |||
| [extract_feature(symbol_img) for symbol_img in equation] | |||
| for equation in equations | |||
| ] | |||
| data[truth_value][equation_len] = reduced_equations | |||
| data[truth_value][equation_len] = reduced_equations | |||
| @@ -32,6 +32,9 @@ abduce_consistent_insts(Exs):- | |||
| % (Experimental) Uncomment to use parallel abduction | |||
| % abduce_consistent_exs_concurrent(Exs), !. | |||
| logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false. | |||
| logic_forward(Exs) :- abduce_consistent_insts(Exs). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce Delta_C given pseudo-labels | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| @@ -0,0 +1,2 @@ | |||
| pseudo_label(N) :- between(0, 9, N). | |||
| logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2. | |||
| @@ -10,24 +10,24 @@ | |||
| # | |||
| # ================================================================# | |||
| from utils.plog import logger, INFO | |||
| from utils.utils import reduce_dimension | |||
| import sys | |||
| sys.path.append("../") | |||
| from abl.utils.plog import logger, INFO | |||
| import torch.nn as nn | |||
| import torch | |||
| from models.nn import LeNet5, SymbolNet | |||
| from models.basic_model import BasicModel, BasicDataset | |||
| from models.wabl_models import DecisionTree, WABLBasicModel | |||
| from sklearn.neighbors import KNeighborsClassifier | |||
| from abl.models.nn import LeNet5, SymbolNet | |||
| from abl.models.basic_model import BasicModel, BasicDataset | |||
| from abl.models.wabl_models import DecisionTree, WABLBasicModel | |||
| from multiprocessing import Pool | |||
| from abducer.abducer_base import AbducerBase | |||
| from abducer.kb import add_KB, HWF_KB, HED_prolog_KB | |||
| from abl.abducer.abducer_base import AbducerBase | |||
| from abl.abducer.kb import add_KB, HWF_KB, prolog_KB | |||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
| from datasets.hwf.get_hwf import get_hwf | |||
| from datasets.hed.get_hed import get_hed, split_equation | |||
| import framework_hed | |||
| import framework_hed_knn | |||
| from abl import framework_hed | |||
| def run_test(): | |||
| @@ -36,7 +36,7 @@ def run_test(): | |||
| # kb = HWF_KB(True) | |||
| # abducer = AbducerBase(kb) | |||
| kb = HED_prolog_KB() | |||
| kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||
| abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
| recorder = logger() | |||