Browse Source

[FIX] add reasoner init parameters

pull/1/head
troyyyyy 2 years ago
parent
commit
cd5c577a50
2 changed files with 46 additions and 78 deletions
  1. +7
    -21
      abl/bridge/simple_bridge.py
  2. +39
    -57
      abl/reasoning/reasoner.py

+ 7
- 21
abl/bridge/simple_bridge.py View File

@@ -21,39 +21,25 @@ class SimpleBridge(BaseBridge):
super().__init__(model, reasoner) super().__init__(model, reasoner)
self.metric_list = metric_list self.metric_list = metric_list


# TODO: add reasoner.mapping to the property of SimpleBridge

def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]: def predict(self, data_samples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
self.model.predict(data_samples) self.model.predict(data_samples)
return data_samples.pred_idx, data_samples.pred_prob return data_samples.pred_idx, data_samples.pred_prob


def abduce_pseudo_label(
self,
data_samples: ListData,
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
self.reasoner.batch_abduce(data_samples, max_revision, require_more_revision)
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
self.reasoner.batch_abduce(data_samples)
return data_samples.abduced_pseudo_label return data_samples.abduced_pseudo_label


def idx_to_pseudo_label(
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.reasoner.mapping
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
pred_idx = data_samples.pred_idx pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [ data_samples.pred_pseudo_label = [
[mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
[self.reasoner.mapping[_idx] for _idx in sub_list]
for sub_list in pred_idx
] ]
return data_samples.pred_pseudo_label return data_samples.pred_pseudo_label


def pseudo_label_to_idx(
self, data_samples: ListData, mapping: Optional[Dict] = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.reasoner.remapping
def pseudo_label_to_idx(self, data_samples: ListData) -> List[List[Any]]:
abduced_idx = [ abduced_idx = [
[mapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label for sub_list in data_samples.abduced_pseudo_label
] ]
data_samples.abduced_idx = abduced_idx data_samples.abduced_idx = abduced_idx


+ 39
- 57
abl/reasoning/reasoner.py View File

@@ -24,17 +24,35 @@ class ReasonerBase:
mapping : dict, optional mapping : dict, optional
A mapping from index in the base model to label. If not provided, a default A mapping from index in the base model to label. If not provided, a default
order-based mapping is created. order-based mapping is created.
max_revision : int or float, optional
The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total
length that can be revised. A value of -1 implies no restriction on the
number of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required
when performing abductive reasoning. Defaults to 0.
use_zoopt : bool, optional use_zoopt : bool, optional
Whether to use the Zoopt library during abductive reasoning. Defaults to False. Whether to use the Zoopt library during abductive reasoning. Defaults to False.
""" """
def __init__(self, kb, dist_func="confidence", mapping=None, use_zoopt=False):
def __init__(self,
kb,
dist_func="confidence",
mapping=None,
max_revision=-1,
require_more_revision=0,
use_zoopt=False,
):
if dist_func not in ["hamming", "confidence"]: if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"") raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"")


self.kb = kb self.kb = kb
self.dist_func = dist_func self.dist_func = dist_func
self.use_zoopt = use_zoopt self.use_zoopt = use_zoopt
self.max_revision = max_revision
self.require_more_revision = require_more_revision
if mapping is None: if mapping is None:
self.mapping = { self.mapping = {
index: label for index, label in enumerate(self.kb.pseudo_label_list) index: label for index, label in enumerate(self.kb.pseudo_label_list)
@@ -117,9 +135,7 @@ class ReasonerBase:
size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num
) )
objective = Objective( objective = Objective(
lambda sol: self.zoopt_revision_score(
symbol_num, data_sample, sol
),
lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol),
dim=dimension, dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
) )
@@ -133,7 +149,9 @@ class ReasonerBase:
has a higher preference for this solution. has a higher preference for this solution.
""" """
revision_idx = np.where(sol.get_x() != 0)[0] revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
if len(candidates) > 0: if len(candidates) > 0:
return np.min(self._get_cost_list(data_sample, candidates)) return np.min(self._get_cost_list(data_sample, candidates))
else: else:
@@ -146,27 +164,6 @@ class ReasonerBase:
""" """
x = solution.get_x() x = solution.get_x()
return max_revision_num - x.sum() return max_revision_num - x.sum()
def revise_at_idx(self, data_sample, revision_idx):
"""
Revise the pseudo label in the data sample at specified index positions.

Parameters
----------
data_sample : ListData
Data sample.
revision_idx : array-like
Indices of where revisions should be made to the predicted pseudo label.
"""
return self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
def abduce_candidates(self, data_sample, max_revision_num, require_more_revision):
return self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
require_more_revision)


def _get_max_revision_num(self, max_revision, symbol_num): def _get_max_revision_num(self, max_revision, symbol_num):
""" """
@@ -186,9 +183,7 @@ class ReasonerBase:
raise ValueError("If max_revision is an int, it must be non-negative.") raise ValueError("If max_revision is an int, it must be non-negative.")
return max_revision return max_revision
def abduce(
self, data_sample, max_revision=-1, require_more_revision=0
):
def abduce(self, data_sample):
""" """
Perform abductive reasoning on the given data sample. Perform abductive reasoning on the given data sample.


@@ -196,14 +191,7 @@ class ReasonerBase:
---------- ----------
data_sample : ListData data_sample : ListData
Data sample. Data sample.
max_revision : int or float, optional
The upper limit on the number of revisions. If float, denotes the fraction of the
total length that can be revised. A value of -1 implies no restriction on the number
of revisions. Defaults to -1.
require_more_revision : int, optional
Specifies additional number of revisions permitted beyond the minimum required.
Defaults to 0.

Returns Returns
------- -------
List[Any] List[Any]
@@ -211,39 +199,33 @@ class ReasonerBase:
knowledge base. knowledge base.
""" """
symbol_num = data_sample.elements_num("pred_pseudo_label") symbol_num = data_sample.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
if self.use_zoopt: if self.use_zoopt:
solution = self.zoopt_get_solution(
symbol_num, data_sample, max_revision_num
)
solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)
revision_idx = np.where(solution != 0)[0] revision_idx = np.where(solution != 0)[0]
candidates = self.revise_at_idx(data_sample, revision_idx)
candidates = self.self.kb.revise_at_idx(data_sample.pred_pseudo_label,
data_sample.Y,
revision_idx)
else: else:
candidates = self.abduce_candidates(
data_sample, max_revision_num, require_more_revision
)

candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label,
data_sample.Y,
max_revision_num,
self.require_more_revision)
candidate = self._get_one_candidate(data_sample, candidates) candidate = self._get_one_candidate(data_sample, candidates)
return candidate return candidate


def batch_abduce(
self, data_samples, max_revision=-1, require_more_revision=0
):
def batch_abduce(self, data_samples):
""" """
Perform abductive reasoning on the given prediction data samples. Perform abductive reasoning on the given prediction data samples.
For detailed information, refer to `abduce`. For detailed information, refer to `abduce`.
""" """
abduced_pseudo_label = [ abduced_pseudo_label = [
self.abduce(data_sample, max_revision, require_more_revision)
for data_sample in data_samples
self.abduce(data_sample) for data_sample in data_samples
] ]
data_samples.abduced_pseudo_label = abduced_pseudo_label data_samples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label return abduced_pseudo_label


def __call__(
self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0
):
return self.batch_abduce(
pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision
)
def __call__(self, data_samples):
return self.batch_abduce(data_samples)

Loading…
Cancel
Save