Browse Source

Merge branch 'Dev' of https://github.com/AbductiveLearning/ABL-Package into Dev

pull/3/head
Gao Enhao 3 years ago
parent
commit
057be2a315
28 changed files with 361 additions and 428 deletions
  1. +0
    -0
      abl/__init__.py
  2. +0
    -0
      abl/abducer/__init__.py
  3. +103
    -110
      abl/abducer/abducer_base.py
  4. +216
    -270
      abl/abducer/kb.py
  5. +1
    -1
      abl/framework.py
  6. +11
    -8
      abl/framework_hed.py
  7. +0
    -0
      abl/models/__init__.py
  8. +0
    -0
      abl/models/basic_model.py
  9. +0
    -0
      abl/models/lenet5.py
  10. +0
    -5
      abl/models/nn.py
  11. +0
    -1
      abl/models/wabl_models.py
  12. +0
    -0
      abl/utils/plog.py
  13. +14
    -22
      abl/utils/utils.py
  14. +0
    -0
      examples/datasets/data_generator.py
  15. +0
    -0
      examples/datasets/hed/BK.pl
  16. +0
    -0
      examples/datasets/hed/README.md
  17. +0
    -0
      examples/datasets/hed/get_hed.py
  18. +3
    -0
      examples/datasets/hed/learn_add.pl
  19. +0
    -0
      examples/datasets/hwf/README.md
  20. +0
    -0
      examples/datasets/hwf/get_hwf.py
  21. +2
    -0
      examples/datasets/mnist_add/add.pl
  22. +0
    -0
      examples/datasets/mnist_add/get_mnist_add.py
  23. +0
    -0
      examples/datasets/mnist_add/test_data.txt
  24. +0
    -0
      examples/datasets/mnist_add/train_data.txt
  25. +11
    -11
      examples/example.py
  26. +0
    -0
      examples/nonshare_example.py
  27. +0
    -0
      examples/share_example.py
  28. +0
    -0
      examples/weights/all_weights_here.txt

weights/all_weights_here.txt → abl/__init__.py View File


+ 0
- 0
abl/abducer/__init__.py View File


abducer/abducer_base.py → abl/abducer/abducer_base.py View File

@@ -10,32 +10,15 @@
# #
# ================================================================# # ================================================================#


import sys

sys.path.append(".")
sys.path.append("..")

import abc import abc
from abducer.kb import *
import numpy as np import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt from zoopt import Dimension, Objective, Parameter, Opt
from utils.utils import confidence_dist, flatten, hamming_dist

import math
import time

from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist


class AbducerBase(abc.ABC): class AbducerBase(abc.ABC):
def __init__(
self,
kb,
dist_func="confidence",
zoopt=False,
multiple_predictions=False,
cache=True,
):
def __init__(self, kb, dist_func='confidence', zoopt=False, multiple_predictions=False, cache=True):
self.kb = kb self.kb = kb
assert dist_func == "hamming" or dist_func == "confidence"
assert dist_func == 'hamming' or dist_func == 'confidence'
self.dist_func = dist_func self.dist_func = dist_func
self.zoopt = zoopt self.zoopt = zoopt
self.multiple_predictions = multiple_predictions self.multiple_predictions = multiple_predictions
@@ -46,41 +29,42 @@ class AbducerBase(abc.ABC):
self.cache_candidates = {} self.cache_candidates = {}


def _get_cost_list(self, pred_res, pred_res_prob, candidates): def _get_cost_list(self, pred_res, pred_res_prob, candidates):
if self.dist_func == "hamming":
if self.dist_func == 'hamming':
if self.multiple_predictions:
pred_res = flatten(pred_res)
candidates = [flatten(c) for c in candidates]
return hamming_dist(pred_res, candidates) return hamming_dist(pred_res, candidates)
elif self.dist_func == "confidence":
mapping = dict(
zip(
self.kb.pseudo_label_list,
list(range(len(self.kb.pseudo_label_list))),
)
)
return confidence_dist(
pred_res_prob, [list(map(lambda x: mapping[x], c)) for c in candidates]
)
elif self.dist_func == 'confidence':
if self.multiple_predictions:
pred_res_prob = flatten(pred_res_prob)
candidates = [flatten(c) for c in candidates]
mapping = dict(zip(self.kb.pseudo_label_list, list(range(len(self.kb.pseudo_label_list)))))
candidates = [list(map(lambda x: mapping[x], c)) for c in candidates]
return confidence_dist(pred_res_prob, candidates)


def _get_one_candidate(self, pred_res, pred_res_prob, candidates): def _get_one_candidate(self, pred_res, pred_res_prob, candidates):
if len(candidates) == 0: if len(candidates) == 0:
return [] return []
elif len(candidates) == 1 or self.zoopt: elif len(candidates) == 1 or self.zoopt:
return candidates[0] return candidates[0]
else: else:
cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates) cost_list = self._get_cost_list(pred_res, pred_res_prob, candidates)
min_address_num = np.min(cost_list) min_address_num = np.min(cost_list)
idxs = np.where(cost_list == min_address_num)[0] idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]
candidate = [candidates[idx] for idx in idxs][0]
return candidate


# for zoopt # for zoopt
def _zoopt_score_multiple(self, pred_res, key, solution): def _zoopt_score_multiple(self, pred_res, key, solution):
all_address_flag = reform_idx(solution, pred_res) all_address_flag = reform_idx(solution, pred_res)
score = 0 score = 0
for idx in range(len(pred_res)): for idx in range(len(pred_res)):
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
candidate = self.kb.address_by_idx(
[pred_res[idx]], key[idx], address_idx, True
)
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = self.address_by_idx([pred_res[idx]], key[idx], address_idx)
if len(candidate) > 0: if len(candidate) > 0:
score += 1 score += 1
return score return score
@@ -88,9 +72,7 @@ class AbducerBase(abc.ABC):
def _zoopt_address_score(self, pred_res, key, sol): def _zoopt_address_score(self, pred_res, key, sol):
if not self.multiple_predictions: if not self.multiple_predictions:
address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0] address_idx = [idx for idx, i in enumerate(sol.get_x()) if i != 0]
candidates = self.kb.address_by_idx(
pred_res, key, address_idx, self.multiple_predictions
)
candidates = self.address_by_idx(pred_res, key, address_idx)
return 1 if len(candidates) > 0 else 0 return 1 if len(candidates) > 0 else 0
else: else:
return self._zoopt_score_multiple(pred_res, key, sol.get_x()) return self._zoopt_score_multiple(pred_res, key, sol.get_x())
@@ -107,7 +89,7 @@ class AbducerBase(abc.ABC):
dim=dimension, dim=dimension,
constraint=lambda sol: self._constrain_address_num(sol, max_address_num), constraint=lambda sol: self._constrain_address_num(sol, max_address_num),
) )
parameter = Parameter(budget=100, autoset=True)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x() solution = Opt.min(objective, parameter).get_x()


return solution return solution
@@ -118,11 +100,7 @@ class AbducerBase(abc.ABC):
pred_res = flatten(pred_res) pred_res = flatten(pred_res)
key = tuple(key) key = tuple(key)
if (tuple(pred_res), key) in self.cache_min_address_num: if (tuple(pred_res), key) in self.cache_min_address_num:
address_num = min(
max_address_num,
self.cache_min_address_num[(tuple(pred_res), key)]
+ require_more_address,
)
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), key)] + require_more_address)
if (tuple(pred_res), key, address_num) in self.cache_candidates: if (tuple(pred_res), key, address_num) in self.cache_candidates:
candidates = self.cache_candidates[(tuple(pred_res), key, address_num)] candidates = self.cache_candidates[(tuple(pred_res), key, address_num)]
if self.zoopt: if self.zoopt:
@@ -137,6 +115,9 @@ class AbducerBase(abc.ABC):
key = tuple(key) key = tuple(key)
self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num self.cache_min_address_num[(tuple(pred_res), key)] = min_address_num
self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates self.cache_candidates[(tuple(pred_res), key, address_num)] = candidates
def address_by_idx(self, pred_res, key, address_idx):
return self.kb.address_by_idx(pred_res, key, address_idx, self.multiple_predictions)


def abduce(self, data, max_address_num=-1, require_more_address=0): def abduce(self, data, max_address_num=-1, require_more_address=0):
pred_res, pred_res_prob, key = data pred_res, pred_res_prob, key = data
@@ -151,18 +132,12 @@ class AbducerBase(abc.ABC):
if self.zoopt: if self.zoopt:
solution = self.zoopt_get_solution(pred_res, key, max_address_num) solution = self.zoopt_get_solution(pred_res, key, max_address_num)
address_idx = [idx for idx, i in enumerate(solution) if i != 0] address_idx = [idx for idx, i in enumerate(solution) if i != 0]
candidates = self.kb.address_by_idx(
pred_res, key, address_idx, self.multiple_predictions
)
candidates = self.address_by_idx(pred_res, key, address_idx)
address_num = int(solution.sum()) address_num = int(solution.sum())
min_address_num = address_num min_address_num = address_num
else: else:
candidates, min_address_num, address_num = self.kb.abduce_candidates( candidates, min_address_num, address_num = self.kb.abduce_candidates(
pred_res,
key,
max_address_num,
require_more_address,
self.multiple_predictions,
pred_res, key, max_address_num, require_more_address, self.multiple_predictions
) )


candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
@@ -176,32 +151,22 @@ class AbducerBase(abc.ABC):
return self.kb.abduce_rules(pred_res) return self.kb.abduce_rules(pred_res)


def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0): def batch_abduce(self, Z, Y, max_address_num=-1, require_more_address=0):
if self.multiple_predictions:
return self.abduce(
(Z["cls"], Z["prob"], Y), max_address_num, require_more_address
)
else:
return [
self.abduce((z, prob, y), max_address_num, require_more_address)
for z, prob, y in zip(Z["cls"], Z["prob"], Y)
]
# if self.multiple_predictions:
return self.abduce((Z['cls'], Z['prob'], Y), max_address_num, require_more_address)
# else:
# return [self.abduce((z, prob, y), max_address_num, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]


def __call__(self, Z, Y, max_address_num=-1, require_more_address=0): def __call__(self, Z, Y, max_address_num=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address_num, require_more_address) return self.batch_abduce(Z, Y, max_address_num, require_more_address)


if __name__ == '__main__':
from kb import add_KB, prolog_KB, HWF_KB
prob1 = [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
prob2 = [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]


if __name__ == "__main__":
prob1 = [
[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]
prob2 = [
[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
]

kb = add_KB()
abd = AbducerBase(kb, "confidence")
kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -213,9 +178,23 @@ if __name__ == "__main__":
res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 20), max_address_num=2, require_more_address=0)
print(res) print(res)
print() print()
multiple_prob = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
kb = add_KB()
abd = AbducerBase(kb, 'confidence', multiple_predictions=True)
res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=0)
print(res)
res = abd.abduce(([[1, 1], [1, 2]], multiple_prob, [4, 8]), max_address_num=4, require_more_address=1)
print(res)
print()


kb = add_prolog_KB()
abd = AbducerBase(kb, "confidence")
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence')
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -228,8 +207,8 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


kb = add_prolog_KB()
abd = AbducerBase(kb, "confidence", zoopt=True)
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence', zoopt=True)
res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob1, 8), max_address_num=2, require_more_address=0)
print(res) print(res)
res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0) res = abd.abduce(([1, 1], prob2, 8), max_address_num=2, require_more_address=0)
@@ -242,49 +221,63 @@ if __name__ == "__main__":
print(res) print(res)
print() print()


kb = HWF_KB(len_list=[1, 3, 5])
abd = AbducerBase(kb, "hamming")
res = abd.abduce(
(["5", "+", "2"], None, 3), max_address_num=2, require_more_address=0
)
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res)
print()
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0)
print(res)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3)
print(res)
print()
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming', multiple_predictions=True)
res = abd.abduce(([['5', '+', '2'], ['5', '+', '9']], None, [3, 64]), max_address_num=6, require_more_address=0)
print(res)
print()
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '2'], None, 3), max_address_num=2, require_more_address=0)
print(res)
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res) print(res)
res = abd.abduce(
(["5", "+", "2"], None, 64), max_address_num=3, require_more_address=0
)
kb = HWF_KB(len_list=[1, 3, 5], max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.abduce((['5', '+', '9'], None, 64), max_address_num=3, require_more_address=0)
print(res) print(res)
res = abd.abduce(
(["5", "+", "2"], None, 1.67), max_address_num=3, require_more_address=0
)
res = abd.abduce((['5', '+', '2'], None, 1.67), max_address_num=3, require_more_address=0)
print(res) print(res)
res = abd.abduce(
(["5", "8", "8", "8", "8"], None, 3.17),
max_address_num=5,
require_more_address=3,
)
res = abd.abduce((['5', '8', '8', '8', '8'], None, 3.17), max_address_num=5, require_more_address=3)
print(res) print(res)
print() print()


kb = HED_prolog_KB()
kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = AbducerBase(kb, zoopt=True, multiple_predictions=True) abd = AbducerBase(kb, zoopt=True, multiple_predictions=True)
consist_exs = [[1, "+", 0, "=", 0], [1, "+", 1, "=", 0], [0, "+", 0, "=", 1, 1]]
consist_exs2 = [
[1, "+", 0, "=", 0],
[1, "+", 1, "=", 0],
[0, "+", 1, "=", 1, 1],
] # not consistent with rules
inconsist_exs = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]]
consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]]
inconsist_exs = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
# inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']] # inconsist_exs = [[1, '+', 0, '=', 0], ['=', '=', '=', '=', 0], ['=', '=', 0, '=', '=', '=']]
rules = ["my_op([0], [0], [1, 1])", "my_op([1], [1], [0])", "my_op([1], [0], [0])"]
rules = ['my_op([0], [0], [0])', 'my_op([1], [1], [1, 0])']


print(kb.logic_forward(consist_exs), kb.logic_forward(inconsist_exs))
print(kb.consist_rule(consist_exs, rules), kb.consist_rule(consist_exs2, rules))
print(kb._logic_forward(consist_exs, True), kb._logic_forward(inconsist_exs, True))
print(kb.consist_rule([1, '+', 1, '=', 1, 0], rules), kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print() print()


res = abd.abduce((consist_exs, None, [1] * len(consist_exs)))
res = abd.abduce((consist_exs, None, [None] * len(consist_exs)))
print(res) print(res)
res = abd.abduce((inconsist_exs, None, [1] * len(consist_exs)))
res = abd.abduce((inconsist_exs, None, [None] * len(inconsist_exs)))
print(res) print(res)
print() print()


abduced_rules = abd.abduce_rules(consist_exs) abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)
print(abduced_rules)

abducer/kb.py → abl/abducer/kb.py View File

@@ -15,93 +15,26 @@ import bisect
import copy import copy
import numpy as np import numpy as np


import sys

sys.path.append("..")

from collections import defaultdict from collections import defaultdict
from itertools import product, combinations from itertools import product, combinations
from utils.utils import flatten, reform_idx, hamming_dist, check_equal
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal


from multiprocessing import Pool from multiprocessing import Pool


import pyswip import pyswip



class KBBase(ABC): class KBBase(ABC):
def __init__(self, pseudo_label_list=None):
pass

@abstractmethod
def logic_forward(self):
pass

@abstractmethod
def abduce_candidates(self):
pass
@abstractmethod
def address_by_idx(self):
pass

def _address(self, address_num, pred_res, key, multiple_predictions=False):
new_candidates = []
if not multiple_predictions:
address_idx_list = list(combinations(list(range(len(pred_res))), address_num))
else:
address_idx_list = list(combinations(list(range(len(flatten(pred_res)))), address_num))

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

if len(candidates) > 0:
min_address_num = address_num
break

if address_num >= max_address_num:
return [], 0, 0

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

return candidates, min_address_num, address_num

def __len__(self):
pass


class ClsKB(KBBase):
def __init__(self, GKB_flag=False, pseudo_label_list=None, len_list=None):
super().__init__()
self.GKB_flag = GKB_flag
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0):
self.pseudo_label_list = pseudo_label_list self.pseudo_label_list = pseudo_label_list
self.len_list = len_list self.len_list = len_list
self.GKB_flag = GKB_flag
self.max_err = max_err


if GKB_flag: if GKB_flag:
self.base = {} self.base = {}
X, Y = self._get_GKB() X, Y = self._get_GKB()
for x, y in zip(X, Y): for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x) self.base.setdefault(len(x), defaultdict(list))[y].append(x)
else:
self.all_address_candidate_dict = {}
for address_num in range(max(self.len_list) + 1):
self.all_address_candidate_dict[address_num] = list(product(self.pseudo_label_list, repeat=address_num))


# For parallel version of _get_GKB # For parallel version of _get_GKB
def _get_XY_list(self, args): def _get_XY_list(self, args):
@@ -130,39 +63,80 @@ class ClsKB(KBBase):
part_X, part_Y = zip(*XY_list) part_X, part_Y = zip(*XY_list)
X.extend(part_X) X.extend(part_X)
Y.extend(part_Y) Y.extend(part_Y)
if type(Y[0]) in (int, float):
sorted_XY = sorted(list(zip(Y, X)))
X = [x for y, x in sorted_XY]
Y = [y for y, x in sorted_XY]
return X, Y return X, Y


def logic_forward(self):
@abstractmethod
def logic_forward(self, pseudo_labels):
pass pass
def _logic_forward(self, xs, multiple_predictions=False):
if not multiple_predictions:
return self.logic_forward(xs)
else:
res = [self.logic_forward(x) for x in xs]
return res


def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False):
def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0, multiple_predictions=False):
if self.GKB_flag: if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address, multiple_predictions)
else: else:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)
@abstractmethod
def _find_candidate_GKB(self, pred_res, key):
pass
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
if self.base == {}:
return [], 0, 0


def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
if self.base == {} or len(pred_res) not in self.len_list:
return []

all_candidates = self.base[len(pred_res)][key]

if len(all_candidates) == 0:
candidates = []
min_address_num = 0
address_num = 0
if not multiple_predictions:
if len(pred_res) not in self.len_list:
return [], 0, 0
all_candidates = self._find_candidate_GKB(pred_res, key)
if len(all_candidates) == 0:
return [], 0, 0
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates, min_address_num, address_num
else: else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
min_address_num = 0
all_candidates_save = []
cost_list_save = []
for p_res, k in zip(pred_res, key):
if len(p_res) not in self.len_list:
return [], 0, 0
all_candidates = self._find_candidate_GKB(p_res, k)
if len(all_candidates) == 0:
return [], 0, 0
else:
all_candidates_save.append(all_candidates)
cost_list = hamming_dist(p_res, all_candidates)
min_address_num += np.min(cost_list)
cost_list_save.append(cost_list)
multiple_all_candidates = [flatten(c) for c in product(*all_candidates_save)]
assert len(multiple_all_candidates[0]) == len(flatten(pred_res))
multiple_cost_list = np.array([sum(cost) for cost in product(*cost_list_save)])
assert len(multiple_all_candidates) == len(multiple_cost_list)
address_num = min(max_address_num, min_address_num + require_more_address) address_num = min(max_address_num, min_address_num + require_more_address)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]

return candidates, min_address_num, address_num

idxs = np.where(multiple_cost_list <= address_num)[0]
candidates = [reform_idx(multiple_all_candidates[idx], pred_res) for idx in idxs]
return candidates, min_address_num, address_num
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = [] candidates = []
abduce_c = self.all_address_candidate_dict[len(address_idx)]
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))


if multiple_predictions: if multiple_predictions:
save_pred_res = pred_res save_pred_res = pred_res
@@ -176,10 +150,48 @@ class ClsKB(KBBase):
if multiple_predictions: if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res) candidate = reform_idx(candidate, save_pred_res)


if self.logic_forward(candidate) == key:
if check_equal(self._logic_forward(candidate, multiple_predictions), key, self.max_err):
candidates.append(candidate) candidates.append(candidate)
return candidates return candidates


def _address(self, address_num, pred_res, key, multiple_predictions):
new_candidates = []
if not multiple_predictions:
address_idx_list = combinations(list(range(len(pred_res))), address_num)
else:
address_idx_list = combinations(list(range(len(flatten(pred_res)))), address_num)

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx, multiple_predictions)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
candidates = []

for address_num in range(len(flatten(pred_res)) + 1):
if address_num == 0:
if check_equal(self._logic_forward(pred_res, multiple_predictions), key, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

if len(candidates) > 0:
min_address_num = address_num
break

if address_num >= max_address_num:
return [], 0, 0

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
return candidates, min_address_num, address_num - 1
new_candidates = self._address(address_num, pred_res, key, multiple_predictions)
candidates += new_candidates

return candidates, min_address_num, address_num

def _dict_len(self, dic): def _dict_len(self, dic):
if not self.GKB_flag: if not self.GKB_flag:
return 0 return 0
@@ -193,130 +205,77 @@ class ClsKB(KBBase):
return sum(self._dict_len(v) for v in self.base.values()) return sum(self._dict_len(v) for v in self.base.values())




class add_KB(ClsKB):
def __init__(self, GKB_flag=False, pseudo_label_list=list(range(10)), len_list=[2]):
super().__init__(GKB_flag, pseudo_label_list, len_list)

def logic_forward(self, nums):
return sum(nums)
class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list, GKB_flag):
super().__init__(pseudo_label_list, len_list, GKB_flag)


def _find_candidate_GKB(self, pred_res, key):
return self.base[len(pred_res)][key]


class HWF_KB(ClsKB):
def __init__(
self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7]
):
super().__init__(GKB_flag, pseudo_label_list, len_list)


def valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True
class add_KB(ClsKB):
def __init__(self, pseudo_label_list=list(range(10)), len_list=[2], GKB_flag=False):
super().__init__(pseudo_label_list, len_list, GKB_flag)


def logic_forward(self, formula):
if not self.valid_candidate(formula):
return np.inf
mapping = {
'1': '1',
'2': '2',
'3': '3',
'4': '4',
'5': '5',
'6': '6',
'7': '7',
'8': '8',
'9': '9',
'+': '+',
'-': '-',
'times': '*',
'div': '/',
}
formula = [mapping[f] for f in formula]
return round(eval(''.join(formula)), 2)
def logic_forward(self, nums):
return sum(nums)




class prolog_KB(KBBase): class prolog_KB(KBBase):
def __init__(self, pseudo_label_list):
super().__init__()
self.pseudo_label_list = pseudo_label_list
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list)
self.prolog = pyswip.Prolog() self.prolog = pyswip.Prolog()
self.prolog.consult(pl_file)


def logic_forward(self):
pass

def abduce_candidates(self, pred_res, key, max_address_num, require_more_address, multiple_predictions):
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions)
def logic_forward(self, pseudo_labels):
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_labels))[0]['Res']
if result == 'true':
return True
elif result == 'false':
return False
return result
def _address_pred_res(self, pred_res, address_idx, multiple_predictions):
import re
address_pred_res = pred_res.copy()
if multiple_predictions:
address_pred_res = flatten(address_pred_res)
for idx in address_idx:
address_pred_res[idx] = 'P' + str(idx)
if multiple_predictions:
address_pred_res = reform_idx(address_pred_res, pred_res)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
def get_query_string(self, pred_res, key, address_idx, multiple_predictions):
query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx, multiple_predictions)
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
return query_string


def _find_candidate_GKB(self, pred_res, key):
pass
def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False): def address_by_idx(self, pred_res, key, address_idx, multiple_predictions=False):
candidates = [] candidates = []
# print(address_idx)
if not multiple_predictions:
query_string = self.get_query_string(pred_res, key, address_idx)
else:
query_string = self.get_query_string_need_flatten(pred_res, key, address_idx)

query_string = self.get_query_string(pred_res, key, address_idx, multiple_predictions)
if multiple_predictions: if multiple_predictions:
save_pred_res = pred_res save_pred_res = pred_res
pred_res = flatten(pred_res) pred_res = flatten(pred_res)

abduce_c = [list(z.values()) for z in list(self.prolog.query(query_string))]
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c: for c in abduce_c:
candidate = pred_res.copy() candidate = pred_res.copy()
for i, idx in enumerate(address_idx): for i, idx in enumerate(address_idx):
candidate[idx] = c[i] candidate[idx] = c[i]

if multiple_predictions: if multiple_predictions:
candidate = reform_idx(candidate, save_pred_res) candidate = reform_idx(candidate, save_pred_res)

candidates.append(candidate) candidates.append(candidate)
return candidates return candidates



class add_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list=list(range(10))):
super().__init__(pseudo_label_list)
for i in self.pseudo_label_list:
self.prolog.assertz("pseudo_label(%s)" % i)
self.prolog.assertz("addition(Z1, Z2, Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2")

def logic_forward(self, nums):
return list(self.prolog.query("addition(%s, %s, Res)." % (nums[0], nums[1])))[0]['Res']

def get_query_string(self, pred_res, key, address_idx):
query_string = "addition("
for idx, i in enumerate(pred_res):
tmp = 'Z' + str(idx) + ',' if idx in address_idx else str(i) + ','
query_string += tmp
query_string += "%s)." % key
return query_string


class HED_prolog_KB(prolog_KB):
def __init__(self, pseudo_label_list=[0, 1, '+', '=']):
super().__init__(pseudo_label_list)
self.prolog.consult('./datasets/hed/learn_add.pl')

# corresponding to `con_sol is not None` in `consistent_score_mapped` within `learn_add.py`
def logic_forward(self, exs):
return len(list(self.prolog.query("abduce_consistent_insts(%s)." % exs))) != 0

def get_query_string_need_flatten(self, pred_res, key, address_idx):
# flatten
flatten_pred_res = flatten(pred_res)
# add variables for prolog
for idx in range(len(flatten_pred_res)):
if idx in address_idx:
flatten_pred_res[idx] = 'X' + str(idx)
# unflatten
new_pred_res = reform_idx(flatten_pred_res, pred_res)

query_string = "abduce_consistent_insts(%s)." % new_pred_res
return query_string.replace("'", "").replace("+", "'+'").replace("=", "'='")

def consist_rule(self, exs, rules): def consist_rule(self, exs, rules):
rules = str(rules).replace("\'","") rules = str(rules).replace("\'","")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0
@@ -327,92 +286,79 @@ class HED_prolog_KB(prolog_KB):
if len(prolog_result) == 0: if len(prolog_result) == 0:
return None return None
prolog_rules = prolog_result[0]['X'] prolog_rules = prolog_result[0]['X']
rules = []
for rule in prolog_rules:
rules.append(rule.value)
rules = [rule.value for rule in prolog_rules]
return rules return rules


# def consist_rules(self, pred_res, rules):



class RegKB(KBBase): class RegKB(KBBase):
def __init__(self, GKB_flag=False, X=None, Y=None):
super().__init__()
tmp_dict = {}
for x, y in zip(X, Y):
tmp_dict.setdefault(len(x), defaultdict(list))[y].append(np.array(x))
self.base = {}
for l in tmp_dict.keys():
data = sorted(list(zip(tmp_dict[l].keys(), tmp_dict[l].values())))
X = [x for y, x in data]
Y = [y for y, x in data]
self.base[l] = (X, Y)
def valid_candidate(self):
pass
def logic_forward(self):
pass
def abduce_candidates(self, key, length=None):
if key is None:
return self.get_all_candidates()
length = self._length(length)
def __init__(self, pseudo_label_list=None, len_list=None, GKB_flag=False, max_err=1e-3):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)
def _find_candidate_GKB(self, pred_res, key):
potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, key)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - key) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
return all_candidates


min_err = 999999
candidates = []
for l in length:
X, Y = self.base[l]

idx = bisect.bisect_left(Y, key)
begin = max(0, idx - 1)
end = min(idx + 2, len(X))

for idx in range(begin, end):
err = abs(Y[idx] - key)
if abs(err - min_err) < 1e-9:
candidates.extend(X[idx])
elif err < min_err:
candidates = copy.deepcopy(X[idx])
min_err = err
return candidates
class HWF_KB(RegKB):
def __init__(
self,
pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'],
len_list=[1, 3, 5, 7],
GKB_flag=False,
max_err=1e-3
):
super().__init__(pseudo_label_list, len_list, GKB_flag, max_err)


def get_all_candidates(self):
return sum([sum(D[0], []) for D in self.base.values()], [])
def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
return False
if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
return False
return True


def __len__(self):
return sum([sum(len(x) for x in D[0]) for D in self.base.values()])
def logic_forward(self, formula):
if not self._valid_candidate(formula):
return np.inf
mapping = {
'1': '1',
'2': '2',
'3': '3',
'4': '4',
'5': '5',
'6': '6',
'7': '7',
'8': '8',
'9': '9',
'+': '+',
'-': '-',
'times': '*',
'div': '/',
}
formula = [mapping[f] for f in formula]
return eval(''.join(formula))




import time import time


if __name__ == "__main__": if __name__ == "__main__":
t1 = time.time()
kb = HWF_KB(True)
t2 = time.time()
print(t2 - t1)

# X = ["1+1", "0+1", "1+0", "2+0", "1+0+1"]
# Y = [2, 1, 1, 2, 2]
# kb = ClsKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(2, 5)
# print(res)
# res = kb.get_candidates(2, 3)
# print(res)
# res = kb.get_candidates(None)
# print(res)
# print()

# X = ["1+1", "0+1", "1+0", "2+0", "1+0.5", "0.75+0.75"]
# Y = [2, 1, 1, 2, 1.5, 1.5]
# kb = RegKB(X, Y)
# print('len(kb):', len(kb))
# res = kb.get_candidates(1.6)
# print(res)
# res = kb.get_candidates(1.6, length = 9)
# print(res)
# res = kb.get_candidates(None)
# print(res)
pass

framework.py → abl/framework.py View File

@@ -14,7 +14,7 @@ import pickle as pk


import numpy as np import numpy as np


from utils.plog import INFO, DEBUG, clocker
from .utils.plog import INFO, DEBUG, clocker


def block_sample(X, Z, Y, sample_num, epoch_idx): def block_sample(X, Z, Y, sample_num, epoch_idx):
part_num = (len(X) // sample_num) part_num = (len(X) // sample_num)

framework_hed.py → abl/framework_hed.py View File

@@ -16,12 +16,15 @@ import torch.nn as nn
import numpy as np import numpy as np
import os import os


from utils.plog import INFO, DEBUG, clocker
from utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res
from .utils.plog import INFO, DEBUG, clocker
from .utils.utils import flatten, reform_idx, block_sample, gen_mappings, mapping_res, remapping_res


from models.nn import MLP, SymbolNetAutoencoder
from models.basic_model import BasicModel, BasicDataset
from datasets.hed.get_hed import get_pretrain_data
from .models.nn import MLP, SymbolNetAutoencoder
from .models.basic_model import BasicModel, BasicDataset

import sys
sys.path.append("..")
from examples.datasets.hed.get_hed import get_pretrain_data


def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
result = {} result = {}
@@ -147,7 +150,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for m in mappings: for m in mappings:
pred_res = mapping_res(original_pred_res, m) pred_res = mapping_res(original_pred_res, m)
max_abduce_num = 20 max_abduce_num = 20
solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num)
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), max_abduce_num)
all_address_flag = reform_idx(solution, pred_res) all_address_flag = reform_idx(solution, pred_res)


consistent_idx_tmp = [] consistent_idx_tmp = []
@@ -155,7 +158,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
for idx in range(len(pred_res)): for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True)
candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx)
if len(candidate) > 0: if len(candidate) > 0:
consistent_idx_tmp.append(idx) consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0]) consistent_pred_res_tmp.append(candidate[0][0])
@@ -211,7 +214,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule,
consistent_idx = [] consistent_idx = []
consistent_pred_res = [] consistent_pred_res = []
for idx in range(len(pred_res)): for idx in range(len(pred_res)):
if abducer.kb.logic_forward([pred_res[idx]]):
if abducer.kb.logic_forward(pred_res[idx]):
consistent_idx.append(idx) consistent_idx.append(idx)
consistent_pred_res.append(pred_res[idx]) consistent_pred_res.append(pred_res[idx])



+ 0
- 0
abl/models/__init__.py View File


models/basic_model.py → abl/models/basic_model.py View File


models/lenet5.py → abl/models/lenet5.py View File


models/nn.py → abl/models/nn.py View File

@@ -10,9 +10,6 @@
# #
# ================================================================# # ================================================================#


import sys

sys.path.append("..")


import torchvision import torchvision


@@ -23,8 +20,6 @@ from torch.autograd import Variable
import torchvision.transforms as transforms import torchvision.transforms as transforms
import numpy as np import numpy as np


from models.basic_model import BasicModel
import utils.plog as plog




class LeNet5(nn.Module): class LeNet5(nn.Module):

models/wabl_models.py → abl/models/wabl_models.py View File

@@ -21,7 +21,6 @@ from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF from sklearn.gaussian_process.kernels import RBF
from models.basic_model import BasicModel


import pickle as pk import pickle as pk
import random import random

utils/plog.py → abl/utils/plog.py View File


utils/utils.py → abl/utils/utils.py View File

@@ -1,30 +1,23 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from utils.plog import INFO
from .plog import INFO
from collections import OrderedDict from collections import OrderedDict
from itertools import chain


# for multiple predictions, modify from `learn_add.py`
# for multiple predictions
def flatten(l): def flatten(l):
return (
[item for sublist in l for item in flatten(sublist)]
if isinstance(l, list)
else [l]
)


# for multiple predictions, modify from `learn_add.py`
if not isinstance(l[0], (list, tuple)):
return l
return list(chain.from_iterable(l))
# for multiple predictions
def reform_idx(flatten_pred_res, save_pred_res): def reform_idx(flatten_pred_res, save_pred_res):
re = [] re = []
i = 0 i = 0
for e in save_pred_res: for e in save_pred_res:
j = 0
idx = []
while j < len(e):
idx.append(flatten_pred_res[i + j])
j += 1
re.append(idx)
i = i + j
re.append(flatten_pred_res[i:i + len(e)])
i += len(e)
return re return re




@@ -85,11 +78,10 @@ def remapping_res(pred_res, m):
remapping[value] = key remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res] return [[remapping[symbol] for symbol in formula] for formula in pred_res]



def check_equal(a, b):
def check_equal(a, b, max_err=0):
if isinstance(a, (int, float)) and isinstance(b, (int, float)): if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return abs(a - b) <= 1e-3
return abs(a - b) <= max_err
if isinstance(a, list) and isinstance(b, list): if isinstance(a, list) and isinstance(b, list):
if len(a) != len(b): if len(a) != len(b):
return False return False
@@ -119,4 +111,4 @@ def reduce_dimension(data):
[extract_feature(symbol_img) for symbol_img in equation] [extract_feature(symbol_img) for symbol_img in equation]
for equation in equations for equation in equations
] ]
data[truth_value][equation_len] = reduced_equations
data[truth_value][equation_len] = reduced_equations

datasets/data_generator.py → examples/datasets/data_generator.py View File


datasets/hed/BK.pl → examples/datasets/hed/BK.pl View File


datasets/hed/README.md → examples/datasets/hed/README.md View File


datasets/hed/get_hed.py → examples/datasets/hed/get_hed.py View File


datasets/hed/learn_add.pl → examples/datasets/hed/learn_add.pl View File

@@ -32,6 +32,9 @@ abduce_consistent_insts(Exs):-
% (Experimental) Uncomment to use parallel abduction % (Experimental) Uncomment to use parallel abduction
% abduce_consistent_exs_concurrent(Exs), !. % abduce_consistent_exs_concurrent(Exs), !.
logic_forward(Exs, X) :- abduce_consistent_insts([Exs]) -> X = true ; X = false.
logic_forward(Exs) :- abduce_consistent_insts(Exs).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Abduce Delta_C given pseudo-labels %% Abduce Delta_C given pseudo-labels
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

datasets/hwf/README.md → examples/datasets/hwf/README.md View File


datasets/hwf/get_hwf.py → examples/datasets/hwf/get_hwf.py View File


+ 2
- 0
examples/datasets/mnist_add/add.pl View File

@@ -0,0 +1,2 @@
pseudo_label(N) :- between(0, 9, N).
logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2.

datasets/mnist_add/get_mnist_add.py → examples/datasets/mnist_add/get_mnist_add.py View File


datasets/mnist_add/test_data.txt → examples/datasets/mnist_add/test_data.txt View File


datasets/mnist_add/train_data.txt → examples/datasets/mnist_add/train_data.txt View File


example.py → examples/example.py View File

@@ -10,24 +10,24 @@
# #
# ================================================================# # ================================================================#


from utils.plog import logger, INFO
from utils.utils import reduce_dimension
import sys
sys.path.append("../")

from abl.utils.plog import logger, INFO
import torch.nn as nn import torch.nn as nn
import torch import torch


from models.nn import LeNet5, SymbolNet
from models.basic_model import BasicModel, BasicDataset
from models.wabl_models import DecisionTree, WABLBasicModel
from sklearn.neighbors import KNeighborsClassifier
from abl.models.nn import LeNet5, SymbolNet
from abl.models.basic_model import BasicModel, BasicDataset
from abl.models.wabl_models import DecisionTree, WABLBasicModel


from multiprocessing import Pool from multiprocessing import Pool
from abducer.abducer_base import AbducerBase
from abducer.kb import add_KB, HWF_KB, HED_prolog_KB
from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation from datasets.hed.get_hed import get_hed, split_equation
import framework_hed
import framework_hed_knn
from abl import framework_hed




def run_test(): def run_test():
@@ -36,7 +36,7 @@ def run_test():
# kb = HWF_KB(True) # kb = HWF_KB(True)
# abducer = AbducerBase(kb) # abducer = AbducerBase(kb)


kb = HED_prolog_KB()
kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)


recorder = logger() recorder = logger()

nonshare_example.py → examples/nonshare_example.py View File


share_example.py → examples/share_example.py View File


+ 0
- 0
examples/weights/all_weights_here.txt View File


Loading…
Cancel
Save