Browse Source

[MNT] Change variable names

pull/3/head
troyyyyy 2 years ago
parent
commit
1e2b574112
2 changed files with 140 additions and 153 deletions
  1. +31
    -44
      abl/reasoning/kb.py
  2. +109
    -109
      abl/reasoning/reasoner.py

+ 31
- 44
abl/reasoning/kb.py View File

@@ -1,18 +1,5 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 LAMDA All rights reserved.
#
# File Name :kb.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/03
# Description :
#
# ================================================================#

from abc import ABC, abstractmethod
import bisect
import copy
import numpy as np

from collections import defaultdict
@@ -79,103 +66,103 @@ class KBBase(ABC):
def logic_forward(self, pseudo_labels):
pass

def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0):
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision)
else:
if not self.use_cache:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
else:
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision)
def _find_candidate_GKB(self, pred_res, key):
def _find_candidate_GKB(self, pred_res, y):
if self.max_err == 0:
return self.base[len(pred_res)][key]
return self.base[len(pred_res)][y]
else:
potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, key)
key_idx = bisect.bisect_left(key_list, y)
all_candidates = []
for idx in range(key_idx - 1, 0, -1):
k = key_list[idx]
if abs(k - key) <= self.max_err:
if abs(k - y) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
for idx in range(key_idx, len(key_list)):
k = key_list[idx]
if abs(k - key) <= self.max_err:
if abs(k - y) <= self.max_err:
all_candidates += potential_candidates[k]
else:
break
return all_candidates
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision):
if self.base == {}:
return []
if len(pred_res) not in self.len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, key)
all_candidates = self._find_candidate_GKB(pred_res, y)
if len(all_candidates) == 0:
return []
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
address_num = min(max_revision_num, min_address_num + require_more_revision)
idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), key, self.max_err):
if check_equal(self.logic_forward(candidate), y, self.max_err):
candidates.append(candidate)
return candidates

def _address(self, address_num, pred_res, key):
def _address(self, address_num, pred_res, y):
new_candidates = []
address_idx_list = combinations(list(range(len(pred_res))), address_num)

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
if check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, key)
new_candidates = self._address(address_num, pred_res, y)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
break
if address_num >= max_address_num:
if address_num >= max_revision_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
for address_num in range(min_address_num + 1, min_address_num + require_more_revision + 1):
if address_num > max_revision_num:
return candidates
new_candidates = self._address(address_num, pred_res, key)
new_candidates = self._address(address_num, pred_res, y)
candidates += new_candidates
return candidates
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision):
pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
y = hashable_to_list(y)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
def _dict_len(self, dic):
if not self.GKB_flag:
@@ -217,16 +204,16 @@ class prolog_KB(KBBase):
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
def get_query_string(self, pred_res, key, address_idx):
def get_query_string(self, pred_res, y, address_idx):
query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx)
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
candidates = []
query_string = self.get_query_string(pred_res, key, address_idx)
query_string = self.get_query_string(pred_res, y, address_idx)
save_pred_res = pred_res
pred_res = flatten(pred_res)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]


+ 109
- 109
abl/reasoning/reasoner.py View File

@@ -4,7 +4,7 @@ from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist

class AbducerBase(abc.ABC):
class ReasonerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False):
self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence'
@@ -66,15 +66,15 @@ class AbducerBase(abc.ABC):
candidate = candidates[np.argmin(cost_list)]
return candidate
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key):
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, y):
address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
"""
Get the address score for a single solution.

@@ -86,8 +86,8 @@ class AbducerBase(abc.ABC):
List of predicted results.
pred_res_prob : list
List of probabilities for predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.

Returns
-------
@@ -95,7 +95,7 @@ class AbducerBase(abc.ABC):
The address score for the given solution.
"""
address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
@@ -105,7 +105,7 @@ class AbducerBase(abc.ABC):
x = solution.get_x()
return max_address_num - x.sum()

def zoopt_get_solution(self, pred_res, pred_res_prob, key, max_address_num):
def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_address_num):
"""Get the optimal solution using the Zoopt library.

Parameters
@@ -114,8 +114,8 @@ class AbducerBase(abc.ABC):
List of predicted results.
pred_res_prob : list
List of probabilities for predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.
max_address_num : int or float
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.

@@ -127,7 +127,7 @@ class AbducerBase(abc.ABC):
length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
objective = Objective(
lambda sol: self.zoopt_address_score(pred_res, pred_res_prob, key, sol),
lambda sol: self.zoopt_address_score(pred_res, pred_res_prob, y, sol),
dim=dimension,
constraint=lambda sol: self._constrain_address_num(sol, max_address_num),
)
@@ -135,15 +135,15 @@ class AbducerBase(abc.ABC):
solution = Opt.min(objective, parameter).get_x()
return solution
def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
"""Get the addresses corresponding to the given indices.

Parameters
----------
pred_res : list
List of predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.
address_idx : array-like
Indices of the addresses to retrieve.

@@ -152,19 +152,19 @@ class AbducerBase(abc.ABC):
list
The addresses corresponding to the given indices.
"""
return self.kb.address_by_idx(pred_res, key, address_idx)
return self.kb.address_by_idx(pred_res, y, address_idx)

def abduce(self, data, max_address=-1, require_more_address=0):
def abduce(self, data, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data.

Parameters
----------
data : tuple
Tuple containing the predicted results, predicted result probabilities, and key.
max_address : int or float, optional
Tuple containing the predicted results, predicted result probabilities, and y.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1.
require_more_address : int, optional
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0.

Returns
@@ -172,41 +172,41 @@ class AbducerBase(abc.ABC):
list
The abduced addresses.
"""
pred_res, pred_res_prob, key = data
pred_res, pred_res_prob, y = data
assert(type(max_address) in (int, float))
if max_address == -1:
assert(type(max_revision) in (int, float))
if max_revision == -1:
max_address_num = len(flatten(pred_res))
elif type(max_address) == float:
assert(max_address >= 0 and max_address <= 1)
max_address_num = round(len(flatten(pred_res)) * max_address)
elif type(max_revision) == float:
assert(max_revision >= 0 and max_revision <= 1)
max_address_num = round(len(flatten(pred_res)) * max_revision)
else:
assert(max_address >= 0)
max_address_num = max_address
assert(max_revision >= 0)
max_address_num = max_revision

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num)
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_address_num)
address_idx = np.where(solution != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
else:
candidates = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address)
candidates = self.kb.abduce_candidates(pred_res, y, max_address_num, require_more_revision)

candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
return candidate

def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0):
def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data in batches.

Parameters
----------
Z : list
List of predicted results.
List of predicted results and result probablities.
Y : list
List of predicted result probabilities.
max_address : int or float, optional
List of ground truths.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1.
require_more_address : int, optional
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0.

Returns
@@ -214,19 +214,19 @@ class AbducerBase(abc.ABC):
list
The abduced addresses.
"""
return [self.abduce((z, prob, y), max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
# def _batch_abduce_helper(self, args):
# z, prob, y, max_address, require_more_address = args
# return self.abduce((z, prob, y), max_address, require_more_address)
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)

# def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0):
# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
# with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results
def __call__(self, Z, Y, max_address=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address, require_more_address)
def __call__(self, Z, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(Z, Y, max_revision, require_more_revision)

if __name__ == '__main__':
@@ -245,76 +245,76 @@ if __name__ == '__main__':
print('add_KB with GKB:')
kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res)
print()
print('add_KB without GKB:')
kb = add_KB()
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res)
print()
print('add_KB without GKB:, no cache')
kb = add_KB(use_cache=False)
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res)
print()
print('prolog_KB with add.pl:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res)
print()

print('prolog_KB with add.pl using zoopt:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence', zoopt=True)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
abd = ReasonerBase(kb, 'confidence', zoopt=True)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res)
print()
@@ -323,10 +323,10 @@ if __name__ == '__main__':
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
kb = add_KB()
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_address=2, require_more_address=1)
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=1)
print(res)
print()
@@ -361,62 +361,62 @@ if __name__ == '__main__':
print('HWF_KB with GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res)
print()
print('HWF_KB without GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res)
print()
print('HWF_KB with GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res)
print()
print('HWF_KB without GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res)
print()
print('HWF_KB with multiple inputs at once:')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=1, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=1, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=3, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_revision=3, require_more_revision=0)
print(res)
print()
print('max_address is float')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.5, require_more_address=0)
print('max_revision is float')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.5, require_more_revision=0)
print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.9, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.9, require_more_revision=0)
print(res)
print()
@@ -436,23 +436,23 @@ if __name__ == '__main__':
rules = [rule.value for rule in prolog_rules]
return rules
class HED_Abducer(AbducerBase):
class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func='hamming'):
super().__init__(kb, dist_func, zoopt=True)
def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):
def _address_by_idxs(self, pred_res, y, all_address_flag, idxs):
pred = []
k = []
address_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(key[idx])
k.append(y[idx])
address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx)
return candidate
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
@@ -464,7 +464,7 @@ if __name__ == '__main__':
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)
candidate = self._address_by_idxs(pred_res, y, all_address_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()
@@ -486,8 +486,8 @@ if __name__ == '__main__':
def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)

kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = HED_Abducer(kb)
kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/hed/datasets/learn_add.pl')
abd = HED_Reasoner(kb)
consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]]
inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]]
inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
@@ -502,7 +502,7 @@ if __name__ == '__main__':
print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print()

print('HED_Abducer abduce')
print('HED_Reasoner abduce')
res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)))
print(res)
res = abd.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)))
@@ -511,6 +511,6 @@ if __name__ == '__main__':
print(res)
print()

print('HED_Abducer abduce rules')
print('HED_Reasoner abduce rules')
abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules)

Loading…
Cancel
Save