| @@ -228,63 +228,21 @@ class AbducerBase(abc.ABC): | |||
| def __call__(self, Z, Y, max_address=-1, require_more_address=0): | |||
| return self.batch_abduce(Z, Y, max_address, require_more_address) | |||
| class HED_Abducer(AbducerBase): | |||
| def __init__(self, kb, dist_func='hamming'): | |||
| super().__init__(kb, dist_func, zoopt=True) | |||
| def _address_by_idxs(self, pred_res, key, all_address_flag, idxs): | |||
| pred = [] | |||
| k = [] | |||
| address_flag = [] | |||
| for idx in idxs: | |||
| pred.append(pred_res[idx]) | |||
| k.append(key[idx]) | |||
| address_flag += list(all_address_flag[idx]) | |||
| address_idx = np.where(np.array(address_flag) != 0)[0] | |||
| candidate = self.address_by_idx(pred, k, address_idx) | |||
| return candidate | |||
| def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): | |||
| all_address_flag = reform_idx(sol.get_x(), pred_res) | |||
| lefted_idxs = [i for i in range(len(pred_res))] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(pred_res)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| else: | |||
| if len(idxs) > len(max_candidate_idxs): | |||
| found = True | |||
| max_candidate_idxs = idxs.copy() | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| candidate_size.append(len(removed) + 1) | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| import math | |||
| for i in range(0, len(candidate_size)): | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| return score | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| if __name__ == '__main__': | |||
| from kb import add_KB, prolog_KB, HWF_KB, HED_prolog_KB | |||
| from kb import KBBase, prolog_KB | |||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| print('add_KB with GKB:') | |||
| kb = add_KB(GKB_flag=True) | |||
| abd = AbducerBase(kb, 'confidence') | |||
| @@ -372,6 +330,35 @@ if __name__ == '__main__': | |||
| print(res) | |||
| print() | |||
| class HWF_KB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| len_list=[1, 3, 5, 7], | |||
| GKB_flag=False, | |||
| max_err=1e-3, | |||
| use_cache=True | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||
| mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval(''.join(formula)) | |||
| print('HWF_KB with GKB, max_err=0.1') | |||
| kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) | |||
| abd = AbducerBase(kb, 'hamming') | |||
| @@ -432,6 +419,72 @@ if __name__ == '__main__': | |||
| res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.9, require_more_address=0) | |||
| print(res) | |||
| print() | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("\'","") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]['X'] | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| class HED_Abducer(AbducerBase): | |||
| def __init__(self, kb, dist_func='hamming'): | |||
| super().__init__(kb, dist_func, zoopt=True) | |||
| def _address_by_idxs(self, pred_res, key, all_address_flag, idxs): | |||
| pred = [] | |||
| k = [] | |||
| address_flag = [] | |||
| for idx in idxs: | |||
| pred.append(pred_res[idx]) | |||
| k.append(key[idx]) | |||
| address_flag += list(all_address_flag[idx]) | |||
| address_idx = np.where(np.array(address_flag) != 0)[0] | |||
| candidate = self.address_by_idx(pred, k, address_idx) | |||
| return candidate | |||
| def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): | |||
| all_address_flag = reform_idx(sol.get_x(), pred_res) | |||
| lefted_idxs = [i for i in range(len(pred_res))] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(pred_res)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| else: | |||
| if len(idxs) > len(max_candidate_idxs): | |||
| found = True | |||
| max_candidate_idxs = idxs.copy() | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| candidate_size.append(len(removed) + 1) | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| import math | |||
| for i in range(0, len(candidate_size)): | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| return score | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||
| abd = HED_Abducer(kb) | |||
| @@ -189,13 +189,6 @@ class KBBase(ABC): | |||
| else: | |||
| return sum(self._dict_len(v) for v in self.base.values()) | |||
| class add_KB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class prolog_KB(KBBase): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| @@ -244,54 +237,3 @@ class prolog_KB(KBBase): | |||
| candidate = reform_idx(candidate, save_pred_res) | |||
| candidates.append(candidate) | |||
| return candidates | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("\'","") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]['X'] | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| class HWF_KB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], | |||
| len_list=[1, 3, 5, 7], | |||
| GKB_flag=False, | |||
| max_err=1e-3, | |||
| use_cache=True | |||
| ): | |||
| super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']: | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return np.inf | |||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||
| mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval(''.join(formula)) | |||
| if __name__ == "__main__": | |||
| pass | |||
| @@ -10,15 +10,17 @@ | |||
| "\n", | |||
| "sys.path.append(\"../../\")\n", | |||
| "\n", | |||
| "import numpy as np\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import HED_Abducer\n", | |||
| "from abl.abducer.kb import HED_prolog_KB\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import prolog_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "from abl.utils.utils import reform_idx\n", | |||
| "\n", | |||
| "from models.nn import SymbolNet\n", | |||
| "from datasets.get_hed import get_hed, split_equation\n", | |||
| @@ -58,7 +60,75 @@ | |||
| ], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "class HED_prolog_KB(prolog_KB):\n", | |||
| " def __init__(self, pseudo_label_list, pl_file):\n", | |||
| " super().__init__(pseudo_label_list, pl_file)\n", | |||
| " \n", | |||
| " def consist_rule(self, exs, rules):\n", | |||
| " rules = str(rules).replace(\"\\'\",\"\")\n", | |||
| " return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n", | |||
| "\n", | |||
| " def abduce_rules(self, pred_res):\n", | |||
| " prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n", | |||
| " if len(prolog_result) == 0:\n", | |||
| " return None\n", | |||
| " prolog_rules = prolog_result[0]['X']\n", | |||
| " rules = [rule.value for rule in prolog_rules]\n", | |||
| " return rules\n", | |||
| " \n", | |||
| " \n", | |||
| "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n", | |||
| "\n", | |||
| "class HED_Abducer(AbducerBase):\n", | |||
| " def __init__(self, kb, dist_func='hamming'):\n", | |||
| " super().__init__(kb, dist_func, zoopt=True)\n", | |||
| " \n", | |||
| " def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", | |||
| " pred = []\n", | |||
| " k = []\n", | |||
| " address_flag = []\n", | |||
| " for idx in idxs:\n", | |||
| " pred.append(pred_res[idx])\n", | |||
| " k.append(key[idx])\n", | |||
| " address_flag += list(all_address_flag[idx])\n", | |||
| " address_idx = np.where(np.array(address_flag) != 0)[0] \n", | |||
| " candidate = self.address_by_idx(pred, k, address_idx)\n", | |||
| " return candidate\n", | |||
| " \n", | |||
| " def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n", | |||
| " all_address_flag = reform_idx(sol.get_x(), pred_res)\n", | |||
| " lefted_idxs = [i for i in range(len(pred_res))]\n", | |||
| " candidate_size = [] \n", | |||
| " while lefted_idxs:\n", | |||
| " idxs = []\n", | |||
| " idxs.append(lefted_idxs.pop(0))\n", | |||
| " max_candidate_idxs = []\n", | |||
| " found = False\n", | |||
| " for idx in range(-1, len(pred_res)):\n", | |||
| " if (not idx in idxs) and (idx >= 0):\n", | |||
| " idxs.append(idx)\n", | |||
| " candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n", | |||
| " if len(candidate) == 0:\n", | |||
| " if len(idxs) > 1:\n", | |||
| " idxs.pop()\n", | |||
| " else:\n", | |||
| " if len(idxs) > len(max_candidate_idxs):\n", | |||
| " found = True\n", | |||
| " max_candidate_idxs = idxs.copy() \n", | |||
| " removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", | |||
| " if found:\n", | |||
| " candidate_size.append(len(removed) + 1)\n", | |||
| " lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] \n", | |||
| " candidate_size.sort()\n", | |||
| " score = 0\n", | |||
| " import math\n", | |||
| " for i in range(0, len(candidate_size)):\n", | |||
| " score -= math.exp(-i) * candidate_size[i]\n", | |||
| " return score\n", | |||
| "\n", | |||
| " def abduce_rules(self, pred_res):\n", | |||
| " return self.kb.abduce_rules(pred_res)\n", | |||
| " \n", | |||
| "abducer = HED_Abducer(kb)" | |||
| ] | |||
| }, | |||
| @@ -10,11 +10,12 @@ | |||
| "\n", | |||
| "sys.path.append(\"../../\")\n", | |||
| "\n", | |||
| "import numpy as np\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import HWF_KB\n", | |||
| "from abl.abducer.kb import KBBase\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| @@ -50,6 +51,35 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "class HWF_KB(KBBase):\n", | |||
| " def __init__(\n", | |||
| " self, \n", | |||
| " pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n", | |||
| " len_list=[1, 3, 5, 7],\n", | |||
| " GKB_flag=False,\n", | |||
| " max_err=1e-3,\n", | |||
| " use_cache=True\n", | |||
| " ):\n", | |||
| " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", | |||
| "\n", | |||
| " def _valid_candidate(self, formula):\n", | |||
| " if len(formula) % 2 == 0:\n", | |||
| " return False\n", | |||
| " for i in range(len(formula)):\n", | |||
| " if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n", | |||
| " return False\n", | |||
| " if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n", | |||
| " return False\n", | |||
| " return True\n", | |||
| "\n", | |||
| " def logic_forward(self, formula):\n", | |||
| " if not self._valid_candidate(formula):\n", | |||
| " return np.inf\n", | |||
| " mapping = {str(i): str(i) for i in range(1, 10)}\n", | |||
| " mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n", | |||
| " formula = [mapping[f] for f in formula]\n", | |||
| " return eval(''.join(formula))\n", | |||
| "\n", | |||
| "kb = HWF_KB(GKB_flag=True)\n", | |||
| "abducer = AbducerBase(kb)" | |||
| ] | |||
| @@ -14,7 +14,7 @@ | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import add_KB\n", | |||
| "from abl.abducer.kb import KBBase, prolog_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| @@ -50,7 +50,16 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "class add_KB(KBBase):\n", | |||
| " def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n", | |||
| " super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n", | |||
| "\n", | |||
| " def logic_forward(self, nums):\n", | |||
| " return sum(nums)\n", | |||
| "\n", | |||
| "kb = add_KB(GKB_flag=True)\n", | |||
| "\n", | |||
| "# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n", | |||
| "abducer = AbducerBase(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||