Browse Source

[MNT] add docstring for class KBBase

pull/3/head
troyyyyy 2 years ago
parent
commit
233fc9738d
2 changed files with 156 additions and 65 deletions
  1. +147
    -57
      abl/reasoning/kb.py
  2. +9
    -8
      abl/reasoning/reasoner.py

+ 147
- 57
abl/reasoning/kb.py View File

@@ -5,7 +5,7 @@ import numpy as np
from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from abl.utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -13,11 +13,31 @@ from functools import lru_cache
import pyswip

class KBBase(ABC):
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
# TODO:添加一下类型检查,比如
# if not isinstance(X, (np.ndarray, spmatrix)):
# raise TypeError("X should be numpy array or sparse matrix")
"""
Base class for reasoner.

Attributes
----------
pseudo_label_list : list
List of possible pseudo labels.
max_err : float, optional
The upper tolerance limit when comparing the similarity between a candidate result
and the ground truth. Especially relevant for regression problems where exact matches
might not be feasible. Default to 0.
use_cache : bool, optional
Whether to use a cache for previously abduced candidates to speed up subsequent
operations. Defaults to True.
Notes
-----
Users creating there own KB should inherit from this base class. For the inherited
subclass, it's mandatory to provide `pseudo_label_list` and override the `logic_forward`
function. After that, other operations (e.g. how to perform abductive reasoning)
will be automatically set up.
"""
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
if not isinstance(pseudo_label_list, list):
raise TypeError("pseudo_label_list should be list")
self.pseudo_label_list = pseudo_label_list
self.max_err = max_err
self.use_cache = use_cache
@@ -26,39 +46,105 @@ class KBBase(ABC):
def logic_forward(self, pseudo_labels):
pass

def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
"""
Perform abductive reasoning to get a candidate consistent with the knowledge base.

Parameters
----------
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
knowledge base.
"""
if not self.use_cache:
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
return self._abduce_by_search(pred_pseudo_label, y,
max_revision_num, require_more_revision)
else:
return self._abduce_by_search_cache(to_hashable(pred_res), to_hashable(y), max_revision_num, require_more_revision)
return self._abduce_by_search_cache(to_hashable(pred_pseudo_label),
to_hashable(y),
max_revision_num, require_more_revision)
def revise_by_idx(self, pred_res, y, revision_idx):
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
"""
Revise the predicted pseudo label at specified index positions.

Parameters
----------
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth.
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
candidate = pred_res.copy()
candidate = pred_pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), y, self.max_err):
candidates.append(candidate)
return candidates

def _revision(self, revision_num, pred_res, y):
def _revision(self, revision_num, pred_pseudo_label, y):
"""
For a specified number of pseudo label to revise, iterate through all possible
indices to find any candidates that are consistent with the knowledge base.
"""
new_candidates = []
revision_idx_list = combinations(range(len(pred_res)), revision_num)
revision_idx_list = combinations(range(len(pred_pseudo_label)), revision_num)

for revision_idx in revision_idx_list:
candidates = self.revise_by_idx(pred_res, y, revision_idx)
candidates = self.revise_at_idx(pred_pseudo_label, y, revision_idx)
new_candidates.extend(candidates)
return new_candidates

def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
def _abduce_by_search(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
Perform abductive reasoning by exhastive search. Specifically, begin with 0 and
continuously increase the number of pseudo labels to revise, until candidates
that are consistent with the knowledge base are found.
Parameters
----------
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth.
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
If larger than 0, then after having found any candidates consistent with the
knowledge base, continue to increase the number pseudo labels to revise to
get more possible consistent candidates.

Returns
-------
List[List[Any]]
A list of candidates, i.e. revised pseudo label that are consistent with the
knowledge base.
"""
candidates = []
for revision_num in range(len(pred_res) + 1):
if revision_num == 0 and check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
for revision_num in range(len(pred_pseudo_label) + 1):
if revision_num == 0 and check_equal(self.logic_forward(pred_pseudo_label),
y,
self.max_err):
candidates.append(pred_pseudo_label)
elif revision_num > 0:
candidates.extend(self._revision(revision_num, pred_res, y))
candidates.extend(self._revision(revision_num, pred_pseudo_label, y))
if len(candidates) > 0:
min_revision_num = revision_num
break
@@ -68,26 +154,17 @@ class KBBase(ABC):
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
if revision_num > max_revision_num:
return candidates
candidates.extend(self._revision(revision_num, pred_res, y))
candidates.extend(self._revision(revision_num, pred_pseudo_label, y))
return candidates
@lru_cache(maxsize=None)
def _abduce_by_search_cache(self, pred_res, y, max_revision_num, require_more_revision):
pred_res = hashable_to_list(pred_res)
def _abduce_by_search_cache(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
"""
`_abduce_by_search` with cache.
"""
pred_pseudo_label = hashable_to_list(pred_pseudo_label)
y = hashable_to_list(y)
return self._abduce_by_search(pred_res, y, max_revision_num, require_more_revision)
def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)
class ground_KB(KBBase):
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0):
@@ -130,14 +207,14 @@ class ground_KB(KBBase):
X, Y = zip(*sorted(zip(X, Y), key=lambda pair: pair[1]))
return X, Y
def abduce_candidates(self, pred_res, y, max_revision_num, require_more_revision=0):
return self._abduce_by_GKB(pred_res, y, max_revision_num, require_more_revision)
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
return self._abduce_by_GKB(pred_pseudo_label, y, max_revision_num, require_more_revision)
def _find_candidate_GKB(self, pred_res, y):
def _find_candidate_GKB(self, pred_pseudo_label, y):
if self.max_err == 0:
return self.base[len(pred_res)][y]
return self.base[len(pred_pseudo_label)][y]
else:
potential_candidates = self.base[len(pred_res)]
potential_candidates = self.base[len(pred_pseudo_label)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y)
@@ -157,21 +234,34 @@ class ground_KB(KBBase):
break
return all_candidates
def _abduce_by_GKB(self, pred_res, y, max_revision_num, require_more_revision):
if self.base == {} or len(pred_res) not in self.GKB_len_list:
def _abduce_by_GKB(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
if self.base == {} or len(pred_pseudo_label) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_res, y)
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y)
if len(all_candidates) == 0:
return []

cost_list = hamming_dist(pred_res, all_candidates)
cost_list = hamming_dist(pred_pseudo_label, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())

class prolog_KB(KBBase):
def __init__(self, pseudo_label_list, pl_file, max_err=0):
@@ -187,36 +277,36 @@ class prolog_KB(KBBase):
return False
return result
def _revision_pred_res(self, pred_res, revision_idx):
def _revision_pred_pseudo_label(self, pred_pseudo_label, revision_idx):
import re
revision_pred_res = pred_res.copy()
revision_pred_res = flatten(revision_pred_res)
revision_pred_pseudo_label = pred_pseudo_label.copy()
revision_pred_pseudo_label = flatten(revision_pred_pseudo_label)
for idx in revision_idx:
revision_pred_res[idx] = 'P' + str(idx)
revision_pred_res = reform_idx(revision_pred_res, pred_res)
revision_pred_pseudo_label[idx] = 'P' + str(idx)
revision_pred_pseudo_label = reform_idx(revision_pred_pseudo_label, pred_pseudo_label)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res))
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_pseudo_label))
def get_query_string(self, pred_res, y, revision_idx):
def get_query_string(self, pred_pseudo_label, y, revision_idx):
query_string = "logic_forward("
query_string += self._revision_pred_res(pred_res, revision_idx)
query_string += self._revision_pred_pseudo_label(pred_pseudo_label, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def revise_by_idx(self, pred_res, y, revision_idx):
def revise_at_idx(self, pred_pseudo_label, y, revision_idx):
candidates = []
query_string = self.get_query_string(pred_res, y, revision_idx)
save_pred_res = pred_res
pred_res = flatten(pred_res)
query_string = self.get_query_string(pred_pseudo_label, y, revision_idx)
save_pred_pseudo_label = pred_pseudo_label
pred_pseudo_label = flatten(pred_pseudo_label)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c:
candidate = pred_res.copy()
candidate = pred_pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_idx(candidate, save_pred_res)
candidate = reform_idx(candidate, save_pred_pseudo_label)
candidates.append(candidate)
return candidates

+ 9
- 8
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
from abl.utils.utils import (
confidence_dist,
flatten,
reform_idx,
@@ -23,7 +23,7 @@ class ReasonerBase:
| `"confidence"`. Any other options will raise a `NotImplementedError`. For
detailed explanations of these options, refer to `_get_cost_list`.
mapping : dict, optional
A mapping from label to index. If not provided, a default order-based mapping is
A mapping from index to label. If not provided, a default order-based mapping is
created.
use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning. Default to False.
@@ -44,6 +44,7 @@ class ReasonerBase:
if not isinstance(mapping, dict):
raise TypeError("mapping should be dict")
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
"""
@@ -57,7 +58,7 @@ class ReasonerBase:
Predicted pseudo label to be used for selecting a candidate.
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels).
distribution over all pseudo labels).
candidates : List[List[Any]]
Multiple candidate abduction results.
"""
@@ -85,7 +86,7 @@ class ReasonerBase:
Predicted pseudo label.
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels). Used when distance function is "confidence".
distribution over all pseudo labels). Used when distance function is "confidence".
candidates : List[List[Any]]
Multiple candidate abduction results.
"""
@@ -93,7 +94,7 @@ class ReasonerBase:
return hamming_dist(pred_pseudo_label, candidates)

elif self.dist_func == "confidence":
candidates = [[self.mapping[x] for x in c] for c in candidates]
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_prob, candidates)


@@ -112,7 +113,7 @@ class ReasonerBase:
Predicted pseudo label.
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels).
distribution over all pseudo labels).
y : Any
Ground truth.
max_revision_num : int
@@ -177,7 +178,7 @@ class ReasonerBase:
----------
pred_prob : List[List[Any]]
Predicted probabilities of the prediction (Each sublist contains the probability
values of all pseudo labels).
distribution over all pseudo labels).
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
@@ -193,7 +194,7 @@ class ReasonerBase:
Returns
-------
List[Any]
The revised pseudo label through abductive reasoning, which is consistent with the
A revised pseudo label through abductive reasoning, which is consistent with the
knowledge base.
"""
symbol_num = len(flatten(pred_pseudo_label))


Loading…
Cancel
Save