Browse Source

[MNT] Change variable names

pull/3/head
troyyyyy 3 years ago
parent
commit
1e2b574112
2 changed files with 140 additions and 153 deletions
  1. +31
    -44
      abl/reasoning/kb.py
  2. +109
    -109
      abl/reasoning/reasoner.py

+ 31
- 44
abl/reasoning/kb.py View File

@@ -1,18 +1,5 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 LAMDA All rights reserved.
#
# File Name :kb.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/03
# Description :
#
# ================================================================#

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import bisect import bisect
import copy
import numpy as np import numpy as np


from collections import defaultdict from collections import defaultdict
@@ -79,103 +66,103 @@ class KBBase(ABC):
def logic_forward(self, pseudo_labels): def logic_forward(self, pseudo_labels):
pass pass


def abduce_candidates(self, pred_res, key, max_address_num, require_more_address=0):
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
if self.GKB_flag: if self.GKB_flag:
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision)
else: else:
if not self.use_cache: if not self.use_cache:
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
else: else:
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(key), max_address_num, require_more_address)
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision)
def _find_candidate_GKB(self, pred_res, key):
def _find_candidate_GKB(self, pred_res, y):
if self.max_err == 0: if self.max_err == 0:
return self.base[len(pred_res)][key]
return self.base[len(pred_res)][y]
else: else:
potential_candidates = self.base[len(pred_res)] potential_candidates = self.base[len(pred_res)]
key_list = list(potential_candidates.keys()) key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, key)
key_idx = bisect.bisect_left(key_list, y)
all_candidates = [] all_candidates = []
for idx in range(key_idx - 1, 0, -1): for idx in range(key_idx - 1, 0, -1):
k = key_list[idx] k = key_list[idx]
if abs(k - key) <= self.max_err:
if abs(k - y) <= self.max_err:
all_candidates += potential_candidates[k] all_candidates += potential_candidates[k]
else: else:
break break
for idx in range(key_idx, len(key_list)): for idx in range(key_idx, len(key_list)):
k = key_list[idx] k = key_list[idx]
if abs(k - key) <= self.max_err:
if abs(k - y) <= self.max_err:
all_candidates += potential_candidates[k] all_candidates += potential_candidates[k]
else: else:
break break
return all_candidates return all_candidates
def _abduce_by_GKB(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision):
if self.base == {}: if self.base == {}:
return [] return []
if len(pred_res) not in self.len_list: if len(pred_res) not in self.len_list:
return [] return []
all_candidates = self._find_candidate_GKB(pred_res, key)
all_candidates = self._find_candidate_GKB(pred_res, y)
if len(all_candidates) == 0: if len(all_candidates) == 0:
return [] return []
else: else:
cost_list = hamming_dist(pred_res, all_candidates) cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list) min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
address_num = min(max_revision_num, min_address_num + require_more_revision)
idxs = np.where(cost_list <= address_num)[0] idxs = np.where(cost_list <= address_num)[0]
candidates = [all_candidates[idx] for idx in idxs] candidates = [all_candidates[idx] for idx in idxs]
return candidates return candidates


def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
candidates = [] candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx)) abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
for c in abduce_c: for c in abduce_c:
candidate = pred_res.copy() candidate = pred_res.copy()
for i, idx in enumerate(address_idx): for i, idx in enumerate(address_idx):
candidate[idx] = c[i] candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), key, self.max_err):
if check_equal(self.logic_forward(candidate), y, self.max_err):
candidates.append(candidate) candidates.append(candidate)
return candidates return candidates


def _address(self, address_num, pred_res, key):
def _address(self, address_num, pred_res, y):
new_candidates = [] new_candidates = []
address_idx_list = combinations(list(range(len(pred_res))), address_num) address_idx_list = combinations(list(range(len(pred_res))), address_num)


for address_idx in address_idx_list: for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
new_candidates += candidates new_candidates += candidates
return new_candidates return new_candidates


def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
candidates = [] candidates = []
for address_num in range(len(pred_res) + 1): for address_num in range(len(pred_res) + 1):
if address_num == 0: if address_num == 0:
if check_equal(self.logic_forward(pred_res), key, self.max_err):
if check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res) candidates.append(pred_res)
else: else:
new_candidates = self._address(address_num, pred_res, key)
new_candidates = self._address(address_num, pred_res, y)
candidates += new_candidates candidates += new_candidates
if len(candidates) > 0: if len(candidates) > 0:
min_address_num = address_num min_address_num = address_num
break break
if address_num >= max_address_num:
if address_num >= max_revision_num:
return [] return []


for address_num in range(min_address_num + 1, min_address_num + require_more_address + 1):
if address_num > max_address_num:
for address_num in range(min_address_num + 1, min_address_num + require_more_revision + 1):
if address_num > max_revision_num:
return candidates return candidates
new_candidates = self._address(address_num, pred_res, key)
new_candidates = self._address(address_num, pred_res, y)
candidates += new_candidates candidates += new_candidates
return candidates return candidates
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, key, max_address_num, require_more_address):
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision):
pred_res = hashable_to_list(pred_res) pred_res = hashable_to_list(pred_res)
key = hashable_to_list(key)
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address)
y = hashable_to_list(y)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
def _dict_len(self, dic): def _dict_len(self, dic):
if not self.GKB_flag: if not self.GKB_flag:
@@ -217,16 +204,16 @@ class prolog_KB(KBBase):
regex = r"'P\d+'" regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res)) return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
def get_query_string(self, pred_res, key, address_idx):
def get_query_string(self, pred_res, y, address_idx):
query_string = "logic_forward(" query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx) query_string += self._address_pred_res(pred_res, address_idx)
key_is_none_flag = key is None or (type(key) == list and key[0] is None)
query_string += ",%s)." % key if not key_is_none_flag else ")."
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string return query_string
def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
candidates = [] candidates = []
query_string = self.get_query_string(pred_res, key, address_idx)
query_string = self.get_query_string(pred_res, y, address_idx)
save_pred_res = pred_res save_pred_res = pred_res
pred_res = flatten(pred_res) pred_res = flatten(pred_res)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)] abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]


+ 109
- 109
abl/reasoning/reasoner.py View File

@@ -4,7 +4,7 @@ from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist


class AbducerBase(abc.ABC):
class ReasonerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False): def __init__(self, kb, dist_func='hamming', zoopt=False):
self.kb = kb self.kb = kb
assert dist_func == 'hamming' or dist_func == 'confidence' assert dist_func == 'hamming' or dist_func == 'confidence'
@@ -66,15 +66,15 @@ class AbducerBase(abc.ABC):
candidate = candidates[np.argmin(cost_list)] candidate = candidates[np.argmin(cost_list)]
return candidate return candidate
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, key):
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, y):
address_idx = np.where(sol_x != 0)[0] address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
if len(candidates) > 0: if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else: else:
return len(pred_res) return len(pred_res)
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
""" """
Get the address score for a single solution. Get the address score for a single solution.


@@ -86,8 +86,8 @@ class AbducerBase(abc.ABC):
List of predicted results. List of predicted results.
pred_res_prob : list pred_res_prob : list
List of probabilities for predicted results. List of probabilities for predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.


Returns Returns
------- -------
@@ -95,7 +95,7 @@ class AbducerBase(abc.ABC):
The address score for the given solution. The address score for the given solution.
""" """
address_idx = np.where(sol.get_x() != 0)[0] address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
if len(candidates) > 0: if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates)) return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else: else:
@@ -105,7 +105,7 @@ class AbducerBase(abc.ABC):
x = solution.get_x() x = solution.get_x()
return max_address_num - x.sum() return max_address_num - x.sum()


def zoopt_get_solution(self, pred_res, pred_res_prob, key, max_address_num):
def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_address_num):
"""Get the optimal solution using the Zoopt library. """Get the optimal solution using the Zoopt library.


Parameters Parameters
@@ -114,8 +114,8 @@ class AbducerBase(abc.ABC):
List of predicted results. List of predicted results.
pred_res_prob : list pred_res_prob : list
List of probabilities for predicted results. List of probabilities for predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.
max_address_num : int or float max_address_num : int or float
Maximum number of addresses to use. If float, represents the fraction of total addresses to use. Maximum number of addresses to use. If float, represents the fraction of total addresses to use.


@@ -127,7 +127,7 @@ class AbducerBase(abc.ABC):
length = len(flatten(pred_res)) length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
objective = Objective( objective = Objective(
lambda sol: self.zoopt_address_score(pred_res, pred_res_prob, key, sol),
lambda sol: self.zoopt_address_score(pred_res, pred_res_prob, y, sol),
dim=dimension, dim=dimension,
constraint=lambda sol: self._constrain_address_num(sol, max_address_num), constraint=lambda sol: self._constrain_address_num(sol, max_address_num),
) )
@@ -135,15 +135,15 @@ class AbducerBase(abc.ABC):
solution = Opt.min(objective, parameter).get_x() solution = Opt.min(objective, parameter).get_x()
return solution return solution
def address_by_idx(self, pred_res, key, address_idx):
def address_by_idx(self, pred_res, y, address_idx):
"""Get the addresses corresponding to the given indices. """Get the addresses corresponding to the given indices.


Parameters Parameters
---------- ----------
pred_res : list pred_res : list
List of predicted results. List of predicted results.
key : str
Key for the predicted results.
y : str
Ground truth for the predicted results.
address_idx : array-like address_idx : array-like
Indices of the addresses to retrieve. Indices of the addresses to retrieve.


@@ -152,19 +152,19 @@ class AbducerBase(abc.ABC):
list list
The addresses corresponding to the given indices. The addresses corresponding to the given indices.
""" """
return self.kb.address_by_idx(pred_res, key, address_idx)
return self.kb.address_by_idx(pred_res, y, address_idx)


def abduce(self, data, max_address=-1, require_more_address=0):
def abduce(self, data, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data. """Perform abduction on the given data.


Parameters Parameters
---------- ----------
data : tuple data : tuple
Tuple containing the predicted results, predicted result probabilities, and key.
max_address : int or float, optional
Tuple containing the predicted results, predicted result probabilities, and y.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use. Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1. If -1, use all addresses. Defaults to -1.
require_more_address : int, optional
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0. Number of additional addresses to require. Defaults to 0.


Returns Returns
@@ -172,41 +172,41 @@ class AbducerBase(abc.ABC):
list list
The abduced addresses. The abduced addresses.
""" """
pred_res, pred_res_prob, key = data
pred_res, pred_res_prob, y = data
assert(type(max_address) in (int, float))
if max_address == -1:
assert(type(max_revision) in (int, float))
if max_revision == -1:
max_address_num = len(flatten(pred_res)) max_address_num = len(flatten(pred_res))
elif type(max_address) == float:
assert(max_address >= 0 and max_address <= 1)
max_address_num = round(len(flatten(pred_res)) * max_address)
elif type(max_revision) == float:
assert(max_revision >= 0 and max_revision <= 1)
max_address_num = round(len(flatten(pred_res)) * max_revision)
else: else:
assert(max_address >= 0)
max_address_num = max_address
assert(max_revision >= 0)
max_address_num = max_revision


if self.zoopt: if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, key, max_address_num)
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_address_num)
address_idx = np.where(solution != 0)[0] address_idx = np.where(solution != 0)[0]
candidates = self.address_by_idx(pred_res, key, address_idx)
candidates = self.address_by_idx(pred_res, y, address_idx)
else: else:
candidates = self.kb.abduce_candidates(pred_res, key, max_address_num, require_more_address)
candidates = self.kb.abduce_candidates(pred_res, y, max_address_num, require_more_revision)


candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates) candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
return candidate return candidate


def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0):
def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data in batches. """Perform abduction on the given data in batches.


Parameters Parameters
---------- ----------
Z : list Z : list
List of predicted results.
List of predicted results and result probablities.
Y : list Y : list
List of predicted result probabilities.
max_address : int or float, optional
List of ground truths.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use. Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1. If -1, use all addresses. Defaults to -1.
require_more_address : int, optional
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0. Number of additional addresses to require. Defaults to 0.


Returns Returns
@@ -214,19 +214,19 @@ class AbducerBase(abc.ABC):
list list
The abduced addresses. The abduced addresses.
""" """
return [self.abduce((z, prob, y), max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
# def _batch_abduce_helper(self, args): # def _batch_abduce_helper(self, args):
# z, prob, y, max_address, require_more_address = args
# return self.abduce((z, prob, y), max_address, require_more_address)
# z, prob, y, max_revision, require_more_revision = args
# return self.abduce((z, prob, y), max_revision, require_more_revision)


# def batch_abduce(self, Z, Y, max_address=-1, require_more_address=0):
# def batch_abduce(self, Z, Y, max_revision=-1, require_more_revision=0):
# with Pool(processes=os.cpu_count()) as pool: # with Pool(processes=os.cpu_count()) as pool:
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_address, require_more_address) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results # return results
def __call__(self, Z, Y, max_address=-1, require_more_address=0):
return self.batch_abduce(Z, Y, max_address, require_more_address)
def __call__(self, Z, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(Z, Y, max_revision, require_more_revision)


if __name__ == '__main__': if __name__ == '__main__':
@@ -245,76 +245,76 @@ if __name__ == '__main__':
print('add_KB with GKB:') print('add_KB with GKB:')
kb = add_KB(GKB_flag=True) kb = add_KB(GKB_flag=True)
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res) print(res)
print() print()
print('add_KB without GKB:') print('add_KB without GKB:')
kb = add_KB() kb = add_KB()
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res) print(res)
print() print()
print('add_KB without GKB:, no cache') print('add_KB without GKB:, no cache')
kb = add_KB(use_cache=False) kb = add_KB(use_cache=False)
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res) print(res)
print() print()
print('prolog_KB with add.pl:') print('prolog_KB with add.pl:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res) print(res)
print() print()


print('prolog_KB with add.pl using zoopt:') print('prolog_KB with add.pl using zoopt:')
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/datasets/mnist_add/add.pl')
abd = AbducerBase(kb, 'confidence', zoopt=True)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_address=2, require_more_address=0)
kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='../examples/mnist_add/datasets/add.pl')
abd = ReasonerBase(kb, 'confidence', zoopt=True)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob2}, [8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_address=1, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [17], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_address=2, require_more_address=0)
res = abd.batch_abduce({'cls':[[1, 1]], 'prob':prob1}, [20], max_revision=2, require_more_revision=0)
print(res) print(res)
print() print()
@@ -323,10 +323,10 @@ if __name__ == '__main__':
[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]]
kb = add_KB() kb = add_KB()
abd = AbducerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'confidence')
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_address=2, require_more_address=1)
res = abd.batch_abduce({'cls':[[1, 1], [1, 2]], 'prob':multiple_prob}, [4, 8], max_revision=2, require_more_revision=1)
print(res) print(res)
print() print()
@@ -361,62 +361,62 @@ if __name__ == '__main__':
print('HWF_KB with GKB, max_err=0.1') print('HWF_KB with GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1) kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res) print(res)
print() print()
print('HWF_KB without GKB, max_err=0.1') print('HWF_KB without GKB, max_err=0.1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_address=2, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [3], max_revision=2, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res) print(res)
print() print()
print('HWF_KB with GKB, max_err=1') print('HWF_KB with GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1) kb = HWF_KB(len_list=[1, 3, 5], GKB_flag=True, max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res) print(res)
print() print()
print('HWF_KB without GKB, max_err=1') print('HWF_KB without GKB, max_err=1')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 1) kb = HWF_KB(len_list=[1, 3, 5], max_err = 1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_address=3, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '9']], 'prob':[None]}, [65], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2']], 'prob':[None]}, [1.67], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_address=5, require_more_address=3)
res = abd.batch_abduce({'cls':[['5', '8', '8', '8', '8']], 'prob':[None]}, [3.17], max_revision=5, require_more_revision=3)
print(res) print(res)
print() print()
print('HWF_KB with multiple inputs at once:') print('HWF_KB with multiple inputs at once:')
kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1) kb = HWF_KB(len_list=[1, 3, 5], max_err = 0.1)
abd = AbducerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=1, require_more_address=0)
abd = ReasonerBase(kb, 'hamming')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=1, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=3, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_address=3, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 65], max_revision=3, require_more_revision=0)
print(res) print(res)
print() print()
print('max_address is float')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.5, require_more_address=0)
print('max_revision is float')
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.5, require_more_revision=0)
print(res) print(res)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_address=0.9, require_more_address=0)
res = abd.batch_abduce({'cls':[['5', '+', '2'], ['5', '+', '9']], 'prob':[None, None]}, [3, 64], max_revision=0.9, require_more_revision=0)
print(res) print(res)
print() print()
@@ -436,23 +436,23 @@ if __name__ == '__main__':
rules = [rule.value for rule in prolog_rules] rules = [rule.value for rule in prolog_rules]
return rules return rules
class HED_Abducer(AbducerBase):
class HED_Reasoner(ReasonerBase):
def __init__(self, kb, dist_func='hamming'): def __init__(self, kb, dist_func='hamming'):
super().__init__(kb, dist_func, zoopt=True) super().__init__(kb, dist_func, zoopt=True)
def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):
def _address_by_idxs(self, pred_res, y, all_address_flag, idxs):
pred = [] pred = []
k = [] k = []
address_flag = [] address_flag = []
for idx in idxs: for idx in idxs:
pred.append(pred_res[idx]) pred.append(pred_res[idx])
k.append(key[idx])
k.append(y[idx])
address_flag += list(all_address_flag[idx]) address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0] address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx) candidate = self.address_by_idx(pred, k, address_idx)
return candidate return candidate
def zoopt_address_score(self, pred_res, pred_res_prob, key, sol):
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res) all_address_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))] lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = [] candidate_size = []
@@ -464,7 +464,7 @@ if __name__ == '__main__':
for idx in range(-1, len(pred_res)): for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0): if (not idx in idxs) and (idx >= 0):
idxs.append(idx) idxs.append(idx)
candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)
candidate = self._address_by_idxs(pred_res, y, all_address_flag, idxs)
if len(candidate) == 0: if len(candidate) == 0:
if len(idxs) > 1: if len(idxs) > 1:
idxs.pop() idxs.pop()
@@ -486,8 +486,8 @@ if __name__ == '__main__':
def abduce_rules(self, pred_res): def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res) return self.kb.abduce_rules(pred_res)


kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abd = HED_Abducer(kb)
kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/hed/datasets/learn_add.pl')
abd = HED_Reasoner(kb)
consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]] consist_exs = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0]]
inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]] inconsist_exs1 = [[1, 1, '+', 0, '=', 1, 1], [1, '+', 1, '=', 1, 0], [0, '+', 0, '=', 0], [0, '+', 0, '=', 1]]
inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]] inconsist_exs2 = [[1, '+', 0, '=', 0], [1, '=', 1, '=', 0], [0, '=', 0, '=', 1, 1]]
@@ -502,7 +502,7 @@ if __name__ == '__main__':
print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules)) print(kb.consist_rule([1, '+', 1, '=', 1, 1], rules))
print() print()


print('HED_Abducer abduce')
print('HED_Reasoner abduce')
res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs))) res = abd.abduce((consist_exs, [[[None]]] * len(consist_exs), [None] * len(consist_exs)))
print(res) print(res)
res = abd.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1))) res = abd.abduce((inconsist_exs1, [[[None]]] * len(inconsist_exs1), [None] * len(inconsist_exs1)))
@@ -511,6 +511,6 @@ if __name__ == '__main__':
print(res) print(res)
print() print()


print('HED_Abducer abduce rules')
print('HED_Reasoner abduce rules')
abduced_rules = abd.abduce_rules(consist_exs) abduced_rules = abd.abduce_rules(consist_exs)
print(abduced_rules) print(abduced_rules)

Loading…
Cancel
Save