Browse Source

[MNT] reformat code and add doc string to utils.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
736d11c03e
2 changed files with 236 additions and 74 deletions
  1. +2
    -2
      abl/reasoning/reasoner.py
  2. +234
    -72
      abl/utils/utils.py

+ 2
- 2
abl/reasoning/reasoner.py View File

@@ -5,7 +5,7 @@ from ..utils.utils import (
flatten,
reform_idx,
hamming_dist,
float_parameter,
calculate_revision_num,
)


@@ -214,7 +214,7 @@ class ReasonerBase:
The abduced revisions.
"""
symbol_num = len(flatten(pred_pseudo_label))
max_revision_num = float_parameter(max_revision, symbol_num)
max_revision_num = calculate_revision_num(max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(


+ 234
- 72
abl/utils/utils.py View File

@@ -1,85 +1,247 @@
import numpy as np
from itertools import chain

def flatten(l):
if not isinstance(l[0], (list, tuple)):
return l
return list(chain.from_iterable(l))
def reform_idx(flatten_pred_res, save_pred_res):
if not isinstance(save_pred_res[0], (list, tuple)):
return flatten_pred_res
re = []
i = 0
for e in save_pred_res:
re.append(flatten_pred_res[i:i + len(e)])
i += len(e)
return re


def hamming_dist(A, B):
A = np.array(A, dtype='<U')
B = np.array(B, dtype='<U')
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist(A, B):
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B))
return 1 - np.prod(A[rows, cols, B], axis=1)

def flatten(nested_list):
"""
Flattens a nested list.

Parameters
----------
nested_list : list
A list which might contain sublists or tuples.

Returns
-------
list
A flattened version of the input list.

Raises
------
TypeError
If the input object is not a list.
"""
if not isinstance(nested_list, list):
raise TypeError("Input must be of type list.")

if not nested_list or not isinstance(nested_list[0], (list, tuple)):
return nested_list

return list(chain.from_iterable(nested_list))


def reform_idx(flattened_pred, saved_pred):
"""
Reform the index based on saved_pred structure.

Parameters
----------
flattened_pred : list
A flattened list of predictions.
saved_pred : list
A list containing saved predictions, which could be nested lists or tuples.

Returns
-------
list
A reformed list that mimics the structure of saved_pred.
"""
if not isinstance(saved_pred[0], (list, tuple)):
return flattened_pred

reformed_pred = []
idx_start = 0
for elem in saved_pred:
idx_end = idx_start + len(elem)
reformed_pred.append(flattened_pred[idx_start:idx_end])
idx_start = idx_end

return reformed_pred


def hamming_dist(pred_pseudo_label, candidates):
"""
Compute the Hamming distance between two arrays.

Parameters
----------
pred_pseudo_label : list
First array to compare.
candidates : list
Second array to compare, expected to have shape (n, m)
where n is the number of rows, m is the length of pred_pseudo_label.

Returns
-------
numpy.ndarray
Hamming distances.
"""
pred_pseudo_label = np.array(pred_pseudo_label)
candidates = np.array(candidates)

# Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)

return np.sum(pred_pseudo_label != candidates, axis=1)


def confidence_dist(pred_prob, candidates):
"""
Compute the confidence distance between prediction probabilities and candidates.

Parameters
----------
pred_prob : list of numpy.ndarray
Prediction probability distributions, each element is an ndarray
representing the probability distribution of a particular prediction.
candidates : list of list of int
Index of candidate labels, each element is a list of indexes being considered
as a candidate correction.

Returns
-------
numpy.ndarray
Confidence distances computed for each candidate.
"""
pred_prob = np.clip(pred_prob, 1e-9, 1)
rows, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[rows, cols, candidates], axis=1)


def block_sample(X, Z, Y, sample_num, seg_idx):
X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)]
Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)]
Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)]
return X, Z, Y
"""
Extract a block of samples from lists X, Z, and Y.

Parameters
----------
X, Z, Y : list
Input lists from which to extract the samples.
sample_num : int
The number of samples per block.
seg_idx : int
The block index to extract.

Returns
-------
tuple of lists
The extracted block samples from X, Z, and Y.

Example
-------
>>> X = [1, 2, 3, 4, 5, 6]
>>> Z = ['a', 'b', 'c', 'd', 'e', 'f']
>>> Y = [10, 11, 12, 13, 14, 15]
>>> block_sample(X, Z, Y, 2, 1)
([3, 4], ['c', 'd'], [12, 13])
"""
start_idx = sample_num * seg_idx
end_idx = sample_num * (seg_idx + 1)

return (data[start_idx:end_idx] for data in (X, Z, Y))


def check_equal(a, b, max_err=0):
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return abs(a - b) <= max_err
if isinstance(a, list) and isinstance(b, list):
if len(a) != len(b):
return False
for i in range(len(a)):
if not check_equal(a[i], b[i]):
return False
return True
"""
Check whether two numbers a and b are equal within a maximum allowable error.

Parameters
----------
a, b : int or float
The numbers to compare.
max_err : int or float, optional
The maximum allowable absolute difference between a and b for them to be considered equal.
Default is 0, meaning the numbers must be exactly equal.

Returns
-------
bool
True if a and b are equal within the allowable error, False otherwise.

Raises
------
TypeError
If a or b are not of type int or float.
"""
if not (isinstance(a, (int, float)) and isinstance(b, (int, float))):
raise TypeError("Input values must be int or float.")

return abs(a - b) <= max_err


def to_hashable(x):
"""
Convert a nested list to a nested tuple so it is hashable.

Parameters
----------
x : list or other type
A potentially nested list to convert to a tuple.

Returns
-------
tuple or other type
The input converted to a tuple if it was a list,
otherwise the original input.
"""
if isinstance(x, list):
return tuple(to_hashable(item) for item in x)
return x


def hashable_to_list(x):
"""
Convert a nested tuple back to a nested list.

Parameters
----------
x : tuple or other type
A potentially nested tuple to convert to a list.

Returns
-------
list or other type
The input converted to a list if it was a tuple,
otherwise the original input.
"""
if isinstance(x, tuple):
return [hashable_to_list(item) for item in x]
return x


def calculate_revision_num(parameter, total_length):
"""
Convert a float parameter to an integer, based on a total length.

Parameters
----------
parameter : int or float
The parameter to convert. If float, it should be between 0 and 1.
If int, it should be non-negative. If -1, it will be replaced with total_length.
total_length : int
The total length to calculate the parameter from if it's a fraction.

Returns
-------
int
The calculated parameter.

Raises
------
TypeError
If parameter is not an int or a float.
ValueError
If parameter is a float not in [0, 1] or an int below 0.
"""
if not isinstance(parameter, (int, float)):
raise TypeError("Parameter must be of type int or float.")

else:
return a == b

def to_hashable(l):
if type(l) is not list:
return l
if type(l[0]) is not list:
return tuple(l)
return tuple(tuple(sublist) for sublist in l)

def hashable_to_list(t):
if type(t) is not tuple:
return t
if type(t[0]) is not tuple:
return list(t)
return [list(subtuple) for subtuple in t]


def float_parameter(parameter, total_length):
assert(type(parameter) in (int, float))
if parameter == -1:
return total_length
elif type(parameter) == float:
assert(parameter >= 0 and parameter <= 1)
elif isinstance(parameter, float):
if not (0 <= parameter <= 1):
raise ValueError("If parameter is a float, it must be between 0 and 1.")
return round(total_length * parameter)
else:
assert(parameter >= 0)
return parameter
if parameter < 0:
raise ValueError("If parameter is an int, it must be non-negative.")
return parameter

Loading…
Cancel
Save