|
|
|
@@ -1,8 +1,10 @@ |
|
|
|
import inspect |
|
|
|
from typing import Callable, Any, List, Optional |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from zoopt import Dimension, Objective, Opt, Parameter |
|
|
|
from typing import Callable, Any, List, Optional |
|
|
|
|
|
|
|
from kb import KBBase |
|
|
|
from ..reasoning import KBBase |
|
|
|
from ..structures import ListData |
|
|
|
from ..utils.utils import confidence_dist, hamming_dist |
|
|
|
|
|
|
|
@@ -16,8 +18,18 @@ class Reasoner: |
|
|
|
kb : class KBBase |
|
|
|
The knowledge base to be used for reasoning. |
|
|
|
dist_func : str or Callable, optional |
|
|
|
The distance function to be used when determining the cost list between each |
|
|
|
candidate and the given prediction. Defaults to "confidence". |
|
|
|
The distance function used to determine the cost list between each |
|
|
|
candidate and the given prediction. It can be either a string representing a |
|
|
|
predefined distance function or a callable function. The available predefined |
|
|
|
distance functions: 'hamming' | 'confidence'. 'hamming': directly calculates |
|
|
|
the Hamming distance between the predicted pseudo label in the data sample |
|
|
|
and each candidate, 'confidence': calculates the distance between the prediction |
|
|
|
and each candidate based on confidence derived from the predicted probability |
|
|
|
in the data sample. The callable function should have the signature |
|
|
|
dist_func(data_sample, candidates) and must return a cost list. Each element |
|
|
|
in this cost list should be a numerical value representing the cost for each |
|
|
|
candidate, and the list should have the same length as candidates. |
|
|
|
Defaults to 'confidence'. |
|
|
|
mapping : Optional[dict], optional |
|
|
|
A mapping from index in the base model to label. If not provided, a default |
|
|
|
order-based mapping is created. Defaults to None. |
|
|
|
@@ -43,6 +55,7 @@ class Reasoner: |
|
|
|
use_zoopt: bool = False, |
|
|
|
): |
|
|
|
self.kb = kb |
|
|
|
self._check_valid_dist(dist_func) |
|
|
|
self.dist_func = dist_func |
|
|
|
self.use_zoopt = use_zoopt |
|
|
|
self.max_revision = max_revision |
|
|
|
@@ -55,18 +68,48 @@ class Reasoner: |
|
|
|
self.mapping = mapping |
|
|
|
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) |
|
|
|
|
|
|
|
def _check_valid_dist(self, dist_func): |
|
|
|
if isinstance(dist_func, str): |
|
|
|
if dist_func not in ["hamming", "confidence"]: |
|
|
|
raise NotImplementedError( |
|
|
|
f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.' |
|
|
|
) |
|
|
|
return |
|
|
|
elif callable(dist_func): |
|
|
|
params = inspect.signature(dist_func).parameters.values() |
|
|
|
if len(params) != 2: |
|
|
|
raise ValueError(f"User-defined dist_func must have exactly two parameters, but got {len(params)}.") |
|
|
|
return |
|
|
|
else: |
|
|
|
raise TypeError( |
|
|
|
f"dist_func must be a string or a callable function, but got {type(dist_func)}." |
|
|
|
) |
|
|
|
|
|
|
|
def _check_valid_dist_output(self, cost_list, candidate_num): |
|
|
|
if not isinstance(cost_list, np.ndarray): |
|
|
|
raise TypeError(f"Expected dist_func to return a numpy.ndarray, but got {type(cost_list)}.") |
|
|
|
if not cost_list.dtype.kind in "biufc": |
|
|
|
raise ValueError(f"Expected dist_func to return a numpy.ndarray with a numerical type, but got dtype {cost_list.dtype}.") |
|
|
|
if len(cost_list) != candidate_num: |
|
|
|
raise ValueError( |
|
|
|
f"The length of the array returned by dist_func must be equal to the number of candidates. " |
|
|
|
f"Expected length {candidate_num}, but got {len(cost_list)}." |
|
|
|
) |
|
|
|
|
|
|
|
def _check_valid_mapping(self, mapping): |
|
|
|
if not isinstance(mapping, dict): |
|
|
|
raise TypeError(f"mapping should be dict, got {type(mapping)}") |
|
|
|
raise TypeError(f"mapping should be dict, but got {type(mapping)}.") |
|
|
|
for key, value in mapping.items(): |
|
|
|
if not isinstance(key, int): |
|
|
|
raise ValueError(f"All keys in the mapping must be integers, got {key}") |
|
|
|
raise ValueError(f"All keys in the mapping must be integers, but got {key}.") |
|
|
|
if value not in self.kb.pseudo_label_list: |
|
|
|
raise ValueError(f"All values in the mapping must be in the pseudo_label_list, got {value}") |
|
|
|
|
|
|
|
raise ValueError( |
|
|
|
f"All values in the mapping must be in the pseudo_label_list, but got {value}." |
|
|
|
) |
|
|
|
|
|
|
|
def _get_one_candidate( |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
candidates: List[List[Any]], |
|
|
|
) -> List[Any]: |
|
|
|
""" |
|
|
|
@@ -91,25 +134,17 @@ class Reasoner: |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_array = self.get_cost_list(data_sample, candidates) |
|
|
|
cost_array = self._get_cost_list(data_sample, candidates) |
|
|
|
candidate = candidates[np.argmin(cost_array)] |
|
|
|
return candidate |
|
|
|
|
|
|
|
def get_cost_list( |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
def _get_cost_list( |
|
|
|
self, |
|
|
|
data_sample: ListData, |
|
|
|
candidates: List[List[Any]], |
|
|
|
) -> np.ndarray: |
|
|
|
""" |
|
|
|
Get the list of costs between each candidate and the given data sample. |
|
|
|
|
|
|
|
The list is |
|
|
|
calculated based on one of the following distance functions: |
|
|
|
- "hamming": Directly calculates the Hamming distance between the predicted pseudo |
|
|
|
label in the data sample and candidate. |
|
|
|
- "confidence": Calculates the distance between the prediction and candidate based |
|
|
|
on confidence derived from the predicted probability in the data |
|
|
|
sample. |
|
|
|
Get the list of costs between each candidate and the given data sample. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
@@ -117,7 +152,7 @@ class Reasoner: |
|
|
|
Data sample. |
|
|
|
candidates : List[List[Any]] |
|
|
|
Multiple compatible candidates. |
|
|
|
|
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
np.ndarray |
|
|
|
@@ -129,18 +164,16 @@ class Reasoner: |
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(data_sample.pred_prob, candidates) |
|
|
|
|
|
|
|
elif callable(self.dist_func): |
|
|
|
return self.dist_func(data_sample, candidates) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError("dist_func must be either a string or a callable function") |
|
|
|
|
|
|
|
elif callable(self.dist_func): |
|
|
|
cost_list = self.dist_func(data_sample, candidates) |
|
|
|
self._check_valid_dist_output(cost_list, len(candidates)) |
|
|
|
return cost_list |
|
|
|
|
|
|
|
def _zoopt_get_solution( |
|
|
|
self, |
|
|
|
symbol_num: int, |
|
|
|
data_sample: ListData, |
|
|
|
self, |
|
|
|
symbol_num: int, |
|
|
|
data_sample: ListData, |
|
|
|
max_revision_num: int, |
|
|
|
) -> List[bool]: |
|
|
|
""" |
|
|
|
@@ -155,7 +188,7 @@ class Reasoner: |
|
|
|
Data sample. |
|
|
|
max_revision_num : int |
|
|
|
Specifies the maximum number of revisions allowed. |
|
|
|
|
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
List[bool] |
|
|
|
@@ -172,15 +205,15 @@ class Reasoner: |
|
|
|
return solution |
|
|
|
|
|
|
|
def zoopt_revision_score( |
|
|
|
self, |
|
|
|
symbol_num: int, |
|
|
|
data_sample: ListData, |
|
|
|
self, |
|
|
|
symbol_num: int, |
|
|
|
data_sample: ListData, |
|
|
|
sol: List[bool], |
|
|
|
) -> int: |
|
|
|
""" |
|
|
|
Get the revision score for a solution. A lower score suggests that ZOOpt library |
|
|
|
has a higher preference for this solution. |
|
|
|
|
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
symbol_num : int |
|
|
|
@@ -189,7 +222,7 @@ class Reasoner: |
|
|
|
Data sample. |
|
|
|
sol: List[bool] |
|
|
|
The solution for ZOOpt library. |
|
|
|
|
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
@@ -200,7 +233,7 @@ class Reasoner: |
|
|
|
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx |
|
|
|
) |
|
|
|
if len(candidates) > 0: |
|
|
|
return np.min(self.get_cost_list(data_sample, candidates)) |
|
|
|
return np.min(self._get_cost_list(data_sample, candidates)) |
|
|
|
else: |
|
|
|
return symbol_num |
|
|
|
|
|
|
|
@@ -217,17 +250,21 @@ class Reasoner: |
|
|
|
Get the maximum revision number according to input `max_revision`. |
|
|
|
""" |
|
|
|
if not isinstance(max_revision, (int, float)): |
|
|
|
raise TypeError(f"Parameter must be of type int or float, got {type(max_revision)}") |
|
|
|
raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}") |
|
|
|
|
|
|
|
if max_revision == -1: |
|
|
|
return symbol_num |
|
|
|
elif isinstance(max_revision, float): |
|
|
|
if not (0 <= max_revision <= 1): |
|
|
|
raise ValueError(f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}") |
|
|
|
raise ValueError( |
|
|
|
f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}" |
|
|
|
) |
|
|
|
return round(symbol_num * max_revision) |
|
|
|
else: |
|
|
|
if max_revision < 0: |
|
|
|
raise ValueError(f"If max_revision is an int, it must be non-negative, but got {max_revision}") |
|
|
|
raise ValueError( |
|
|
|
f"If max_revision is an int, it must be non-negative, but got {max_revision}" |
|
|
|
) |
|
|
|
return max_revision |
|
|
|
|
|
|
|
def abduce(self, data_sample: ListData) -> List[Any]: |
|
|
|
@@ -256,11 +293,11 @@ class Reasoner: |
|
|
|
) |
|
|
|
else: |
|
|
|
candidates = self.kb.abduce_candidates( |
|
|
|
pseudo_label = data_sample.pred_pseudo_label, |
|
|
|
y = data_sample.Y, |
|
|
|
x = data_sample.X, |
|
|
|
max_revision_num = max_revision_num, |
|
|
|
require_more_revision = self.require_more_revision, |
|
|
|
data_sample.pred_pseudo_label, |
|
|
|
data_sample.Y, |
|
|
|
data_sample.X, |
|
|
|
max_revision_num, |
|
|
|
self.require_more_revision, |
|
|
|
) |
|
|
|
|
|
|
|
candidate = self._get_one_candidate(data_sample, candidates) |
|
|
|
|