Browse Source

[MNT] remove unnecessary utils functions

pull/3/head
troyyyyy 2 years ago
parent
commit
0e6d829ed1
3 changed files with 26 additions and 54 deletions
  1. +2
    -2
      abl/reasoning/kb.py
  2. +20
    -3
      abl/reasoning/reasoner.py
  3. +4
    -49
      abl/utils/utils.py

+ 2
- 2
abl/reasoning/kb.py View File

@@ -5,7 +5,7 @@ import numpy as np
from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list
from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list

from multiprocessing import Pool

@@ -100,7 +100,7 @@ class KBBase(ABC):
candidate = pred_pseudo_label.copy()
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), y, self.max_err):
if abs(self.logic_forward(candidate) - y) <= self.max_err:
candidates.append(candidate)
return candidates



+ 20
- 3
abl/reasoning/reasoner.py View File

@@ -1,11 +1,10 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from abl.utils.utils import (
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
hamming_dist,
calculate_revision_num,
)


@@ -168,6 +167,24 @@ class ReasonerBase:
"""
return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx)

def _get_max_revision_num(max_revision, symbol_num):
"""
Get the maximum revision number according to input `max_revision`.
"""
if not isinstance(max_revision, (int, float)):
raise TypeError("Parameter must be of type int or float.")

if max_revision == -1:
return symbol_num
elif isinstance(max_revision, float):
if not (0 <= max_revision <= 1):
raise ValueError("If max_revision is a float, it must be between 0 and 1.")
return round(symbol_num * max_revision)
else:
if max_revision < 0:
raise ValueError("If max_revision is an int, it must be non-negative.")
return max_revision
def abduce(
self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0
):
@@ -198,7 +215,7 @@ class ReasonerBase:
knowledge base.
"""
symbol_num = len(flatten(pred_pseudo_label))
max_revision_num = calculate_revision_num(max_revision, symbol_num)
max_revision_num = self._get_max_revision_num(max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(


+ 4
- 49
abl/utils/utils.py View File

@@ -15,11 +15,6 @@ def flatten(nested_list):
-------
list
A flattened version of the input list.

Raises
------
TypeError
If the input object is not a list.
"""
if not isinstance(nested_list, list):
raise TypeError("Input must be of type list.")
@@ -46,9 +41,6 @@ def reform_idx(flattened_list, structured_list):
list
A reformed list that mimics the structure of structured_list.
"""
# if not isinstance(flattened_list, list):
# raise TypeError("Input must be of type list.")

if not isinstance(structured_list[0], (list, tuple)):
return flattened_list

@@ -88,7 +80,7 @@ def hamming_dist(pred_pseudo_label, candidates):
return np.sum(pred_pseudo_label != candidates, axis=1)


def confidence_dist(pred_prob, candidates):
def confidence_dist(pred_prob, candidates_idx):
"""
Compute the confidence distance between prediction probabilities and candidates.

@@ -97,7 +89,7 @@ def confidence_dist(pred_prob, candidates):
pred_prob : list of numpy.ndarray
Prediction probability distributions, each element is an ndarray
representing the probability distribution of a particular prediction.
candidates : list of list of int
candidates_idx : list of list of int
Index of candidate labels, each element is a list of indexes being considered
as a candidate correction.

@@ -107,8 +99,8 @@ def confidence_dist(pred_prob, candidates):
Confidence distances computed for each candidate.
"""
pred_prob = np.clip(pred_prob, 1e-9, 1)
_, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[cols, candidates], axis=1)
_, cols = np.indices((len(candidates_idx), len(candidates_idx[0])))
return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1)


def block_sample(X, Z, Y, sample_num, seg_idx):
@@ -143,34 +135,6 @@ def block_sample(X, Z, Y, sample_num, seg_idx):
return (data[start_idx:end_idx] for data in (X, Z, Y))


def check_equal(a, b, max_err=0):
"""
Check whether two numbers a and b are equal within a maximum allowable error.

Parameters
----------
a, b : int or float
The numbers to compare.
max_err : int or float, optional
The maximum allowable absolute difference between a and b for them to be considered equal.
Default is 0, meaning the numbers must be exactly equal.

Returns
-------
bool
True if a and b are equal within the allowable error, False otherwise.

Raises
------
TypeError
If a or b are not of type int or float.
"""
if not (isinstance(a, (int, float)) and isinstance(b, (int, float))):
raise TypeError("Input values must be int or float.")

return abs(a - b) <= max_err


def to_hashable(x):
"""
Convert a nested list to a nested tuple so it is hashable.
@@ -190,7 +154,6 @@ def to_hashable(x):
return tuple(to_hashable(item) for item in x)
return x


def hashable_to_list(x):
"""
Convert a nested tuple back to a nested list.
@@ -227,13 +190,6 @@ def calculate_revision_num(parameter, total_length):
-------
int
The calculated parameter.

Raises
------
TypeError
If parameter is not an int or a float.
ValueError
If parameter is a float not in [0, 1] or an int below 0.
"""
if not isinstance(parameter, (int, float)):
raise TypeError("Parameter must be of type int or float.")
@@ -303,5 +259,4 @@ if __name__ == "__main__":
)
B = [[0, 9, 3], [0, 11, 4]]

print(ori_confidence_dist(A, B))
print(confidence_dist(A, B))

Loading…
Cancel
Save