Browse Source

[MNT] Move specific kb and abducer to example

pull/3/head
troyyyyy 2 years ago
parent
commit
a580b0dfd5
5 changed files with 217 additions and 113 deletions
  1. +104
    -51
      abl/abducer/abducer_base.py
  2. +0
    -58
      abl/abducer/kb.py
  3. +72
    -2
      examples/hed/hed_example.ipynb
  4. +31
    -1
      examples/hwf/hwf_example.ipynb
  5. +10
    -1
      examples/mnist_add/mnist_add_example.ipynb

+ 104
- 51
abl/abducer/abducer_base.py View File

@@ -228,63 +228,21 @@ class AbducerBase(abc.ABC):
def __call__(self, Z, Y, max_address=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address, require_more_address)
class HED_Abducer(AbducerBase):
def __init__(self, kb, dist_func='hamming'):
super().__init__(kb, dist_func, zoopt=True)
def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):
pred = []
k = []
address_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(key[idx])
address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx)
return candidate
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math
for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

if __name__ == '__main__':
from kb import add_KB, prolog_KB, HWF_KB, HED_prolog_KB
from kb import KBBase, prolog_KB
prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]

class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def logic_forward(self, nums):
return sum(nums)
print('add_KB with GKB:')
kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
@@ -372,6 +330,35 @@ if __name__ == '__main__':
print(res)
print()
class HWF_KB(KBBase):
def __init__(
self,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
len_list=[1, 3, 5, 7],
GKB_flag=False,
max_err=1e-3,
use_cache=True
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})
formula = [mapping[f] for f in formula]
return eval(''.join(formula))
print('HWF_KB with GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
@@ -432,6 +419,72 @@ if __name__ == '__main__':
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.9, require_more_address=0)
print(res)
print()
class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
def consist_rule(self, exs, rules):
rules = str(rules).replace("\'","")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]['X']
rules = [rule.value for rule in prolog_rules]
return rules
class HED_Abducer(AbducerBase):
def __init__(self, kb, dist_func='hamming'):
super().__init__(kb, dist_func, zoopt=True)
def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):
pred = []
k = []
address_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(key[idx])
address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx)
return candidate
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
idxs = []
idxs.append(lefted_idxs.pop(0))
max_candidate_idxs = []
found = False
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
else:
if len(idxs) > len(max_candidate_idxs):
found = True
max_candidate_idxs = idxs.copy()
removed = [i for i in lefted_idxs if i in max_candidate_idxs]
if found:
candidate_size.append(len(removed) + 1)
lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
candidate_size.sort()
score = 0
import math
for i in range(0, len(candidate_size)):
score -= math.exp(-i) * candidate_size[i]
return score

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = HED_Abducer(kb)


+ 0
- 58
abl/abducer/kb.py View File

@@ -189,13 +189,6 @@ class KBBase(ABC):
else:
return sum(self._dict_len(v) for v in self.base.values())

class add_KB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def logic_forward(self, nums):
return sum(nums)


class prolog_KB(KBBase):
def __init__(self, pseudo_label_list, pl_file):
@@ -244,54 +237,3 @@ class prolog_KB(KBBase):
candidate = reform_idx(candidate, save_pred_res)
candidates.append(candidate)
return candidates


class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)
def consist_rule(self, exs, rules):
rules = str(rules).replace("\'","")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

def abduce_rules(self, pred_res):
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]['X']
rules = [rule.value for rule in prolog_rules]
return rules


class HWF_KB(KBBase):
def __init__(
self,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
len_list=[1, 3, 5, 7],
GKB_flag=False,
max_err=1e-3,
use_cache=True
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})
formula = [mapping[f] for f in formula]
return eval(''.join(formula))


if __name__ == "__main__":
pass

+ 72
- 2
examples/hed/hed_example.ipynb View File

@@ -10,15 +10,17 @@
"\n",
"sys.path.append(\"../../\")\n",
"\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.abducer.abducer_base import HED_Abducer\n",
"from abl.abducer.kb import HED_prolog_KB\n",
"from abl.abducer.abducer_base import AbducerBase\n",
"from abl.abducer.kb import prolog_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"from abl.utils.utils import reform_idx\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.get_hed import get_hed, split_equation\n",
@@ -58,7 +60,75 @@
],
"source": [
"# Initialize knowledge base and abducer\n",
"class HED_prolog_KB(prolog_KB):\n",
" def __init__(self, pseudo_label_list, pl_file):\n",
" super().__init__(pseudo_label_list, pl_file)\n",
" \n",
" def consist_rule(self, exs, rules):\n",
" rules = str(rules).replace(\"\\'\",\"\")\n",
" return len(list(self.prolog.query(\"eval_inst_feature(%s, %s).\" % (exs, rules)))) != 0\n",
"\n",
" def abduce_rules(self, pred_res):\n",
" prolog_result = list(self.prolog.query(\"consistent_inst_feature(%s, X).\" % pred_res))\n",
" if len(prolog_result) == 0:\n",
" return None\n",
" prolog_rules = prolog_result[0]['X']\n",
" rules = [rule.value for rule in prolog_rules]\n",
" return rules\n",
" \n",
" \n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')\n",
"\n",
"class HED_Abducer(AbducerBase):\n",
" def __init__(self, kb, dist_func='hamming'):\n",
" super().__init__(kb, dist_func, zoopt=True)\n",
" \n",
" def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n",
" pred = []\n",
" k = []\n",
" address_flag = []\n",
" for idx in idxs:\n",
" pred.append(pred_res[idx])\n",
" k.append(key[idx])\n",
" address_flag += list(all_address_flag[idx])\n",
" address_idx = np.where(np.array(address_flag) != 0)[0] \n",
" candidate = self.address_by_idx(pred, k, address_idx)\n",
" return candidate\n",
" \n",
" def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n",
" all_address_flag = reform_idx(sol.get_x(), pred_res)\n",
" lefted_idxs = [i for i in range(len(pred_res))]\n",
" candidate_size = [] \n",
" while lefted_idxs:\n",
" idxs = []\n",
" idxs.append(lefted_idxs.pop(0))\n",
" max_candidate_idxs = []\n",
" found = False\n",
" for idx in range(-1, len(pred_res)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n",
" if len(candidate) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",
" else:\n",
" if len(idxs) > len(max_candidate_idxs):\n",
" found = True\n",
" max_candidate_idxs = idxs.copy() \n",
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n",
" if found:\n",
" candidate_size.append(len(removed) + 1)\n",
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] \n",
" candidate_size.sort()\n",
" score = 0\n",
" import math\n",
" for i in range(0, len(candidate_size)):\n",
" score -= math.exp(-i) * candidate_size[i]\n",
" return score\n",
"\n",
" def abduce_rules(self, pred_res):\n",
" return self.kb.abduce_rules(pred_res)\n",
" \n",
"abducer = HED_Abducer(kb)"
]
},


+ 31
- 1
examples/hwf/hwf_example.ipynb View File

@@ -10,11 +10,12 @@
"\n",
"sys.path.append(\"../../\")\n",
"\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.abducer.abducer_base import AbducerBase\n",
"from abl.abducer.kb import HWF_KB\n",
"from abl.abducer.kb import KBBase\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
@@ -50,6 +51,35 @@
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class HWF_KB(KBBase):\n",
" def __init__(\n",
" self, \n",
" pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], \n",
" len_list=[1, 3, 5, 7],\n",
" GKB_flag=False,\n",
" max_err=1e-3,\n",
" use_cache=True\n",
" ):\n",
" super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n",
"\n",
" def _valid_candidate(self, formula):\n",
" if len(formula) % 2 == 0:\n",
" return False\n",
" for i in range(len(formula)):\n",
" if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:\n",
" return False\n",
" if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:\n",
" return False\n",
" return True\n",
"\n",
" def logic_forward(self, formula):\n",
" if not self._valid_candidate(formula):\n",
" return np.inf\n",
" mapping = {str(i): str(i) for i in range(1, 10)}\n",
" mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})\n",
" formula = [mapping[f] for f in formula]\n",
" return eval(''.join(formula))\n",
"\n",
"kb = HWF_KB(GKB_flag=True)\n",
"abducer = AbducerBase(kb)"
]


+ 10
- 1
examples/mnist_add/mnist_add_example.ipynb View File

@@ -14,7 +14,7 @@
"import torch\n",
"\n",
"from abl.abducer.abducer_base import AbducerBase\n",
"from abl.abducer.kb import add_KB\n",
"from abl.abducer.kb import KBBase, prolog_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
@@ -50,7 +50,16 @@
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"class add_KB(KBBase):\n",
" def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False, max_err=0, use_cache=True):\n",
" super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)\n",
"\n",
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"kb = add_KB(GKB_flag=True)\n",
"\n",
"# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')\n",
"abducer = AbducerBase(kb, dist_func=\"confidence\")"
]
},


Loading…
Cancel
Save