Browse Source

[MNT] resolve several comments

pull/1/head
troyyyyy 2 years ago
parent
commit
f1b964df58
4 changed files with 151 additions and 56 deletions
  1. +2
    -5
      abl/reasoning/kb.py
  2. +85
    -48
      abl/reasoning/reasoner.py
  3. +10
    -0
      tests/conftest.py
  4. +54
    -3
      tests/test_reasoning.py

+ 2
- 5
abl/reasoning/kb.py View File

@@ -243,10 +243,7 @@ class KBBase(ABC):
"""
candidates = []
for revision_num in range(len(pseudo_label) + 1):
if revision_num == 0 and self._check_equal(self.logic_forward(pseudo_label, *(x,) if self._num_args == 2 else ()), y):
candidates.append(pseudo_label)
elif revision_num > 0:
candidates.extend(self._revision(revision_num, pseudo_label, y, x))
candidates.extend(self._revision(revision_num, pseudo_label, y, x))
if len(candidates) > 0:
min_revision_num = revision_num
break
@@ -559,7 +556,7 @@ class PrologKB(KBBase):
knowledge base.
"""
candidates = []
query_string = self.get_query_string(pseudo_label, y, revision_idx)
query_string = self.get_query_string(pseudo_label, y, x, revision_idx)
save_pseudo_label = pseudo_label
pseudo_label = flatten(pseudo_label)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]


+ 85
- 48
abl/reasoning/reasoner.py View File

@@ -1,8 +1,10 @@
import inspect
from typing import Callable, Any, List, Optional

import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter
from typing import Callable, Any, List, Optional

from kb import KBBase
from ..reasoning import KBBase
from ..structures import ListData
from ..utils.utils import confidence_dist, hamming_dist

@@ -16,8 +18,18 @@ class Reasoner:
kb : class KBBase
The knowledge base to be used for reasoning.
dist_func : str or Callable, optional
The distance function to be used when determining the cost list between each
candidate and the given prediction. Defaults to "confidence".
The distance function used to determine the cost list between each
candidate and the given prediction. It can be either a string representing a
predefined distance function or a callable function. The available predefined
distance functions: 'hamming' | 'confidence'. 'hamming': directly calculates
the Hamming distance between the predicted pseudo label in the data sample
and each candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data sample. The callable function should have the signature
dist_func(data_sample, candidates) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
mapping : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default
order-based mapping is created. Defaults to None.
@@ -43,6 +55,7 @@ class Reasoner:
use_zoopt: bool = False,
):
self.kb = kb
self._check_valid_dist(dist_func)
self.dist_func = dist_func
self.use_zoopt = use_zoopt
self.max_revision = max_revision
@@ -55,18 +68,48 @@ class Reasoner:
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _check_valid_dist(self, dist_func):
if isinstance(dist_func, str):
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(
f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.'
)
return
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 2:
raise ValueError(f"User-defined dist_func must have exactly two parameters, but got {len(params)}.")
return
else:
raise TypeError(
f"dist_func must be a string or a callable function, but got {type(dist_func)}."
)

def _check_valid_dist_output(self, cost_list, candidate_num):
if not isinstance(cost_list, np.ndarray):
raise TypeError(f"Expected dist_func to return a numpy.ndarray, but got {type(cost_list)}.")
if not cost_list.dtype.kind in "biufc":
raise ValueError(f"Expected dist_func to return a numpy.ndarray with a numerical type, but got dtype {cost_list.dtype}.")
if len(cost_list) != candidate_num:
raise ValueError(
f"The length of the array returned by dist_func must be equal to the number of candidates. "
f"Expected length {candidate_num}, but got {len(cost_list)}."
)

def _check_valid_mapping(self, mapping):
if not isinstance(mapping, dict):
raise TypeError(f"mapping should be dict, got {type(mapping)}")
raise TypeError(f"mapping should be dict, but got {type(mapping)}.")
for key, value in mapping.items():
if not isinstance(key, int):
raise ValueError(f"All keys in the mapping must be integers, got {key}")
raise ValueError(f"All keys in the mapping must be integers, but got {key}.")
if value not in self.kb.pseudo_label_list:
raise ValueError(f"All values in the mapping must be in the pseudo_label_list, got {value}")
raise ValueError(
f"All values in the mapping must be in the pseudo_label_list, but got {value}."
)

def _get_one_candidate(
self,
data_sample: ListData,
self,
data_sample: ListData,
candidates: List[List[Any]],
) -> List[Any]:
"""
@@ -91,25 +134,17 @@ class Reasoner:
elif len(candidates) == 1:
return candidates[0]
else:
cost_array = self.get_cost_list(data_sample, candidates)
cost_array = self._get_cost_list(data_sample, candidates)
candidate = candidates[np.argmin(cost_array)]
return candidate

def get_cost_list(
self,
data_sample: ListData,
def _get_cost_list(
self,
data_sample: ListData,
candidates: List[List[Any]],
) -> np.ndarray:
"""
Get the list of costs between each candidate and the given data sample.
The list is
calculated based on one of the following distance functions:
- "hamming": Directly calculates the Hamming distance between the predicted pseudo
label in the data sample and candidate.
- "confidence": Calculates the distance between the prediction and candidate based
on confidence derived from the predicted probability in the data
sample.
Get the list of costs between each candidate and the given data sample.

Parameters
----------
@@ -117,7 +152,7 @@ class Reasoner:
Data sample.
candidates : List[List[Any]]
Multiple compatible candidates.
Returns
-------
np.ndarray
@@ -129,18 +164,16 @@ class Reasoner:
elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
elif callable(self.dist_func):
return self.dist_func(data_sample, candidates)

else:
raise ValueError("dist_func must be either a string or a callable function")

elif callable(self.dist_func):
cost_list = self.dist_func(data_sample, candidates)
self._check_valid_dist_output(cost_list, len(candidates))
return cost_list

def _zoopt_get_solution(
self,
symbol_num: int,
data_sample: ListData,
self,
symbol_num: int,
data_sample: ListData,
max_revision_num: int,
) -> List[bool]:
"""
@@ -155,7 +188,7 @@ class Reasoner:
Data sample.
max_revision_num : int
Specifies the maximum number of revisions allowed.
Returns
-------
List[bool]
@@ -172,15 +205,15 @@ class Reasoner:
return solution

def zoopt_revision_score(
self,
symbol_num: int,
data_sample: ListData,
self,
symbol_num: int,
data_sample: ListData,
sol: List[bool],
) -> int:
"""
Get the revision score for a solution. A lower score suggests that ZOOpt library
has a higher preference for this solution.
Parameters
----------
symbol_num : int
@@ -189,7 +222,7 @@ class Reasoner:
Data sample.
sol: List[bool]
The solution for ZOOpt library.
Returns
-------
int
@@ -200,7 +233,7 @@ class Reasoner:
data_sample.pred_pseudo_label, data_sample.Y, data_sample.X, revision_idx
)
if len(candidates) > 0:
return np.min(self.get_cost_list(data_sample, candidates))
return np.min(self._get_cost_list(data_sample, candidates))
else:
return symbol_num

@@ -217,17 +250,21 @@ class Reasoner:
Get the maximum revision number according to input `max_revision`.
"""
if not isinstance(max_revision, (int, float)):
raise TypeError(f"Parameter must be of type int or float, got {type(max_revision)}")
raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")

if max_revision == -1:
return symbol_num
elif isinstance(max_revision, float):
if not (0 <= max_revision <= 1):
raise ValueError(f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}")
raise ValueError(
f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}"
)
return round(symbol_num * max_revision)
else:
if max_revision < 0:
raise ValueError(f"If max_revision is an int, it must be non-negative, but got {max_revision}")
raise ValueError(
f"If max_revision is an int, it must be non-negative, but got {max_revision}"
)
return max_revision

def abduce(self, data_sample: ListData) -> List[Any]:
@@ -256,11 +293,11 @@ class Reasoner:
)
else:
candidates = self.kb.abduce_candidates(
pseudo_label = data_sample.pred_pseudo_label,
y = data_sample.Y,
x = data_sample.X,
max_revision_num = max_revision_num,
require_more_revision = self.require_more_revision,
data_sample.pred_pseudo_label,
data_sample.Y,
data_sample.X,
max_revision_num,
self.require_more_revision,
)

candidate = self._get_one_candidate(data_sample, candidates)


+ 10
- 0
tests/conftest.py View File

@@ -82,6 +82,7 @@ def data_samples_add():
]

data_samples_add = ListData()
data_samples_add.X = None
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
@@ -91,6 +92,7 @@ def data_samples_add():
@pytest.fixture
def data_samples_hwf():
data_samples_hwf = ListData()
data_samples_hwf.X = None
data_samples_hwf.pred_pseudo_label = [
["5", "+", "2"],
["5", "+", "9"],
@@ -200,6 +202,14 @@ def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/datasets/add.pl")
return kb

@pytest.fixture
def kb_hwf1():
return HwfKB(max_err=0.1)

@pytest.fixture
def kb_hwf2():
return HwfKB(max_err=1)


@pytest.fixture
def kb_hed():


+ 54
- 3
tests/test_reasoning.py View File

@@ -1,4 +1,5 @@
import pytest
import numpy as np

from abl.reasoning import PrologKB, Reasoner

@@ -93,15 +94,65 @@ class TestReaonser(object):
def test_reasoner_init(self, reasoner_instance):
assert reasoner_instance.dist_func == "confidence"

def test_invalid_dist_funce(kb_add):
class TestDistFunc(object):
def test_invalid_predefined_dist_func(self, kb_add):
with pytest.raises(NotImplementedError) as excinfo:
Reasoner(kb_add, "invalid_dist_func")
assert 'Valid options for dist_func include "hamming" and "confidence"' in str(
assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str(
excinfo.value
)
def random_dist(self, data_sample, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list
def test_user_defined_dist_func(self, kb_add):
reasoner = Reasoner(kb_add, self.random_dist)
assert reasoner.dist_func == self.random_dist
def invalid_dist1(self, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list
def invalid_dist2(self, data_sample, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
return np.append(cost_list, np.random.rand())
def invalid_dist3(self, data_sample, candidates):
cost_list = [np.random.rand() for _ in candidates]
return cost_list
def invalid_dist4(self, data_sample, candidates):
cost_list = np.array(["invalid" for _ in candidates])
return cost_list
def test_invalid_user_defined_dist_func(self, kb_add, data_samples_add):
with pytest.raises(ValueError) as excinfo:
Reasoner(kb_add, self.invalid_dist1)
assert 'User-defined dist_func must have exactly two parameters' in str(
excinfo.value
)
with pytest.raises(ValueError) as excinfo:
reasoner = Reasoner(kb_add, self.invalid_dist2)
reasoner.batch_abduce(data_samples_add)
assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str(
excinfo.value
)
with pytest.raises(TypeError) as excinfo:
reasoner = Reasoner(kb_add, self.invalid_dist3)
reasoner.batch_abduce(data_samples_add)
assert 'Expected dist_func to return a numpy.ndarray' in str(
excinfo.value
)
with pytest.raises(ValueError) as excinfo:
reasoner = Reasoner(kb_add, self.invalid_dist4)
reasoner.batch_abduce(data_samples_add)
assert 'Expected dist_func to return a numpy.ndarray with a numerical type' in str(
excinfo.value
)


class test_batch_abduce(object):
class TestBatchAbduce(object):
def test_batch_abduce_add(self, kb_add, data_samples_add):
reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)


Loading…
Cancel
Save