From a580b0dfd506a7b4ae51e739e8c05ec09b3bf623 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Fri, 31 Mar 2023 21:58:46 +0800 Subject: [PATCH] [MNT] Move specific kb and abducer to example --- abl/abducer/abducer_base.py | 155 ++++++++++++++------- abl/abducer/kb.py | 58 -------- examples/hed/hed_example.ipynb | 74 +++++++++- examples/hwf/hwf_example.ipynb | 32 ++++- examples/mnist_add/mnist_add_example.ipynb | 11 +- 5 files changed, 217 insertions(+), 113 deletions(-) diff --git a/abl/abducer/abducer_base.py b/abl/abducer/abducer_base.py index 6a7ec86..a117cc5 100644 --- a/abl/abducer/abducer_base.py +++ b/abl/abducer/abducer_base.py @@ -228,63 +228,21 @@ class AbducerBase(abc.ABC): def __call__(self, Z, Y, max_address=-1, require_more_address=0): return self.batch_abduce(Z, Y, max_address, require_more_address) -class HED_Abducer(AbducerBase): - def __init__(self, kb, dist_func='hamming'): - super().__init__(kb, dist_func, zoopt=True) - - def _address_by_idxs(self, pred_res, key, all_address_flag, idxs): - pred = [] - k = [] - address_flag = [] - for idx in idxs: - pred.append(pred_res[idx]) - k.append(key[idx]) - address_flag += list(all_address_flag[idx]) - address_idx = np.where(np.array(address_flag) != 0)[0] - candidate = self.address_by_idx(pred, k, address_idx) - return candidate - - def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): - all_address_flag = reform_idx(sol.get_x(), pred_res) - lefted_idxs = [i for i in range(len(pred_res))] - candidate_size = [] - while lefted_idxs: - idxs = [] - idxs.append(lefted_idxs.pop(0)) - max_candidate_idxs = [] - found = False - for idx in range(-1, len(pred_res)): - if (not idx in idxs) and (idx >= 0): - idxs.append(idx) - candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs) - if len(candidate) == 0: - if len(idxs) > 1: - idxs.pop() - else: - if len(idxs) > len(max_candidate_idxs): - found = True - max_candidate_idxs = idxs.copy() - removed = [i for i in lefted_idxs if i in max_candidate_idxs] - if found: - candidate_size.append(len(removed) + 1) - lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] - candidate_size.sort() - score = 0 - import math - for i in range(0, len(candidate_size)): - score -= math.exp(-i) * candidate_size[i] - return score - - def abduce_rules(self, pred_res): - return self.kb.abduce_rules(pred_res) - if __name__ == '__main__': - from kb import add_KB, prolog_KB, HWF_KB, HED_prolog_KB + from kb import KBBase, prolog_KB prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] + class add_KB(KBBase): + def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) + + def logic_forward(self, nums): + return sum(nums) + + print('add_KB with GKB:') kb = add_KB(GKB_flag=True) abd = AbducerBase(kb, 'confidence') @@ -372,6 +330,35 @@ if __name__ == '__main__': print(res) print() + class HWF_KB(KBBase): + def __init__( + self, + pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], + len_list=[1, 3, 5, 7], + GKB_flag=False, + max_err=1e-3, + use_cache=True + ): + super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) + + def _valid_candidate(self, formula): + if len(formula) % 2 == 0: + return False + for i in range(len(formula)): + if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: + return False + if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: + return False + return True + + def logic_forward(self, formula): + if not self._valid_candidate(formula): + return np.inf + mapping = {str(i): str(i) for i in range(1, 10)} + mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) + formula = [mapping[f] for f in formula] + return eval(''.join(formula)) + print('HWF_KB with GKB, max_err=0.1') kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) abd = AbducerBase(kb, 'hamming') @@ -432,6 +419,72 @@ if __name__ == '__main__': res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.9, require_more_address=0) print(res) print() + + class HED_prolog_KB(prolog_KB): + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list, pl_file) + + def consist_rule(self, exs, rules): + rules = str(rules).replace("\'","") + return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 + + def abduce_rules(self, pred_res): + prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) + if len(prolog_result) == 0: + return None + prolog_rules = prolog_result[0]['X'] + rules = [rule.value for rule in prolog_rules] + return rules + + class HED_Abducer(AbducerBase): + def __init__(self, kb, dist_func='hamming'): + super().__init__(kb, dist_func, zoopt=True) + + def _address_by_idxs(self, pred_res, key, all_address_flag, idxs): + pred = [] + k = [] + address_flag = [] + for idx in idxs: + pred.append(pred_res[idx]) + k.append(key[idx]) + address_flag += list(all_address_flag[idx]) + address_idx = np.where(np.array(address_flag) != 0)[0] + candidate = self.address_by_idx(pred, k, address_idx) + return candidate + + def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): + all_address_flag = reform_idx(sol.get_x(), pred_res) + lefted_idxs = [i for i in range(len(pred_res))] + candidate_size = [] + while lefted_idxs: + idxs = [] + idxs.append(lefted_idxs.pop(0)) + max_candidate_idxs = [] + found = False + for idx in range(-1, len(pred_res)): + if (not idx in idxs) and (idx >= 0): + idxs.append(idx) + candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs) + if len(candidate) == 0: + if len(idxs) > 1: + idxs.pop() + else: + if len(idxs) > len(max_candidate_idxs): + found = True + max_candidate_idxs = idxs.copy() + removed = [i for i in lefted_idxs if i in max_candidate_idxs] + if found: + candidate_size.append(len(removed) + 1) + lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] + candidate_size.sort() + score = 0 + import math + for i in range(0, len(candidate_size)): + score -= math.exp(-i) * candidate_size[i] + return score + + def abduce_rules(self, pred_res): + return self.kb.abduce_rules(pred_res) kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') abd = HED_Abducer(kb) diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index 1db41cb..76a169d 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -189,13 +189,6 @@ class KBBase(ABC): else: return sum(self._dict_len(v) for v in self.base.values()) -class add_KB(KBBase): - def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): - super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) - - def logic_forward(self, nums): - return sum(nums) - class prolog_KB(KBBase): def __init__(self, pseudo_label_list, pl_file): @@ -244,54 +237,3 @@ class prolog_KB(KBBase): candidate = reform_idx(candidate, save_pred_res) candidates.append(candidate) return candidates - - -class HED_prolog_KB(prolog_KB): - def __init__(self, pseudo_label_list, pl_file): - super().__init__(pseudo_label_list, pl_file) - - def consist_rule(self, exs, rules): - rules = str(rules).replace("\'","") - return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 - - def abduce_rules(self, pred_res): - prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) - if len(prolog_result) == 0: - return None - prolog_rules = prolog_result[0]['X'] - rules = [rule.value for rule in prolog_rules] - return rules - - -class HWF_KB(KBBase): - def __init__( - self, - pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], - len_list=[1, 3, 5, 7], - GKB_flag=False, - max_err=1e-3, - use_cache=True - ): - super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) - - def _valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: - return False - if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: - return False - return True - - def logic_forward(self, formula): - if not self._valid_candidate(formula): - return np.inf - mapping = {str(i): str(i) for i in range(1, 10)} - mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) - formula = [mapping[f] for f in formula] - return eval(''.join(formula)) - - -if __name__ == "__main__": - pass \ No newline at end of file diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 922cc96..0a126d8 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -10,15 +10,17 @@ "\n", "sys.path.append(\"../../\")\n", "\n", + "import numpy as np\n", "import torch.nn as nn\n", "import torch\n", "\n", - "from abl.abducer.abducer_base import HED_Abducer\n", - "from abl.abducer.kb import HED_prolog_KB\n", + "from abl.abducer.abducer_base import AbducerBase\n", + "from abl.abducer.kb import prolog_KB\n", "\n", "from abl.utils.plog import logger\n", "from abl.models.basic_model import BasicModel\n", "from abl.models.wabl_models import WABLBasicModel\n", + "from abl.utils.utils import reform_idx\n", "\n", "from models.nn import SymbolNet\n", "from datasets.get_hed import get_hed, split_equation\n", @@ -58,7 +60,75 @@ ], "source": [ "# Initialize knowledge base and abducer\n", + "class HED_prolog_KB(prolog_KB):\n", + " def __init__(self, pseudo_label_list, pl_file):\n", + " super().__init__(pseudo_label_list, pl_file)\n", + " \n", + " def consist_rule(self, exs, rules):\n", + " rules = str(rules).replace(\"\\'\",\"\")\n", + " return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n", + "\n", + " def abduce_rules(self, pred_res):\n", + " prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n", + " if len(prolog_result) == 0:\n", + " return None\n", + " prolog_rules = prolog_result[0]['X']\n", + " rules = [rule.value for rule in prolog_rules]\n", + " return rules\n", + " \n", + " \n", "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n", + "\n", + "class HED_Abducer(AbducerBase):\n", + " def __init__(self, kb, dist_func='hamming'):\n", + " super().__init__(kb, dist_func, zoopt=True)\n", + " \n", + " def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", + " pred = []\n", + " k = []\n", + " address_flag = []\n", + " for idx in idxs:\n", + " pred.append(pred_res[idx])\n", + " k.append(key[idx])\n", + " address_flag += list(all_address_flag[idx])\n", + " address_idx = np.where(np.array(address_flag) != 0)[0] \n", + " candidate = self.address_by_idx(pred, k, address_idx)\n", + " return candidate\n", + " \n", + " def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n", + " all_address_flag = reform_idx(sol.get_x(), pred_res)\n", + " lefted_idxs = [i for i in range(len(pred_res))]\n", + " candidate_size = [] \n", + " while lefted_idxs:\n", + " idxs = []\n", + " idxs.append(lefted_idxs.pop(0))\n", + " max_candidate_idxs = []\n", + " found = False\n", + " for idx in range(-1, len(pred_res)):\n", + " if (not idx in idxs) and (idx >= 0):\n", + " idxs.append(idx)\n", + " candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n", + " if len(candidate) == 0:\n", + " if len(idxs) > 1:\n", + " idxs.pop()\n", + " else:\n", + " if len(idxs) > len(max_candidate_idxs):\n", + " found = True\n", + " max_candidate_idxs = idxs.copy() \n", + " removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", + " if found:\n", + " candidate_size.append(len(removed) + 1)\n", + " lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] \n", + " candidate_size.sort()\n", + " score = 0\n", + " import math\n", + " for i in range(0, len(candidate_size)):\n", + " score -= math.exp(-i) * candidate_size[i]\n", + " return score\n", + "\n", + " def abduce_rules(self, pred_res):\n", + " return self.kb.abduce_rules(pred_res)\n", + " \n", "abducer = HED_Abducer(kb)" ] }, diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index 98f46a6..80eb05a 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -10,11 +10,12 @@ "\n", "sys.path.append(\"../../\")\n", "\n", + "import numpy as np\n", "import torch.nn as nn\n", "import torch\n", "\n", "from abl.abducer.abducer_base import AbducerBase\n", - "from abl.abducer.kb import HWF_KB\n", + "from abl.abducer.kb import KBBase\n", "\n", "from abl.utils.plog import logger\n", "from abl.models.basic_model import BasicModel\n", @@ -50,6 +51,35 @@ "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", + "class HWF_KB(KBBase):\n", + " def __init__(\n", + " self, \n", + " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", + " len_list=[1, 3, 5, 7],\n", + " GKB_flag=False,\n", + " max_err=1e-3,\n", + " use_cache=True\n", + " ):\n", + " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", + "\n", + " def _valid_candidate(self, formula):\n", + " if len(formula) % 2 == 0:\n", + " return False\n", + " for i in range(len(formula)):\n", + " if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", + " return False\n", + " if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", + " return False\n", + " return True\n", + "\n", + " def logic_forward(self, formula):\n", + " if not self._valid_candidate(formula):\n", + " return np.inf\n", + " mapping = {str(i): str(i) for i in range(1, 10)}\n", + " mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", + " formula = [mapping[f] for f in formula]\n", + " return eval(''.join(formula))\n", + "\n", "kb = HWF_KB(GKB_flag=True)\n", "abducer = AbducerBase(kb)" ] diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 0388622..83b52b4 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -14,7 +14,7 @@ "import torch\n", "\n", "from abl.abducer.abducer_base import AbducerBase\n", - "from abl.abducer.kb import add_KB\n", + "from abl.abducer.kb import KBBase, prolog_KB\n", "\n", "from abl.utils.plog import logger\n", "from abl.models.basic_model import BasicModel\n", @@ -50,7 +50,16 @@ "outputs": [], "source": [ "# Initialize knowledge base and abducer\n", + "class add_KB(KBBase):\n", + " def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n", + " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", + "\n", + " def logic_forward(self, nums):\n", + " return sum(nums)\n", + "\n", "kb = add_KB(GKB_flag=True)\n", + "\n", + "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", "abducer = AbducerBase(kb, dist_func=\"confidence\")" ] },