Browse Source

[MNT] add docstring for class ground_KB

pull/3/head
troyyyyy 2 years ago
parent
commit
ac675633f8
2 changed files with 85 additions and 57 deletions
  1. +78
    -50
      abl/reasoning/kb.py
  2. +7
    -7
      abl/reasoning/reasoner.py

+ 78
- 50
abl/reasoning/kb.py View File

@@ -5,7 +5,7 @@ import numpy as np
from collections import defaultdict
from itertools import product, combinations

from abl.utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -14,9 +14,9 @@ import pyswip

class KBBase(ABC):
"""
Base class for reasoner.
Base class for knowledge base.

Attributes
Parameters
----------
pseudo_label_list : list
List of possible pseudo labels.
@@ -30,10 +30,11 @@ class KBBase(ABC):
Notes
-----
Users creating there own KB should inherit from this base class. For the inherited
subclass, it's mandatory to provide `pseudo_label_list` and override the `logic_forward`
function. After that, other operations (e.g. how to perform abductive reasoning)
will be automatically set up.
Users should inherit from this base class to build their own knowledge base. For the
user-build KB (an inherited subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):
if not isinstance(pseudo_label_list, list):
@@ -44,6 +45,9 @@ class KBBase(ABC):

@abstractmethod
def logic_forward(self, pseudo_labels):
"""
How to perform logical reasoning. Users are required to provide this.
"""
pass

def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
@@ -55,7 +59,7 @@ class KBBase(ABC):
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth.
Ground truth for the result (after passing through the logic part).
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int, optional
@@ -85,7 +89,7 @@ class KBBase(ABC):
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth.
Ground truth for the result (after passing through the logic part).
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
@@ -122,8 +126,8 @@ class KBBase(ABC):
----------
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth.
y : Any
Ground truth for the result (after passing through the logic part).
max_revision_num : int
The upper limit on the number of revisions.
require_more_revision : int
@@ -165,30 +169,58 @@ class KBBase(ABC):
pred_pseudo_label = hashable_to_list(pred_pseudo_label)
y = hashable_to_list(y)
return self._abduce_by_search(pred_pseudo_label, y, max_revision_num, require_more_revision)
class ground_KB(KBBase):
def __init__(self, pseudo_label_list, GKB_len_list=None, max_err=0):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt
upon class initialization, stroing all potential candidates along with
their respective results after passing through the logic part. Ground KB can
enhance the speed of abductive reasoning. For more on this, refer to the
`abduce_candidates` method in this class.

Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
GKB_len_list : list
List of possible lengths of pseudo label.
max_err : float, optional
Refer to class `KBBase`.
Notes
-----
Users can also inherit from this class to build their own knowledge base.
Similar to `KBBase`, users are only required to provide the `pseudo_label_list`
and override the `logic_forward` function. Additionally, users should provide
the `GKB_len_list`. After that, other operations (e.g. auto-construction of
GKB, and how to perform abductive reasoning) will be automatically set up.
"""
def __init__(self, pseudo_label_list, GKB_len_list, max_err=0):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):
raise TypeError("GKB_len_list should be list")
self.GKB_len_list = GKB_len_list
self.base = {}
self.GKB = {}
X, Y = self._get_GKB()
for x, y in zip(X, Y):
self.base.setdefault(len(x), defaultdict(list))[y].append(x)
self.GKB.setdefault(len(x), defaultdict(list))[y].append(x)
# For parallel version of _get_GKB
def _get_XY_list(self, args):
pre_x, post_x_it = args[0], args[1]
XY_list = []
for post_x in post_x_it:
x = (pre_x,) + post_x
y = self.logic_forward(x)
if y is not None:
if y is not np.inf:
XY_list.append((x, y))
return XY_list

# Parallel _get_GKB
def _get_GKB(self):
"""
Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`.
"""
X, Y = [], []
for length in self.GKB_len_list:
arg_list = []
@@ -208,13 +240,37 @@ class ground_KB(KBBase):
return X, Y
def abduce_candidates(self, pred_pseudo_label, y, max_revision_num, require_more_revision=0):
return self._abduce_by_GKB(pred_pseudo_label, y, max_revision_num, require_more_revision)
"""
Perform abductive reasoning by directly retrieving consistent candidates from
the prebuilt GKB. In this way, the time-consuming exhaustive search can be
avoided.
This is an overridden function. For more information about the parameters and
returns, refer to the function of the same name in class `KBBase`.
"""
if self.GKB == {} or len(pred_pseudo_label) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y)
if len(all_candidates) == 0:
return []

cost_list = hamming_dist(pred_pseudo_label, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates
def _find_candidate_GKB(self, pred_pseudo_label, y):
"""
Retrieve consistent candidates from the prebuilt GKB. If `max_err` is greater
than 0, return all candidates whose logical results fall within the
[y - max_err, y + max_err] range.
"""
if self.max_err == 0:
return self.base[len(pred_pseudo_label)][y]
return self.GKB[len(pred_pseudo_label)][y]
else:
potential_candidates = self.base[len(pred_pseudo_label)]
potential_candidates = self.GKB[len(pred_pseudo_label)]
key_list = list(potential_candidates.keys())
key_idx = bisect.bisect_left(key_list, y)
@@ -233,35 +289,7 @@ class ground_KB(KBBase):
else:
break
return all_candidates
def _abduce_by_GKB(self, pred_pseudo_label, y, max_revision_num, require_more_revision):
if self.base == {} or len(pred_pseudo_label) not in self.GKB_len_list:
return []
all_candidates = self._find_candidate_GKB(pred_pseudo_label, y)
if len(all_candidates) == 0:
return []

cost_list = hamming_dist(pred_pseudo_label, all_candidates)
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def _dict_len(self, dic):
if not self.GKB_flag:
return 0
else:
return sum(len(c) for c in dic.values())

def __len__(self):
if not self.GKB_flag:
return 0
else:
return sum(self._dict_len(v) for v in self.base.values())

class prolog_KB(KBBase):
def __init__(self, pseudo_label_list, pl_file, max_err=0):


+ 7
- 7
abl/reasoning/reasoner.py View File

@@ -1,6 +1,6 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from abl.utils.utils import (
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
@@ -13,7 +13,7 @@ class ReasonerBase:
"""
Base class for reasoner.

Attributes
Parameters
----------
kb :
The knowledge base to be used for reasoning.
@@ -115,7 +115,7 @@ class ReasonerBase:
Predicted probabilities of the prediction (Each sublist contains the probability
distribution over all pseudo labels).
y : Any
Ground truth.
Ground truth for the result (after passing through the logic part).
max_revision_num : int
Specifies the maximum number of revisions allowed.
"""
@@ -162,7 +162,7 @@ class ReasonerBase:
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : Any
Ground truth.
Ground truth for the result (after passing through the logic part).
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
@@ -181,8 +181,8 @@ class ReasonerBase:
distribution over all pseudo labels).
pred_pseudo_label : List[Any]
Predicted pseudo label.
y : any
Ground truth.
y : Any
Ground truth for the result (after passing through the logic part).
max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number
@@ -456,7 +456,7 @@ if __name__ == "__main__":
print()

print("HWF_KB with GKB, max_err=0.1")
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5], max_err=0.1)
kb = HWF_ground_KB(GKB_len_list=[1, 3, 5, 7], max_err=0.1)
reasoner = ReasonerBase(kb, "hamming")
test_hwf(reasoner)



Loading…
Cancel
Save