From fcda456cb60328b97f45afe00d928cb80cc13090 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Thu, 2 Mar 2023 15:05:37 +0800 Subject: [PATCH 01/31] update TODO in kb.py --- abducer/kb.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index b9beb3f..bbd6c93 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -137,6 +137,7 @@ class ClsKB(KBBase): def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: + # TODO: 这里有可能是multiple_predictions吗 return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) @@ -200,7 +201,7 @@ class add_KB(ClsKB): def logic_forward(self, nums): return sum(nums) - +# TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) class HWF_KB(ClsKB): def __init__( self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] @@ -334,7 +335,13 @@ class HED_prolog_KB(prolog_KB): # def consist_rules(self, pred_res, rules): - +# TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) +# TODO:我理解的RegKB是这样的: +# TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 +# TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 +# TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y +# TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 +# TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err class RegKB(KBBase): def __init__(self, GKB_flag=False, X=None, Y=None): super().__init__() @@ -355,7 +362,7 @@ class RegKB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, key, length=None): + def _abduce_by_GKB(self, key, length=None): if key is None: return self.get_all_candidates() From 37c1c446d9e3c25146c8d08db5db64b0b22ff213 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Fri, 3 Mar 2023 14:00:51 +0800 Subject: [PATCH 02/31] update package structure --- weights/all_weights_here.txt => abl/__init__.py | 0 abl/abducer/__init__.py | 0 {abducer => abl/abducer}/abducer_base.py | 11 ++++++----- {abducer => abl/abducer}/kb.py | 4 ++-- framework.py => abl/framework.py | 2 +- framework_hed.py => abl/framework_hed.py | 13 ++++++++----- abl/models/__init__.py | 0 {models => abl/models}/basic_model.py | 0 {models => abl/models}/lenet5.py | 0 {models => abl/models}/nn.py | 5 ----- {models => abl/models}/wabl_models.py | 1 - {utils => abl/utils}/plog.py | 0 {utils => abl/utils}/utils.py | 2 +- {datasets => examples/datasets}/data_generator.py | 0 {datasets => examples/datasets}/hed/BK.pl | 0 {datasets => examples/datasets}/hed/README.md | 0 {datasets => examples/datasets}/hed/get_hed.py | 0 {datasets => examples/datasets}/hed/learn_add.pl | 0 {datasets => examples/datasets}/hwf/README.md | 0 {datasets => examples/datasets}/hwf/get_hwf.py | 0 .../datasets}/mnist_add/get_mnist_add.py | 0 .../datasets}/mnist_add/test_data.txt | 0 .../datasets}/mnist_add/train_data.txt | 0 example.py => examples/example.py | 0 nonshare_example.py => examples/nonshare_example.py | 0 share_example.py => examples/share_example.py | 0 examples/weights/all_weights_here.txt | 0 27 files changed, 18 insertions(+), 20 deletions(-) rename weights/all_weights_here.txt => abl/__init__.py (100%) create mode 100644 abl/abducer/__init__.py rename {abducer => abl/abducer}/abducer_base.py (98%) rename {abducer => abl/abducer}/kb.py (99%) rename framework.py => abl/framework.py (98%) rename framework_hed.py => abl/framework_hed.py (97%) create mode 100644 abl/models/__init__.py rename {models => abl/models}/basic_model.py (100%) rename {models => abl/models}/lenet5.py (100%) rename {models => abl/models}/nn.py (97%) rename {models => abl/models}/wabl_models.py (99%) rename {utils => abl/utils}/plog.py (100%) rename {utils => abl/utils}/utils.py (98%) rename {datasets => examples/datasets}/data_generator.py (100%) rename {datasets => examples/datasets}/hed/BK.pl (100%) rename {datasets => examples/datasets}/hed/README.md (100%) rename {datasets => examples/datasets}/hed/get_hed.py (100%) rename {datasets => examples/datasets}/hed/learn_add.pl (100%) rename {datasets => examples/datasets}/hwf/README.md (100%) rename {datasets => examples/datasets}/hwf/get_hwf.py (100%) rename {datasets => examples/datasets}/mnist_add/get_mnist_add.py (100%) rename {datasets => examples/datasets}/mnist_add/test_data.txt (100%) rename {datasets => examples/datasets}/mnist_add/train_data.txt (100%) rename example.py => examples/example.py (100%) rename nonshare_example.py => examples/nonshare_example.py (100%) rename share_example.py => examples/share_example.py (100%) create mode 100644 examples/weights/all_weights_here.txt diff --git a/weights/all_weights_here.txt b/abl/__init__.py similarity index 100% rename from weights/all_weights_here.txt rename to abl/__init__.py diff --git a/abl/abducer/__init__.py b/abl/abducer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/abducer/abducer_base.py b/abl/abducer/abducer_base.py similarity index 98% rename from abducer/abducer_base.py rename to abl/abducer/abducer_base.py index c0fd102..9a483c9 100644 --- a/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -10,16 +10,17 @@ # # ================================================================# -import sys +# import sys -sys.path.append(".") -sys.path.append("..") +# sys.path.append(".") +# sys.path.append("..") import abc -from abducer.kb import * +# TODO 尽量别用import * +from .kb import * import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from utils.utils import confidence_dist, flatten, hamming_dist +from ..utils.utils import confidence_dist, flatten, hamming_dist import math import time diff --git a/abducer/kb.py b/abl/abducer/kb.py similarity index 99% rename from abducer/kb.py rename to abl/abducer/kb.py index bbd6c93..8fb727c 100644 --- a/abducer/kb.py +++ b/abl/abducer/kb.py @@ -21,7 +21,7 @@ sys.path.append("..") from collections import defaultdict from itertools import product, combinations -from utils.utils import flatten, reform_idx, hamming_dist, check_equal +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal from multiprocessing import Pool @@ -299,7 +299,7 @@ class add_prolog_KB(prolog_KB): class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list=[0, 1, '+', '=']): super().__init__(pseudo_label_list) - self.prolog.consult('./datasets/hed/learn_add.pl') + self.prolog.consult('../examples/datasets/hed/learn_add.pl') # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): diff --git a/framework.py b/abl/framework.py similarity index 98% rename from framework.py rename to abl/framework.py index 2a1fdad..c3a5c6e 100644 --- a/framework.py +++ b/abl/framework.py @@ -14,7 +14,7 @@ import pickle as pk import numpy as np -from utils.plog import INFO, DEBUG, clocker +from .utils.plog import INFO, DEBUG, clocker def block_sample(X, Z, Y, sample_num, epoch_idx): part_num = (len(X) // sample_num) diff --git a/framework_hed.py b/abl/framework_hed.py similarity index 97% rename from framework_hed.py rename to abl/framework_hed.py index b7439c3..d942be6 100644 --- a/framework_hed.py +++ b/abl/framework_hed.py @@ -16,12 +16,15 @@ import torch.nn as nn import numpy as np import os -from utils.plog import INFO, DEBUG, clocker -from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res +from .utils.plog import INFO, DEBUG, clocker +from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res -from models.nn import MLP, SymbolNetAutoencoder -from models.basic_model import BasicModel, BasicDataset -from datasets.hed.get_hed import get_pretrain_data +from .models.nn import MLP, SymbolNetAutoencoder +from .models.basic_model import BasicModel, BasicDataset + +import sys +sys.path.append("..") +from examples.datasets.hed.get_hed import get_pretrain_data def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): result = {} diff --git a/abl/models/__init__.py b/abl/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/basic_model.py b/abl/models/basic_model.py similarity index 100% rename from models/basic_model.py rename to abl/models/basic_model.py diff --git a/models/lenet5.py b/abl/models/lenet5.py similarity index 100% rename from models/lenet5.py rename to abl/models/lenet5.py diff --git a/models/nn.py b/abl/models/nn.py similarity index 97% rename from models/nn.py rename to abl/models/nn.py index cecbeea..7a0f560 100644 --- a/models/nn.py +++ b/abl/models/nn.py @@ -10,9 +10,6 @@ # # ================================================================# -import sys - -sys.path.append("..") import torchvision @@ -23,8 +20,6 @@ from torch.autograd import Variable import torchvision.transforms as transforms import numpy as np -from models.basic_model import BasicModel -import utils.plog as plog class LeNet5(nn.Module): diff --git a/models/wabl_models.py b/abl/models/wabl_models.py similarity index 99% rename from models/wabl_models.py rename to abl/models/wabl_models.py index 9c97bbc..3b682ee 100644 --- a/models/wabl_models.py +++ b/abl/models/wabl_models.py @@ -21,7 +21,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process.kernels import RBF -from models.basic_model import BasicModel import pickle as pk import random diff --git a/utils/plog.py b/abl/utils/plog.py similarity index 100% rename from utils/plog.py rename to abl/utils/plog.py diff --git a/utils/utils.py b/abl/utils/utils.py similarity index 98% rename from utils/utils.py rename to abl/utils/utils.py index 1138361..65c85fc 100644 --- a/utils/utils.py +++ b/abl/utils/utils.py @@ -1,5 +1,5 @@ import numpy as np -from utils.plog import INFO +from .plog import INFO from collections import OrderedDict # for multiple predictions, modify from `learn_add.py` diff --git a/datasets/data_generator.py b/examples/datasets/data_generator.py similarity index 100% rename from datasets/data_generator.py rename to examples/datasets/data_generator.py diff --git a/datasets/hed/BK.pl b/examples/datasets/hed/BK.pl similarity index 100% rename from datasets/hed/BK.pl rename to examples/datasets/hed/BK.pl diff --git a/datasets/hed/README.md b/examples/datasets/hed/README.md similarity index 100% rename from datasets/hed/README.md rename to examples/datasets/hed/README.md diff --git a/datasets/hed/get_hed.py b/examples/datasets/hed/get_hed.py similarity index 100% rename from datasets/hed/get_hed.py rename to examples/datasets/hed/get_hed.py diff --git a/datasets/hed/learn_add.pl b/examples/datasets/hed/learn_add.pl similarity index 100% rename from datasets/hed/learn_add.pl rename to examples/datasets/hed/learn_add.pl diff --git a/datasets/hwf/README.md b/examples/datasets/hwf/README.md similarity index 100% rename from datasets/hwf/README.md rename to examples/datasets/hwf/README.md diff --git a/datasets/hwf/get_hwf.py b/examples/datasets/hwf/get_hwf.py similarity index 100% rename from datasets/hwf/get_hwf.py rename to examples/datasets/hwf/get_hwf.py diff --git a/datasets/mnist_add/get_mnist_add.py b/examples/datasets/mnist_add/get_mnist_add.py similarity index 100% rename from datasets/mnist_add/get_mnist_add.py rename to examples/datasets/mnist_add/get_mnist_add.py diff --git a/datasets/mnist_add/test_data.txt b/examples/datasets/mnist_add/test_data.txt similarity index 100% rename from datasets/mnist_add/test_data.txt rename to examples/datasets/mnist_add/test_data.txt diff --git a/datasets/mnist_add/train_data.txt b/examples/datasets/mnist_add/train_data.txt similarity index 100% rename from datasets/mnist_add/train_data.txt rename to examples/datasets/mnist_add/train_data.txt diff --git a/example.py b/examples/example.py similarity index 100% rename from example.py rename to examples/example.py diff --git a/nonshare_example.py b/examples/nonshare_example.py similarity index 100% rename from nonshare_example.py rename to examples/nonshare_example.py diff --git a/share_example.py b/examples/share_example.py similarity index 100% rename from share_example.py rename to examples/share_example.py diff --git a/examples/weights/all_weights_here.txt b/examples/weights/all_weights_here.txt new file mode 100644 index 0000000..e69de29 From fe33d3b4eae6c7156c8984d20de8d10ad61cc8c7 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Fri, 3 Mar 2023 15:31:08 +0800 Subject: [PATCH 03/31] remove unnecessary lines --- abl/abducer/kb.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 8fb727c..9a66922 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -15,10 +15,6 @@ import bisect import copy import numpy as np -import sys - -sys.path.append("..") - from collections import defaultdict from itertools import product, combinations from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal From 7e4102c7091489e97648ff1f6b005f8d89f960c8 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:22:36 +0800 Subject: [PATCH 04/31] Update kb.py --- abl/abducer/kb.py | 373 +++++++++++++++++++++++++++++----------------- 1 file changed, 239 insertions(+), 134 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 9a66922..66a134f 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -31,6 +31,15 @@ class KBBase(ABC): @abstractmethod def logic_forward(self): pass + + def _logic_forward(self, xs, multiple_predictions=False): + if not multiple_predictions: + return self.logic_forward(xs) + else: + res = [] + for x in xs: + res.append(self.logic_forward(x)) + return res @abstractmethod def abduce_candidates(self): @@ -40,7 +49,7 @@ class KBBase(ABC): def address_by_idx(self): pass - def _address(self, address_num, pred_res, key, multiple_predictions=False): + def _address(self, address_num, pred_res, key, multiple_predictions): new_candidates = [] if not multiple_predictions: address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) @@ -52,12 +61,12 @@ class KBBase(ABC): new_candidates += candidates return new_candidates - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): candidates = [] for address_num in range(len(flatten(pred_res)) + 1): if address_num == 0: - if check_equal(self.logic_forward(pred_res), key): + if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): candidates.append(pred_res) else: new_candidates = self._address(address_num, pred_res, key, multiple_predictions) @@ -88,16 +97,14 @@ class ClsKB(KBBase): self.GKB_flag = GKB_flag self.pseudo_label_list = pseudo_label_list self.len_list = len_list + self.max_err = 0 if GKB_flag: self.base = {} X, Y = self._get_GKB() for x, y in zip(X, Y): self.base.setdefault(len(x), defaultdict(list))[y].append(x) - else: - self.all_address_candidate_dict = {} - for address_num in range(max(self.len_list) + 1): - self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num)) + # For parallel version of _get_GKB def _get_XY_list(self, args): @@ -133,33 +140,57 @@ class ClsKB(KBBase): def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: - # TODO: 这里有可能是multiple_predictions吗 - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address): - if self.base == {} or len(pred_res) not in self.len_list: - return [] - - all_candidates = self.base[len(pred_res)][key] + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + if self.base == {}: + return [], 0, 0 - if len(all_candidates) == 0: - candidates = [] - min_address_num = 0 - address_num = 0 + if not multiple_predictions: + if len(pred_res) not in self.len_list: + return [], 0, 0 + all_candidates = self.base[len(pred_res)][key] + if len(all_candidates) == 0: + return [], 0, 0 + else: + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num + else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) + min_address_num = 0 + all_candidates_save = [] + cost_list_save = [] + + for p_res, k in zip(pred_res, key): + if len(p_res) not in self.len_list: + return [], 0, 0 + all_candidates = self.base[len(p_res)][k] + if len(all_candidates) == 0: + return [], 0, 0 + else: + all_candidates_save.append(all_candidates) + cost_list = hamming_dist(p_res, all_candidates) + min_address_num += np.min(cost_list) + cost_list_save.append(cost_list) + + multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] + assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) + multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + assert len(multiple_all_candidates) == len(multiple_cost_list) address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - - return candidates, min_address_num, address_num + idxs = np.where(multiple_cost_list <= address_num)[0] + candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + return candidates, min_address_num, address_num def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - abduce_c = self.all_address_candidate_dict[len(address_idx)] + abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) if multiple_predictions: save_pred_res = pred_res @@ -173,7 +204,7 @@ class ClsKB(KBBase): if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if self.logic_forward(candidate) == key: + if check_equal(self._logic_forward(candidate, multiple_predictions), key): candidates.append(candidate) return candidates @@ -197,50 +228,13 @@ class add_KB(ClsKB): def logic_forward(self, nums): return sum(nums) -# TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) -class HWF_KB(ClsKB): - def __init__( - self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] - ): - super().__init__(GKB_flag, pseudo_label_list, len_list) - - def valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: - return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: - return False - return True - - def logic_forward(self, formula): - if not self.valid_candidate(formula): - return np.inf - mapping = { - '1': '1', - '2': '2', - '3': '3', - '4': '4', - '5': '5', - '6': '6', - '7': '7', - '8': '8', - '9': '9', - '+': '+', - '-': '-', - 'times': '*', - 'div': '/', - } - formula = [mapping[f] for f in formula] - return round(eval(''.join(formula)), 2) - class prolog_KB(KBBase): def __init__(self, pseudo_label_list): super().__init__() self.pseudo_label_list = pseudo_label_list self.prolog = pyswip.Prolog() + self.max_err = 0 def logic_forward(self): pass @@ -295,11 +289,11 @@ class add_prolog_KB(prolog_KB): class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list=[0, 1, '+', '=']): super().__init__(pseudo_label_list) - self.prolog.consult('../examples/datasets/hed/learn_add.pl') + self.prolog.consult('./datasets/hed/learn_add.pl') # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): - return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0 + return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 def get_query_string_need_flatten(self, pred_res, key, address_idx): # flatten @@ -329,93 +323,204 @@ class HED_prolog_KB(prolog_KB): rules.append(rule.value) return rules - # def consist_rules(self, pred_res, rules): -# TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) -# TODO:我理解的RegKB是这样的: -# TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 -# TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 -# TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y -# TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 -# TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err class RegKB(KBBase): - def __init__(self, GKB_flag=False, X=None, Y=None): + def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): super().__init__() - tmp_dict = {} - for x, y in zip(X, Y): - tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x)) - - self.base = {} - for l in tmp_dict.keys(): - data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values()))) - X = [x for y, x in data] - Y = [y for y, x in data] - self.base[l] = (X, Y) - - def valid_candidate(self): - pass + self.GKB_flag = GKB_flag + self.pseudo_label_list = pseudo_label_list + self.len_list = len_list + self.max_err = max_err + + if GKB_flag: + self.base = {} + X, Y = self._get_GKB() + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(x) + + # For parallel version of _get_GKB + def _get_XY_list(self, args): + pre_x, post_x_it = args[0], args[1] + XY_list = [] + for post_x in post_x_it: + x = (pre_x,) + post_x + y = self.logic_forward(x) + if y != np.inf: + XY_list.append((x, y)) + return XY_list + + # Parallel _get_GKB + def _get_GKB(self): + X, Y = [], [] + for length in self.len_list: + arg_list = [] + for pre_x in self.pseudo_label_list: + post_x_it = product(self.pseudo_label_list, repeat=length - 1) + arg_list.append((pre_x, post_x_it)) + with Pool(processes=len(arg_list)) as pool: + ret_list = pool.map(self._get_XY_list, arg_list) + for XY_list in ret_list: + if len(XY_list) == 0: + continue + part_X, part_Y = zip(*XY_list) + X.extend(part_X) + Y.extend(part_Y) + return X, Y def logic_forward(self): pass - def _abduce_by_GKB(self, key, length=None): - if key is None: - return self.get_all_candidates() + def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): + if self.GKB_flag: + return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) + else: + return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - length = self._length(length) + def _regression_find_candidate_GKB(self, pred_res, key): + potential_candidates = self.base[len(pred_res)] + key_list = sorted(potential_candidates) + key_idx = bisect.bisect_left(key_list, key) + + all_candidates = [] + for idx in range(key_idx - 1, 0, -1): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + + for idx in range(key_idx, len(key_list)): + k = key_list[idx] + if abs(k - key) <= self.max_err: + all_candidates += potential_candidates[k] + else: + break + return all_candidates + + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + if self.base == {}: + return [], 0, 0 - min_err = 999999 + if not multiple_predictions: + if len(pred_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._regression_find_candidate_GKB(pred_res, key) + if len(all_candidates) == 0: + return [], 0, 0 + else: + cost_list = hamming_dist(pred_res, all_candidates) + min_address_num = np.min(cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(cost_list <= address_num)[0] + candidates = [all_candidates[idx] for idx in idxs] + return candidates, min_address_num, address_num + + else: + min_address_num = 0 + all_candidates_save = [] + cost_list_save = [] + + for p_res, k in zip(pred_res, key): + if len(p_res) not in self.len_list: + return [], 0, 0 + all_candidates = self._regression_find_candidate_GKB(p_res, k) + if len(all_candidates) == 0: + return [], 0, 0 + else: + all_candidates_save.append(all_candidates) + cost_list = hamming_dist(p_res, all_candidates) + min_address_num += np.min(cost_list) + cost_list_save.append(cost_list) + + multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] + assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) + multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) + assert len(multiple_all_candidates) == len(multiple_cost_list) + address_num = min(max_address_num, min_address_num + require_more_address) + idxs = np.where(multiple_cost_list <= address_num)[0] + candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] + return candidates, min_address_num, address_num + + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - for l in length: - X, Y = self.base[l] - - idx = bisect.bisect_left(Y, key) - begin = max(0, idx - 1) - end = min(idx + 2, len(X)) - - for idx in range(begin, end): - err = abs(Y[idx] - key) - if abs(err - min_err) < 1e-9: - candidates.extend(X[idx]) - elif err < min_err: - candidates = copy.deepcopy(X[idx]) - min_err = err - return candidates + abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) - def get_all_candidates(self): - return sum([sum(D[0], []) for D in self.base.values()], []) + if multiple_predictions: + save_pred_res = pred_res + pred_res = flatten(pred_res) + + for c in abduce_c: + candidate = pred_res.copy() + for i, idx in enumerate(address_idx): + candidate[idx] = c[i] + + if multiple_predictions: + candidate = reform_idx(candidate, save_pred_res) + + if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): + candidates.append(candidate) + return candidates + + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) def __len__(self): - return sum([sum(len(x) for x in D[0]) for D in self.base.values()]) + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + +class HWF_KB(RegKB): + def __init__( + self, GKB_flag=False, + pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], + len_list=[1, 3, 5, 7], + max_err=1e-3 + ): + super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) + + def valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + return False + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + return False + return True + + def logic_forward(self, formula): + if not self.valid_candidate(formula): + return np.inf + mapping = { + '1': '1', + '2': '2', + '3': '3', + '4': '4', + '5': '5', + '6': '6', + '7': '7', + '8': '8', + '9': '9', + '+': '+', + '-': '-', + 'times': '*', + 'div': '/', + } + formula = [mapping[f] for f in formula] + return round(eval(''.join(formula)), 2) import time if __name__ == "__main__": t1 = time.time() - kb = HWF_KB(True) + kb = add_KB(True) t2 = time.time() print(t2 - t1) - # X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"] - # Y = [2, 1, 1, 2, 2] - # kb = ClsKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(2, 5) - # print(res) - # res = kb.get_candidates(2, 3) - # print(res) - # res = kb.get_candidates(None) - # print(res) - # print() - - # X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"] - # Y = [2, 1, 1, 2, 1.5, 1.5] - # kb = RegKB(X, Y) - # print('len(kb):', len(kb)) - # res = kb.get_candidates(1.6) - # print(res) - # res = kb.get_candidates(1.6, length = 9) - # print(res) - # res = kb.get_candidates(None) - # print(res) + From 09083f9c2e847ddc290522db02db1212e8da4be8 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:24:12 +0800 Subject: [PATCH 05/31] Update abducer_base.py --- abl/abducer/abducer_base.py | 192 +++++++++++++++++------------------- 1 file changed, 90 insertions(+), 102 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 9a483c9..811f0ea 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -10,33 +10,16 @@ # # ================================================================# -# import sys - -# sys.path.append(".") -# sys.path.append("..") - import abc -# TODO 尽量别用import * from .kb import * import numpy as np from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import confidence_dist, flatten, hamming_dist -import math -import time - - class AbducerBase(abc.ABC): - def __init__( - self, - kb, - dist_func="confidence", - zoopt=False, - multiple_predictions=False, - cache=True, - ): + def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): self.kb = kb - assert dist_func == "hamming" or dist_func == "confidence" + assert dist_func == 'hamming' or dist_func == 'confidence' self.dist_func = dist_func self.zoopt = zoopt self.multiple_predictions = multiple_predictions @@ -47,41 +30,42 @@ class AbducerBase(abc.ABC): self.cache_candidates = {} def _get_cost_list(self, pred_res, pred_res_prob, candidates): - if self.dist_func == "hamming": + if self.dist_func == 'hamming': + if self.multiple_predictions: + pred_res = flatten(pred_res) + candidates = [flatten(c) for c in candidates] + return hamming_dist(pred_res, candidates) - elif self.dist_func == "confidence": - mapping = dict( - zip( - self.kb.pseudo_label_list, - list(range(len(self.kb.pseudo_label_list))), - ) - ) - return confidence_dist( - pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates] - ) + + elif self.dist_func == 'confidence': + if self.multiple_predictions: + pred_res_prob = flatten(pred_res_prob) + candidates = [flatten(c) for c in candidates] + + mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list))))) + candidates = [list(map(lambda x: mapping[x], c)) for c in candidates] + return confidence_dist(pred_res_prob, candidates) def _get_one_candidate(self, pred_res, pred_res_prob, candidates): if len(candidates) == 0: return [] elif len(candidates) == 1 or self.zoopt: return candidates[0] + else: cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) min_address_num = np.min(cost_list) idxs = np.where(cost_list == min_address_num)[0] - return [candidates[idx] for idx in idxs][0] + candidate = [candidates[idx] for idx in idxs][0] + return candidate # for zoopt def _zoopt_score_multiple(self, pred_res, key, solution): all_address_flag = reform_idx(solution, pred_res) score = 0 for idx in range(len(pred_res)): - address_idx = [ - i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 - ] - candidate = self.kb.address_by_idx( - [pred_res[idx]], key[idx], address_idx, True - ) + address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] + candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) if len(candidate) > 0: score += 1 return score @@ -89,9 +73,7 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx( - pred_res, key, address_idx, self.multiple_predictions - ) + candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -108,7 +90,7 @@ class AbducerBase(abc.ABC): dim=dimension, constraint=lambda sol: self._constrain_address_num(sol, max_address_num), ) - parameter = Parameter(budget=100, autoset=True) + parameter = Parameter(budget=100, intermediate_result=False, autoset=True) solution = Opt.min(objective, parameter).get_x() return solution @@ -119,11 +101,7 @@ class AbducerBase(abc.ABC): pred_res = flatten(pred_res) key = tuple(key) if (tuple(pred_res), key) in self.cache_min_address_num: - address_num = min( - max_address_num, - self.cache_min_address_num[(tuple(pred_res), key)] - + require_more_address, - ) + address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address) if (tuple(pred_res), key, address_num) in self.cache_candidates: candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] if self.zoopt: @@ -152,18 +130,12 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx( - pred_res, key, address_idx, self.multiple_predictions - ) + candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) address_num = int(solution.sum()) min_address_num = address_num else: candidates, min_address_num, address_num = self.kb.abduce_candidates( - pred_res, - key, - max_address_num, - require_more_address, - self.multiple_predictions, + pred_res, key, max_address_num, require_more_address, self.multiple_predictions ) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) @@ -177,32 +149,21 @@ class AbducerBase(abc.ABC): return self.kb.abduce_rules(pred_res) def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): - if self.multiple_predictions: - return self.abduce( - (Z["cls"], Z["prob"], Y), max_address_num, require_more_address - ) - else: - return [ - self.abduce((z, prob, y), max_address_num, require_more_address) - for z, prob, y in zip(Z["cls"], Z["prob"], Y) - ] + # if self.multiple_predictions: + return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address) + # else: + # return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)] def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) -if __name__ == "__main__": - prob1 = [ - [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] - prob2 = [ - [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], - [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], - ] +if __name__ == '__main__': + prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] + prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] - kb = add_KB() - abd = AbducerBase(kb, "confidence") + kb = add_KB(True) + abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -214,9 +175,23 @@ if __name__ == "__main__": res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) print(res) print() + + + multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], + [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + + + kb = add_KB() + abd = AbducerBase(kb, 'confidence', multiple_predictions=True) + res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0) + print(res) + res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1) + print(res) + print() + kb = add_prolog_KB() - abd = AbducerBase(kb, "confidence") + abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -230,7 +205,7 @@ if __name__ == "__main__": print() kb = add_prolog_KB() - abd = AbducerBase(kb, "confidence", zoopt=True) + abd = AbducerBase(kb, 'confidence', zoopt=True) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) @@ -243,42 +218,55 @@ if __name__ == "__main__": print(res) print() - kb = HWF_KB(len_list=[1, 3, 5]) - abd = AbducerBase(kb, "hamming") - res = abd.abduce( - (["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0 - ) + kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + print(res) + + kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) + print(res) + res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) + print(res) + print() + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming', multiple_predictions=True) + res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0) + print(res) + print() + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) + print(res) + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0 - ) + + kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) + abd = AbducerBase(kb, 'hamming') + res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0 - ) + res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0) print(res) - res = abd.abduce( - (["5", "8", "8", "8", "8"], None, 3.17), - max_address_num=5, - require_more_address=3, - ) + res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3) print(res) print() kb = HED_prolog_KB() abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) - consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]] - consist_exs2 = [ - [1, "+", 0, "=", 0], - [1, "+", 1, "=", 0], - [0, "+", 1, "=", 1, 1], - ] # not consistent with rules - inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] + consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] + inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] - rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"] + rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])'] - print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs)) - print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules)) + print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True)) + print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() res = abd.abduce((consist_exs, None, [1] * len(consist_exs))) From 6ead422a1c8f0a12117370b6d83e79ca3dd8a229 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:24:43 +0800 Subject: [PATCH 06/31] Update utils.py --- abl/utils/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 65c85fc..05b97a4 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -4,7 +4,10 @@ from collections import OrderedDict # for multiple predictions, modify from `learn_add.py` def flatten(l): - return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] + # return [item for sublist in l for item in flatten(sublist)] if isinstance(l, (list, tuple)) else [l] + if not isinstance(l[0], (list, tuple)): + return l + return [item for sublist in l for item in sublist] if isinstance(l, (list, tuple)) else [l] # for multiple predictions, modify from `learn_add.py` def reform_idx(flatten_pred_res, save_pred_res): From af8753bdcacec34af50a39840252c6aad9c84d2b Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:36:26 +0800 Subject: [PATCH 07/31] Update abducer_base.py --- abl/abducer/abducer_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 811f0ea..36a0e76 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -11,7 +11,6 @@ # ================================================================# import abc -from .kb import * import numpy as np from zoopt import Dimension, Objective, Parameter, Opt from ..utils.utils import confidence_dist, flatten, hamming_dist @@ -159,6 +158,8 @@ class AbducerBase(abc.ABC): if __name__ == '__main__': + from kb import add_KB, add_prolog_KB, HWF_KB, HED_prolog_KB + prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] From 956d64ba1a6da8ea3ab784199df50110acc04c48 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:38:25 +0800 Subject: [PATCH 08/31] Update abducer_base.py --- abl/abducer/abducer_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 36a0e76..9fcfa0a 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -13,7 +13,7 @@ import abc import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import confidence_dist, flatten, hamming_dist +from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist class AbducerBase(abc.ABC): def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True): From cf15052331af5b89678cb397ffa09b98924ffa2f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:48:46 +0800 Subject: [PATCH 09/31] Update utils.py --- abl/utils/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 05b97a4..fbb5bfb 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -76,9 +76,9 @@ def remapping_res(pred_res, m): remapping[value] = key return [[remapping[symbol] for symbol in formula] for formula in pred_res] -def check_equal(a, b): +def check_equal(a, b, max_err=0): if isinstance(a, (int, float)) and isinstance(b, (int, float)): - return abs(a - b) <= 1e-3 + return abs(a - b) <= max_err if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): @@ -89,4 +89,4 @@ def check_equal(a, b): return True else: - return a == b + return a == b From 930b549d4eab631fd700a90c1cfce867d13feb49 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 17:01:31 +0800 Subject: [PATCH 10/31] Update kb.py --- abl/abducer/kb.py | 170 +++++++++++++++------------------------------- 1 file changed, 56 insertions(+), 114 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 66a134f..cd15dd3 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -25,8 +25,46 @@ import pyswip class KBBase(ABC): - def __init__(self, pseudo_label_list=None): - pass + def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): + self.pseudo_label_list = pseudo_label_list + self.len_list = len_list + self.GKB_flag = GKB_flag + self.max_err = max_err + + if GKB_flag: + self.base = {} + X, Y = self._get_GKB() + for x, y in zip(X, Y): + self.base.setdefault(len(x), defaultdict(list))[y].append(x) + + # For parallel version of _get_GKB + def _get_XY_list(self, args): + pre_x, post_x_it = args[0], args[1] + XY_list = [] + for post_x in post_x_it: + x = (pre_x,) + post_x + y = self.logic_forward(x) + if y != np.inf: + XY_list.append((x, y)) + return XY_list + + # Parallel _get_GKB + def _get_GKB(self): + X, Y = [], [] + for length in self.len_list: + arg_list = [] + for pre_x in self.pseudo_label_list: + post_x_it = product(self.pseudo_label_list, repeat=length - 1) + arg_list.append((pre_x, post_x_it)) + with Pool(processes=len(arg_list)) as pool: + ret_list = pool.map(self._get_XY_list, arg_list) + for XY_list in ret_list: + if len(XY_list) == 0: + continue + part_X, part_Y = zip(*XY_list) + X.extend(part_X) + Y.extend(part_Y) + return X, Y @abstractmethod def logic_forward(self): @@ -87,53 +125,22 @@ class KBBase(ABC): return candidates, min_address_num, address_num + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + def __len__(self): - pass + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) class ClsKB(KBBase): - def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None): - super().__init__() - self.GKB_flag = GKB_flag - self.pseudo_label_list = pseudo_label_list - self.len_list = len_list - self.max_err = 0 - - if GKB_flag: - self.base = {} - X, Y = self._get_GKB() - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - - # For parallel version of _get_GKB - def _get_XY_list(self, args): - pre_x, post_x_it = args[0], args[1] - XY_list = [] - for post_x in post_x_it: - x = (pre_x,) + post_x - y = self.logic_forward(x) - if y != np.inf: - XY_list.append((x, y)) - return XY_list - - # Parallel _get_GKB - def _get_GKB(self): - X, Y = [], [] - for length in self.len_list: - arg_list = [] - for pre_x in self.pseudo_label_list: - post_x_it = product(self.pseudo_label_list, repeat=length - 1) - arg_list.append((pre_x, post_x_it)) - with Pool(processes=len(arg_list)) as pool: - ret_list = pool.map(self._get_XY_list, arg_list) - for XY_list in ret_list: - if len(XY_list) == 0: - continue - part_X, part_Y = zip(*XY_list) - X.extend(part_X) - Y.extend(part_Y) - return X, Y + def __init__(self, pseudo_label_list, len_list, GKB_flag): + super().__init__(pseudo_label_list, len_list, GKB_flag) def logic_forward(self): pass @@ -208,22 +215,10 @@ class ClsKB(KBBase): candidates.append(candidate) return candidates - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - class add_KB(ClsKB): - def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]): - super().__init__(GKB_flag, pseudo_label_list, len_list) + def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): + super().__init__(pseudo_label_list, len_list, GKB_flag) def logic_forward(self, nums): return sum(nums) @@ -231,10 +226,8 @@ class add_KB(ClsKB): class prolog_KB(KBBase): def __init__(self, pseudo_label_list): - super().__init__() - self.pseudo_label_list = pseudo_label_list + super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() - self.max_err = 0 def logic_forward(self): pass @@ -326,46 +319,7 @@ class HED_prolog_KB(prolog_KB): class RegKB(KBBase): def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): - super().__init__() - self.GKB_flag = GKB_flag - self.pseudo_label_list = pseudo_label_list - self.len_list = len_list - self.max_err = max_err - - if GKB_flag: - self.base = {} - X, Y = self._get_GKB() - for x, y in zip(X, Y): - self.base.setdefault(len(x), defaultdict(list))[y].append(x) - - # For parallel version of _get_GKB - def _get_XY_list(self, args): - pre_x, post_x_it = args[0], args[1] - XY_list = [] - for post_x in post_x_it: - x = (pre_x,) + post_x - y = self.logic_forward(x) - if y != np.inf: - XY_list.append((x, y)) - return XY_list - - # Parallel _get_GKB - def _get_GKB(self): - X, Y = [], [] - for length in self.len_list: - arg_list = [] - for pre_x in self.pseudo_label_list: - post_x_it = product(self.pseudo_label_list, repeat=length - 1) - arg_list.append((pre_x, post_x_it)) - with Pool(processes=len(arg_list)) as pool: - ret_list = pool.map(self._get_XY_list, arg_list) - for XY_list in ret_list: - if len(XY_list) == 0: - continue - part_X, part_Y = zip(*XY_list) - X.extend(part_X) - Y.extend(part_Y) - return X, Y + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) def logic_forward(self): pass @@ -461,18 +415,6 @@ class RegKB(KBBase): candidates.append(candidate) return candidates - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - class HWF_KB(RegKB): def __init__( From 173138f64802a0ea02fdd2a5a51137bb68e65638 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 17:02:04 +0800 Subject: [PATCH 11/31] Update kb.py --- abl/abducer/kb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index cd15dd3..7ac26c8 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -461,7 +461,7 @@ import time if __name__ == "__main__": t1 = time.time() - kb = add_KB(True) + kb = add_KB(GKB_flag=True) t2 = time.time() print(t2 - t1) From 4478e4e3de5ea0ea0dec93562750dc334eec5e8a Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 18:40:15 +0800 Subject: [PATCH 12/31] Rearrange abduce_by_GKB and address_by_idx to Base --- abl/abducer/kb.py | 258 ++++++++++++++-------------------------------- 1 file changed, 80 insertions(+), 178 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 7ac26c8..7c494c5 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -10,20 +10,6 @@ # # ================================================================# -from abc import ABC, abstractmethod -import bisect -import copy -import numpy as np - -from collections import defaultdict -from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal - -from multiprocessing import Pool - -import pyswip - - class KBBase(ABC): def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list @@ -79,78 +65,16 @@ class KBBase(ABC): res.append(self.logic_forward(x)) return res - @abstractmethod - def abduce_candidates(self): - pass - - @abstractmethod - def address_by_idx(self): - pass - - def _address(self, address_num, pred_res, key, multiple_predictions): - new_candidates = [] - if not multiple_predictions: - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) - else: - address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) - - for address_idx in address_idx_list: - candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) - new_candidates += candidates - return new_candidates - - def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - candidates = [] - - for address_num in range(len(flatten(pred_res)) + 1): - if address_num == 0: - if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): - candidates.append(pred_res) - else: - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - if len(candidates) > 0: - min_address_num = address_num - break - - if address_num >= max_address_num: - return [], 0, 0 - - for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): - if address_num > max_address_num: - return candidates, min_address_num, address_num - 1 - new_candidates = self._address(address_num, pred_res, key, multiple_predictions) - candidates += new_candidates - - return candidates, min_address_num, address_num - - def _dict_len(self, dic): - if not self.GKB_flag: - return 0 - else: - return sum(len(c) for c in dic.values()) - - def __len__(self): - if not self.GKB_flag: - return 0 - else: - return sum(self._dict_len(v) for v in self.base.values()) - - -class ClsKB(KBBase): - def __init__(self, pseudo_label_list, len_list, GKB_flag): - super().__init__(pseudo_label_list, len_list, GKB_flag) - - def logic_forward(self): - pass - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - + + @abstractmethod + def _find_candidate_GKB(self): + pass + def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): if self.base == {}: return [], 0, 0 @@ -158,7 +82,7 @@ class ClsKB(KBBase): if not multiple_predictions: if len(pred_res) not in self.len_list: return [], 0, 0 - all_candidates = self.base[len(pred_res)][key] + all_candidates = self._find_candidate_GKB(pred_res, key) if len(all_candidates) == 0: return [], 0, 0 else: @@ -168,7 +92,7 @@ class ClsKB(KBBase): idxs = np.where(cost_list <= address_num)[0] candidates = [all_candidates[idx] for idx in idxs] return candidates, min_address_num, address_num - + else: min_address_num = 0 all_candidates_save = [] @@ -177,7 +101,7 @@ class ClsKB(KBBase): for p_res, k in zip(pred_res, key): if len(p_res) not in self.len_list: return [], 0, 0 - all_candidates = self.base[len(p_res)][k] + all_candidates = self._regression_find_candidate_GKB(p_res, k) if len(all_candidates) == 0: return [], 0, 0 else: @@ -194,7 +118,7 @@ class ClsKB(KBBase): idxs = np.where(multiple_cost_list <= address_num)[0] candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates, min_address_num, address_num - + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) @@ -211,10 +135,71 @@ class ClsKB(KBBase): if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - if check_equal(self._logic_forward(candidate, multiple_predictions), key): + if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): candidates.append(candidate) return candidates + def _address(self, address_num, pred_res, key, multiple_predictions): + new_candidates = [] + if not multiple_predictions: + address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + else: + address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) + + for address_idx in address_idx_list: + candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) + new_candidates += candidates + return new_candidates + + def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): + candidates = [] + + for address_num in range(len(flatten(pred_res)) + 1): + if address_num == 0: + if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err): + candidates.append(pred_res) + else: + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + if len(candidates) > 0: + min_address_num = address_num + break + + if address_num >= max_address_num: + return [], 0, 0 + + for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1): + if address_num > max_address_num: + return candidates, min_address_num, address_num - 1 + new_candidates = self._address(address_num, pred_res, key, multiple_predictions) + candidates += new_candidates + + return candidates, min_address_num, address_num + + def _dict_len(self, dic): + if not self.GKB_flag: + return 0 + else: + return sum(len(c) for c in dic.values()) + + def __len__(self): + if not self.GKB_flag: + return 0 + else: + return sum(self._dict_len(v) for v in self.base.values()) + + +class ClsKB(KBBase): + def __init__(self, pseudo_label_list, len_list, GKB_flag): + super().__init__(pseudo_label_list, len_list, GKB_flag) + + def logic_forward(self): + pass + + def _find_candidate_GKB(self, pred_res, key): + return self.base[len(pred_res)][key] + class add_KB(ClsKB): def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False): @@ -232,16 +217,13 @@ class prolog_KB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - + def _find_candidate_GKB(self): + pass + def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] # print(address_idx) - if not multiple_predictions: - query_string = self.get_query_string(pred_res, key, address_idx) - else: - query_string = self.get_query_string_need_flatten(pred_res, key, address_idx) + query_string = self.get_query_string(pred_res, key, address_idx) if multiple_predictions: save_pred_res = pred_res @@ -284,21 +266,18 @@ class HED_prolog_KB(prolog_KB): super().__init__(pseudo_label_list) self.prolog.consult('./datasets/hed/learn_add.pl') - # corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py` def logic_forward(self, exs): return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 - def get_query_string_need_flatten(self, pred_res, key, address_idx): - # flatten + def get_query_string(self, pred_res, key, address_idx): flatten_pred_res = flatten(pred_res) # add variables for prolog for idx in range(len(flatten_pred_res)): if idx in address_idx: flatten_pred_res[idx] = 'X' + str(idx) - # unflatten - new_pred_res = reform_idx(flatten_pred_res, pred_res) + pred_res = reform_idx(flatten_pred_res, pred_res) - query_string = "abduce_consistent_insts(%s)." % new_pred_res + query_string = "abduce_consistent_insts(%s)." % pred_res return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") def consist_rule(self, exs, rules): @@ -324,13 +303,7 @@ class RegKB(KBBase): def logic_forward(self): pass - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): - if self.GKB_flag: - return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) - else: - return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) - - def _regression_find_candidate_GKB(self, pred_res, key): + def _find_candidate_GKB(self, pred_res, key): potential_candidates = self.base[len(pred_res)] key_list = sorted(potential_candidates) key_idx = bisect.bisect_left(key_list, key) @@ -351,70 +324,6 @@ class RegKB(KBBase): break return all_candidates - def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): - if self.base == {}: - return [], 0, 0 - - if not multiple_predictions: - if len(pred_res) not in self.len_list: - return [], 0, 0 - all_candidates = self._regression_find_candidate_GKB(pred_res, key) - if len(all_candidates) == 0: - return [], 0, 0 - else: - cost_list = hamming_dist(pred_res, all_candidates) - min_address_num = np.min(cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(cost_list <= address_num)[0] - candidates = [all_candidates[idx] for idx in idxs] - return candidates, min_address_num, address_num - - else: - min_address_num = 0 - all_candidates_save = [] - cost_list_save = [] - - for p_res, k in zip(pred_res, key): - if len(p_res) not in self.len_list: - return [], 0, 0 - all_candidates = self._regression_find_candidate_GKB(p_res, k) - if len(all_candidates) == 0: - return [], 0, 0 - else: - all_candidates_save.append(all_candidates) - cost_list = hamming_dist(p_res, all_candidates) - min_address_num += np.min(cost_list) - cost_list_save.append(cost_list) - - multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)] - assert len(multiple_all_candidates[0]) == len(flatten(pred_res)) - multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)]) - assert len(multiple_all_candidates) == len(multiple_cost_list) - address_num = min(max_address_num, min_address_num + require_more_address) - idxs = np.where(multiple_cost_list <= address_num)[0] - candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] - return candidates, min_address_num, address_num - - def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): - candidates = [] - abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) - - if multiple_predictions: - save_pred_res = pred_res - pred_res = flatten(pred_res) - - for c in abduce_c: - candidate = pred_res.copy() - for i, idx in enumerate(address_idx): - candidate[idx] = c[i] - - if multiple_predictions: - candidate = reform_idx(candidate, save_pred_res) - - if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err): - candidates.append(candidate) - return candidates - class HWF_KB(RegKB): def __init__( @@ -456,13 +365,6 @@ class HWF_KB(RegKB): formula = [mapping[f] for f in formula] return round(eval(''.join(formula)), 2) - -import time - if __name__ == "__main__": - t1 = time.time() - kb = add_KB(GKB_flag=True) - t2 = time.time() - print(t2 - t1) - + pass From 9d9684723846c3e63af3913e78ee45d39e474665 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 18:42:19 +0800 Subject: [PATCH 13/31] Rearrange abduce_by_GKB and address_by_idx to Base --- abl/abducer/kb.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 7c494c5..8db7bba 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -10,6 +10,19 @@ # # ================================================================# +from abc import ABC, abstractmethod +import bisect +import copy +import numpy as np + +from collections import defaultdict +from itertools import product, combinations +from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal + +from multiprocessing import Pool + +import pyswip + class KBBase(ABC): def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list From 1d1873b6eba87726043838ed69cad12f997a64a6 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 18:56:24 +0800 Subject: [PATCH 14/31] Rearrange sorted in RegKB --- abl/abducer/kb.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 8db7bba..137647b 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -63,6 +63,9 @@ class KBBase(ABC): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) + sorted_XY = sorted(list(zip(Y, X))) + X = [x for y, x in sorted_XY] + Y = [y for y, x in sorted_XY] return X, Y @abstractmethod @@ -318,7 +321,7 @@ class RegKB(KBBase): def _find_candidate_GKB(self, pred_res, key): potential_candidates = self.base[len(pred_res)] - key_list = sorted(potential_candidates) + key_list = list(potential_candidates.keys()) key_idx = bisect.bisect_left(key_list, key) all_candidates = [] @@ -378,6 +381,9 @@ class HWF_KB(RegKB): formula = [mapping[f] for f in formula] return round(eval(''.join(formula)), 2) + +import time + if __name__ == "__main__": pass From ce406011bf538f0ec5567f360a9079113c192fe1 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 18:59:18 +0800 Subject: [PATCH 15/31] Rearrange sorted in RegKB --- abl/abducer/kb.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 137647b..3be0875 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -63,9 +63,10 @@ class KBBase(ABC): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - sorted_XY = sorted(list(zip(Y, X))) - X = [x for y, x in sorted_XY] - Y = [y for y, x in sorted_XY] + if self.max_err != 0: + sorted_XY = sorted(list(zip(Y, X))) + X = [x for y, x in sorted_XY] + Y = [y for y, x in sorted_XY] return X, Y @abstractmethod From 253c9e20ba5b1fa75815405439513d5976bd66a3 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 19:08:21 +0800 Subject: [PATCH 16/31] Update kb.py --- abl/abducer/kb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 3be0875..7eee9f5 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -118,7 +118,7 @@ class KBBase(ABC): for p_res, k in zip(pred_res, key): if len(p_res) not in self.len_list: return [], 0, 0 - all_candidates = self._regression_find_candidate_GKB(p_res, k) + all_candidates = self._find_candidate_GKB(p_res, k) if len(all_candidates) == 0: return [], 0, 0 else: From 1cd3570693f18225f80566af925e14dd5be40b27 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Fri, 3 Mar 2023 20:02:15 +0800 Subject: [PATCH 17/31] make code short --- abl/abducer/kb.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 3be0875..627d6c8 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -77,9 +77,7 @@ class KBBase(ABC): if not multiple_predictions: return self.logic_forward(xs) else: - res = [] - for x in xs: - res.append(self.logic_forward(x)) + res = [self.logic_forward(x) for x in xs] return res def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): From ea66df54fbf5a1987c1c76c8edbc7e5ab336e40d Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Fri, 3 Mar 2023 20:04:19 +0800 Subject: [PATCH 18/31] remove round in reg --- abl/abducer/kb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 4178c4e..0cff3fa 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -378,7 +378,7 @@ class HWF_KB(RegKB): 'div': '/', } formula = [mapping[f] for f in formula] - return round(eval(''.join(formula)), 2) + return eval(''.join(formula)) import time From 53586d94f862c86ed0c9bec494801f410fc472fd Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 3 Mar 2023 21:53:06 +0800 Subject: [PATCH 19/31] Create add.pl --- examples/datasets/mnist_add/add.pl | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 examples/datasets/mnist_add/add.pl diff --git a/examples/datasets/mnist_add/add.pl b/examples/datasets/mnist_add/add.pl new file mode 100644 index 0000000..96f0869 --- /dev/null +++ b/examples/datasets/mnist_add/add.pl @@ -0,0 +1,2 @@ +pseudo_label(N) :- between(0, 9, N). +logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2. From 8ba7c509e80c2cc7c01914a53d832ae8d13360dc Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 05:27:41 +0000 Subject: [PATCH 20/31] Update prologKB --- abl/abducer/abducer_base.py | 13 +++--- abl/abducer/kb.py | 63 ++++++++++++++---------------- abl/framework_hed.py | 4 +- examples/datasets/hed/learn_add.pl | 3 ++ 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 9fcfa0a..777533a 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -163,7 +163,7 @@ if __name__ == '__main__': prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] - kb = add_KB(True) + kb = add_KB(GKB_flag=True) abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) @@ -219,14 +219,15 @@ if __name__ == '__main__': print(res) print() - kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 0.1) + kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 0.1) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) print(res) res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) + print() - kb = HWF_KB(True, len_list=[1, 3, 5], max_err = 1) + kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 1) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) @@ -270,11 +271,11 @@ if __name__ == '__main__': print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print() - res = abd.abduce((consist_exs, None, [1] * len(consist_exs))) + res = abd.abduce((consist_exs, None, [None] * len(consist_exs))) print(res) - res = abd.abduce((inconsist_exs, None, [1] * len(consist_exs))) + res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs))) print(res) print() abduced_rules = abd.abduce_rules(consist_exs) - print(abduced_rules) + print(abduced_rules) \ No newline at end of file diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 0cff3fa..e23e637 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -229,30 +229,52 @@ class prolog_KB(KBBase): super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() - def logic_forward(self): - pass + def logic_forward(self, pseudo_labels): + result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] + if result == 'true': + return True + elif result == 'false': + return False + return result + + def _address_pred_res(self, pred_res, address_idx, multiple_predictions): + import re + address_pred_res = pred_res.copy() + if multiple_predictions: + address_pred_res = flatten(address_pred_res) + + for idx in range(len(address_pred_res)): + if idx in address_idx: + address_pred_res[idx] = 'P' + str(idx) + if multiple_predictions: + address_pred_res = reform_idx(address_pred_res, pred_res) + + regex = r"'P\d+'" + return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) + + def get_query_string(self, pred_res, key, address_idx, multiple_predictions): + query_string = "logic_forward(" + query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions) + key_is_none_flag = key is None or (type(key) == list and key[0] is None) + query_string += ",%s)." % key if not key_is_none_flag else ")." + return query_string def _find_candidate_GKB(self): pass def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - # print(address_idx) - query_string = self.get_query_string(pred_res, key, address_idx) - + query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions) if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) - abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): candidate[idx] = c[i] - if multiple_predictions: candidate = reform_idx(candidate, save_pred_res) - candidates.append(candidate) return candidates @@ -264,37 +286,12 @@ class add_prolog_KB(prolog_KB): self.prolog.assertz("pseudo_label(%s)" % i) self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") - def logic_forward(self, nums): - return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res'] - - def get_query_string(self, pred_res, key, address_idx): - query_string = "addition(" - for idx, i in enumerate(pred_res): - tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ',' - query_string += tmp - query_string += "%s)." % key - return query_string - class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list=[0, 1, '+', '=']): super().__init__(pseudo_label_list) self.prolog.consult('./datasets/hed/learn_add.pl') - def logic_forward(self, exs): - return len(list(self.prolog.query("abduce_consistent_insts([%s])." % exs))) != 0 - - def get_query_string(self, pred_res, key, address_idx): - flatten_pred_res = flatten(pred_res) - # add variables for prolog - for idx in range(len(flatten_pred_res)): - if idx in address_idx: - flatten_pred_res[idx] = 'X' + str(idx) - pred_res = reform_idx(flatten_pred_res, pred_res) - - query_string = "abduce_consistent_insts(%s)." % pred_res - return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='") - def consist_rule(self, exs, rules): rules = str(rules).replace("\'","") return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 diff --git a/abl/framework_hed.py b/abl/framework_hed.py index d942be6..171c247 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -150,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for m in mappings: pred_res = mapping_res(original_pred_res, m) max_abduce_num = 20 - solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num) + solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num) all_address_flag = reform_idx(solution, pred_res) consistent_idx_tmp = [] @@ -158,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) + candidate = abducer.kb.address_by_idx([pred_res[idx]], None, address_idx, True) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0]) diff --git a/examples/datasets/hed/learn_add.pl b/examples/datasets/hed/learn_add.pl index af1d6bb..fbf698f 100644 --- a/examples/datasets/hed/learn_add.pl +++ b/examples/datasets/hed/learn_add.pl @@ -32,6 +32,9 @@ abduce_consistent_insts(Exs):- % (Experimental) Uncomment to use parallel abduction % abduce_consistent_exs_concurrent(Exs), !. +logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false. +logic_forward(Exs) :- abduce_consistent_insts(Exs). + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% Abduce Delta_C given pseudo-labels %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% From 3c736686d1a65d7a0d339c907547e6ebf224255b Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 13:31:04 +0800 Subject: [PATCH 21/31] Change the sorting condition in _get_GKB --- abl/abducer/kb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index e23e637..7aff683 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -63,7 +63,7 @@ class KBBase(ABC): part_X, part_Y = zip(*XY_list) X.extend(part_X) Y.extend(part_Y) - if self.max_err != 0: + if type(Y[0]) in (int, float): sorted_XY = sorted(list(zip(Y, X))) X = [x for y, x in sorted_XY] Y = [y for y, x in sorted_XY] From 2160df0bb79ccc20f2f484bc1a29af07105fc3d3 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 13:57:17 +0800 Subject: [PATCH 22/31] Update kb.py --- abl/abducer/kb.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 7aff683..71bdcbe 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -282,9 +282,7 @@ class prolog_KB(KBBase): class add_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list=list(range(10))): super().__init__(pseudo_label_list) - for i in self.pseudo_label_list: - self.prolog.assertz("pseudo_label(%s)" % i) - self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2") + self.prolog.consult('../datasets/mnist_add/add.pl') class HED_prolog_KB(prolog_KB): From 52f8451d56475b5ac0a30786b4d56c788679a561 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 08:02:17 +0000 Subject: [PATCH 23/31] Update prolog_KB --- abl/abducer/abducer_base.py | 8 ++++---- abl/abducer/kb.py | 15 ++------------- examples/example.py | 4 ++-- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 777533a..cb50ee3 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -158,7 +158,7 @@ class AbducerBase(abc.ABC): if __name__ == '__main__': - from kb import add_KB, add_prolog_KB, HWF_KB, HED_prolog_KB + from kb import add_KB, prolog_KB, HWF_KB prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] @@ -191,7 +191,7 @@ if __name__ == '__main__': print() - kb = add_prolog_KB() + kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') abd = AbducerBase(kb, 'confidence') res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) @@ -205,7 +205,7 @@ if __name__ == '__main__': print(res) print() - kb = add_prolog_KB() + kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl') abd = AbducerBase(kb, 'confidence', zoopt=True) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) print(res) @@ -260,7 +260,7 @@ if __name__ == '__main__': print(res) print() - kb = HED_prolog_KB() + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/hed.pl') abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 71bdcbe..c5fc675 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -225,9 +225,10 @@ class add_KB(ClsKB): class prolog_KB(KBBase): - def __init__(self, pseudo_label_list): + def __init__(self, pseudo_label_list, pl_file): super().__init__(pseudo_label_list) self.prolog = pyswip.Prolog() + self.prolog.consult(pl_file) def logic_forward(self, pseudo_labels): result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res'] @@ -278,18 +279,6 @@ class prolog_KB(KBBase): candidates.append(candidate) return candidates - -class add_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list=list(range(10))): - super().__init__(pseudo_label_list) - self.prolog.consult('../datasets/mnist_add/add.pl') - - -class HED_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list=[0, 1, '+', '=']): - super().__init__(pseudo_label_list) - self.prolog.consult('./datasets/hed/learn_add.pl') - def consist_rule(self, exs, rules): rules = str(rules).replace("\'","") return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 diff --git a/examples/example.py b/examples/example.py index 30c50fc..ee0b898 100644 --- a/examples/example.py +++ b/examples/example.py @@ -20,7 +20,7 @@ from models.wabl_models import DecisionTree, WABLBasicModel from multiprocessing import Pool from abducer.abducer_base import AbducerBase -from abducer.kb import add_KB, HWF_KB, HED_prolog_KB +from abducer.kb import add_KB, HWF_KB, prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, split_equation @@ -33,7 +33,7 @@ def run_test(): # kb = HWF_KB(True) # abducer = AbducerBase(kb) - kb = HED_prolog_KB() + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/hed.pl') abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() From 42a43b3cca4d5ebf85efcdb02990de136a58ce86 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Mon, 6 Mar 2023 16:24:28 +0800 Subject: [PATCH 24/31] update para of method --- abl/abducer/kb.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index c5fc675..632bd42 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -24,6 +24,7 @@ from multiprocessing import Pool import pyswip class KBBase(ABC): + # TODO:有些不能是默认参数,必须给定 def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list self.len_list = len_list @@ -70,7 +71,7 @@ class KBBase(ABC): return X, Y @abstractmethod - def logic_forward(self): + def logic_forward(self, pseudo_labels): pass def _logic_forward(self, xs, multiple_predictions=False): @@ -87,7 +88,7 @@ class KBBase(ABC): return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) @abstractmethod - def _find_candidate_GKB(self): + def _find_candidate_GKB(self, pred_res, key): pass def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): @@ -260,7 +261,7 @@ class prolog_KB(KBBase): query_string += ",%s)." % key if not key_is_none_flag else ")." return query_string - def _find_candidate_GKB(self): + def _find_candidate_GKB(self, pred_res, key): pass def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): From 4e909320037df506320b74b54f11605eecadb9ac Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Mon, 6 Mar 2023 16:39:12 +0800 Subject: [PATCH 25/31] update package for example.py --- examples/example.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/example.py b/examples/example.py index ee0b898..a6a5194 100644 --- a/examples/example.py +++ b/examples/example.py @@ -10,21 +10,24 @@ # # ================================================================# -from utils.plog import logger, INFO +import sys +sys.path.append("../") + +from abl.utils.plog import logger, INFO import torch.nn as nn import torch -from models.nn import LeNet5, SymbolNet -from models.basic_model import BasicModel, BasicDataset -from models.wabl_models import DecisionTree, WABLBasicModel +from abl.models.nn import LeNet5, SymbolNet +from abl.models.basic_model import BasicModel, BasicDataset +from abl.models.wabl_models import DecisionTree, WABLBasicModel from multiprocessing import Pool -from abducer.abducer_base import AbducerBase -from abducer.kb import add_KB, HWF_KB, prolog_KB +from abl.abducer.abducer_base import AbducerBase +from abl.abducer.kb import add_KB, HWF_KB, prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, split_equation -import framework_hed +from abl import framework_hed def run_test(): From c98e516e43898b0c431c339f62009aa8a63bb701 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 08:47:45 +0000 Subject: [PATCH 26/31] change pl_file name for HED --- abl/abducer/abducer_base.py | 2 +- abl/abducer/kb.py | 3 +-- examples/example.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index cb50ee3..712eb46 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -260,7 +260,7 @@ if __name__ == '__main__': print(res) print() - kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/hed.pl') + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 632bd42..d937dcf 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -24,8 +24,7 @@ from multiprocessing import Pool import pyswip class KBBase(ABC): - # TODO:有些不能是默认参数,必须给定 - def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=0): + def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): self.pseudo_label_list = pseudo_label_list self.len_list = len_list self.GKB_flag = GKB_flag diff --git a/examples/example.py b/examples/example.py index a6a5194..c74f1b3 100644 --- a/examples/example.py +++ b/examples/example.py @@ -36,7 +36,7 @@ def run_test(): # kb = HWF_KB(True) # abducer = AbducerBase(kb) - kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/hed.pl') + kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() From fe52f91385c17ba639f3a9f7635d6f78825f7e4c Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Mon, 6 Mar 2023 19:39:59 +0800 Subject: [PATCH 27/31] speed up --- abl/utils/utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index fbb5bfb..acc7f94 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -7,6 +7,7 @@ def flatten(l): # return [item for sublist in l for item in flatten(sublist)] if isinstance(l, (list, tuple)) else [l] if not isinstance(l[0], (list, tuple)): return l + # TODO 稍微对比一下和itertools.chain.from_iterable(nested_list)的速度区别,看看哪个好 return [item for sublist in l for item in sublist] if isinstance(l, (list, tuple)) else [l] # for multiple predictions, modify from `learn_add.py` @@ -14,13 +15,8 @@ def reform_idx(flatten_pred_res, save_pred_res): re = [] i = 0 for e in save_pred_res: - j = 0 - idx = [] - while j < len(e): - idx.append(flatten_pred_res[i + j]) - j += 1 - re.append(idx) - i = i + j + re.append(flatten_pred_res[i:i + len(e)]) + i += len(e) return re def hamming_dist(A, B): From e436421fb52e3d6502f630aff25c8e704397030f Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 6 Mar 2023 20:09:11 +0800 Subject: [PATCH 28/31] Update framework_hed.py --- abl/framework_hed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abl/framework_hed.py b/abl/framework_hed.py index 171c247..a823b13 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -214,7 +214,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, consistent_idx = [] consistent_pred_res = [] for idx in range(len(pred_res)): - if abducer.kb.logic_forward([pred_res[idx]]): + if abducer.kb.logic_forward(pred_res[idx]): consistent_idx.append(idx) consistent_pred_res.append(pred_res[idx]) From bfd6dc8a5c37c680c683a3c3d9c4efe756853248 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Tue, 7 Mar 2023 00:01:08 +0800 Subject: [PATCH 29/31] update TODO in kb.py --- abl/abducer/kb.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index d937dcf..9ae919a 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -80,6 +80,7 @@ class KBBase(ABC): res = [self.logic_forward(x) for x in xs] return res + # TODO:这里max_address_num默认值-1,后面运行会有问题吗 def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) @@ -134,8 +135,10 @@ class KBBase(ABC): candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates, min_address_num, address_num + # TODO:应该也是内部使用的方法? def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] + # TODO:product combinations本身就是迭代器,如果没有其他用途,不用转list,直接放到循环那即可,省去一些时间,下面的同理 abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) if multiple_predictions: @@ -209,7 +212,8 @@ class ClsKB(KBBase): def __init__(self, pseudo_label_list, len_list, GKB_flag): super().__init__(pseudo_label_list, len_list, GKB_flag) - def logic_forward(self): + # TODO:这里以及RegKB可以不实现logic_forward吗,这样用户继承后不实现logic_forward就会报错 + def logic_forward(self, pseudo_labels): pass def _find_candidate_GKB(self, pred_res, key): @@ -243,13 +247,15 @@ class prolog_KB(KBBase): address_pred_res = pred_res.copy() if multiple_predictions: address_pred_res = flatten(address_pred_res) - + + # TODO:可以直接对address_idx循环? for idx in range(len(address_pred_res)): if idx in address_idx: address_pred_res[idx] = 'P' + str(idx) if multiple_predictions: address_pred_res = reform_idx(address_pred_res, pred_res) + # TODO:不知道有没有更简洁的方法 regex = r"'P\d+'" return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) @@ -269,6 +275,7 @@ class prolog_KB(KBBase): if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) + # TODO:这里后面的那个list应该也不需要 abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] for c in abduce_c: candidate = pred_res.copy() @@ -289,17 +296,15 @@ class prolog_KB(KBBase): if len(prolog_result) == 0: return None prolog_rules = prolog_result[0]['X'] - rules = [] - for rule in prolog_rules: - rules.append(rule.value) + rules = [rule.value for rule in prolog_rules] return rules - +# TODO:和ClsKB的参数顺序不统一 class RegKB(KBBase): def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - def logic_forward(self): + def logic_forward(self, pseudo_labels): pass def _find_candidate_GKB(self, pred_res, key): @@ -333,6 +338,7 @@ class HWF_KB(RegKB): ): super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) + # TODO:应该是静态方法 def valid_candidate(self, formula): if len(formula) % 2 == 0: return False From e911ef1103d74cc05b2bd9e3c973764b3c707111 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 7 Mar 2023 00:17:55 +0000 Subject: [PATCH 30/31] Update several files --- abl/abducer/abducer_base.py | 4 ++-- abl/abducer/kb.py | 44 +++++++++++++------------------------ abl/utils/utils.py | 9 ++++---- 3 files changed, 21 insertions(+), 36 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 712eb46..bd76831 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -219,7 +219,7 @@ if __name__ == '__main__': print(res) print() - kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 0.1) + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0) print(res) @@ -227,7 +227,7 @@ if __name__ == '__main__': print(res) print() - kb = HWF_KB(GKB_flag=True, len_list=[1, 3, 5], max_err = 1) + kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) abd = AbducerBase(kb, 'hamming') res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0) print(res) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 9ae919a..5b82f03 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -80,8 +80,7 @@ class KBBase(ABC): res = [self.logic_forward(x) for x in xs] return res - # TODO:这里max_address_num默认值-1,后面运行会有问题吗 - def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): + def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False): if self.GKB_flag: return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions) else: @@ -135,11 +134,9 @@ class KBBase(ABC): candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs] return candidates, min_address_num, address_num - # TODO:应该也是内部使用的方法? def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): candidates = [] - # TODO:product combinations本身就是迭代器,如果没有其他用途,不用转list,直接放到循环那即可,省去一些时间,下面的同理 - abduce_c = list(product(self.pseudo_label_list, repeat=len(address_idx))) + abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) if multiple_predictions: save_pred_res = pred_res @@ -160,9 +157,9 @@ class KBBase(ABC): def _address(self, address_num, pred_res, key, multiple_predictions): new_candidates = [] if not multiple_predictions: - address_idx_list = list(combinations(list(range(len(pred_res))), address_num)) + address_idx_list = combinations(list(range(len(pred_res))), address_num) else: - address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num)) + address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num) for address_idx in address_idx_list: candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions) @@ -212,10 +209,6 @@ class ClsKB(KBBase): def __init__(self, pseudo_label_list, len_list, GKB_flag): super().__init__(pseudo_label_list, len_list, GKB_flag) - # TODO:这里以及RegKB可以不实现logic_forward吗,这样用户继承后不实现logic_forward就会报错 - def logic_forward(self, pseudo_labels): - pass - def _find_candidate_GKB(self, pred_res, key): return self.base[len(pred_res)][key] @@ -248,10 +241,8 @@ class prolog_KB(KBBase): if multiple_predictions: address_pred_res = flatten(address_pred_res) - # TODO:可以直接对address_idx循环? - for idx in range(len(address_pred_res)): - if idx in address_idx: - address_pred_res[idx] = 'P' + str(idx) + for idx in address_idx: + address_pred_res[idx] = 'P' + str(idx) if multiple_predictions: address_pred_res = reform_idx(address_pred_res, pred_res) @@ -275,8 +266,7 @@ class prolog_KB(KBBase): if multiple_predictions: save_pred_res = pred_res pred_res = flatten(pred_res) - # TODO:这里后面的那个list应该也不需要 - abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))] + abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] for c in abduce_c: candidate = pred_res.copy() for i, idx in enumerate(address_idx): @@ -299,14 +289,11 @@ class prolog_KB(KBBase): rules = [rule.value for rule in prolog_rules] return rules -# TODO:和ClsKB的参数顺序不统一 + class RegKB(KBBase): - def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None, max_err=1e-3): + def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=1e-3): super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - def logic_forward(self, pseudo_labels): - pass - def _find_candidate_GKB(self, pred_res, key): potential_candidates = self.base[len(pred_res)] key_list = list(potential_candidates.keys()) @@ -331,15 +318,15 @@ class RegKB(KBBase): class HWF_KB(RegKB): def __init__( - self, GKB_flag=False, + self, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7], + GKB_flag=False, max_err=1e-3 ): - super().__init__(GKB_flag, pseudo_label_list, len_list, max_err) + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err) - # TODO:应该是静态方法 - def valid_candidate(self, formula): + def _valid_candidate(self, formula): if len(formula) % 2 == 0: return False for i in range(len(formula)): @@ -350,7 +337,7 @@ class HWF_KB(RegKB): return True def logic_forward(self, formula): - if not self.valid_candidate(formula): + if not self._valid_candidate(formula): return np.inf mapping = { '1': '1', @@ -374,5 +361,4 @@ class HWF_KB(RegKB): import time if __name__ == "__main__": - pass - + pass \ No newline at end of file diff --git a/abl/utils/utils.py b/abl/utils/utils.py index acc7f94..d986065 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,16 +1,15 @@ import numpy as np from .plog import INFO from collections import OrderedDict +from itertools import chain -# for multiple predictions, modify from `learn_add.py` +# for multiple predictions def flatten(l): - # return [item for sublist in l for item in flatten(sublist)] if isinstance(l, (list, tuple)) else [l] if not isinstance(l[0], (list, tuple)): return l - # TODO 稍微对比一下和itertools.chain.from_iterable(nested_list)的速度区别,看看哪个好 - return [item for sublist in l for item in sublist] if isinstance(l, (list, tuple)) else [l] + return list(chain.from_iterable(l)) -# for multiple predictions, modify from `learn_add.py` +# for multiple predictions def reform_idx(flatten_pred_res, save_pred_res): re = [] i = 0 From b75b2897b953353ddade8afae51499367a21b0a8 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Tue, 7 Mar 2023 00:27:49 +0000 Subject: [PATCH 31/31] Rearrange address_by_idx to abducer_base --- abl/abducer/abducer_base.py | 10 ++++++---- abl/framework_hed.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index bd76831..cf14bb7 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -64,7 +64,7 @@ class AbducerBase(abc.ABC): score = 0 for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = self.kb.address_by_idx([pred_res[idx]], key[idx], address_idx, True) + candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx) if len(candidate) > 0: score += 1 return score @@ -72,7 +72,7 @@ class AbducerBase(abc.ABC): def _zoopt_address_score(self, pred_res, key, sol): if not self.multiple_predictions: address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.address_by_idx(pred_res, key, address_idx) return 1 if len(candidates) > 0 else 0 else: return self._zoopt_score_multiple(pred_res, key, sol.get_x()) @@ -115,6 +115,9 @@ class AbducerBase(abc.ABC): key = tuple(key) self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates + + def address_by_idx(self, pred_res, key, address_idx): + return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) def abduce(self, data, max_address_num=-1, require_more_address=0): pred_res, pred_res_prob, key = data @@ -129,7 +132,7 @@ class AbducerBase(abc.ABC): if self.zoopt: solution = self.zoopt_get_solution(pred_res, key, max_address_num) address_idx = [idx for idx, i in enumerate(solution) if i != 0] - candidates = self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions) + candidates = self.address_by_idx(pred_res, key, address_idx) address_num = int(solution.sum()) min_address_num = address_num else: @@ -156,7 +159,6 @@ class AbducerBase(abc.ABC): def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address_num, require_more_address) - if __name__ == '__main__': from kb import add_KB, prolog_KB, HWF_KB diff --git a/abl/framework_hed.py b/abl/framework_hed.py index a823b13..3f09ff6 100644 --- a/abl/framework_hed.py +++ b/abl/framework_hed.py @@ -158,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.kb.address_by_idx([pred_res[idx]], None, address_idx, True) + candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0])