diff --git a/.gitignore b/.gitignore index cf8e7f0..931fddd 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ abl.egg-info/ examples/**/*.jpg .idea/ build/ -docs/API/generated/ \ No newline at end of file +docs/API/generated/ +.history \ No newline at end of file diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 0799605..082214d 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -118,7 +118,7 @@ class Reasoner: data_example : ListData Data example. candidates : List[List[Any]] - Multiple compatible candidates. + Multiple possible candidates. reasoning_results : List[Any] Corresponding reasoning results of the candidates. @@ -150,7 +150,7 @@ class Reasoner: data_example : ListData Data example. candidates : List[List[Any]] - Multiple compatible candidates. + Multiple possible candidates. reasoning_results : List[Any] Corresponding reasoning results of the candidates. @@ -162,8 +162,8 @@ class Reasoner: if self.dist_func == "hamming": return hamming_dist(data_example.pred_pseudo_label, candidates) elif self.dist_func == "confidence": - candidates = [[self.label_to_idx[x] for x in c] for c in candidates] - return confidence_dist(data_example.pred_prob, candidates) + candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] + return confidence_dist(data_example.pred_prob, candidates_idxs) else: candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 601c7a6..4fceec9 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,58 +1,54 @@ -from itertools import chain +from typing import List, Any, Union, Tuple import numpy as np -def flatten(nested_list): +def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[Any]: """ - Flattens a nested list. + Flattens a nested list at the first level. Parameters ---------- - nested_list : list - A list which might contain sublists or tuples. + nested_list : List[Union[Any, List[Any], Tuple[Any, ...]]] + A list which might contain sublists or tuples at the first level. Returns ------- - list - A flattened version of the input list. - - Raises - ------ - TypeError - If the input object is not a list. + List[Any] + A flattened version of the input list, where only the first + level of sublists and tuples are reduced. """ - # if not isinstance(nested_list, list): - # raise TypeError("Input must be of type list.") - - if isinstance(nested_list, list) and len(nested_list) == 0: + if not isinstance(nested_list, list): return nested_list - if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)): - return nested_list + flattened_list = [] + for item in nested_list: + if isinstance(item, (list, tuple)): + flattened_list.extend(item) + else: + flattened_list.append(item) - return list(chain.from_iterable(nested_list)) + return flattened_list - -def reform_list(flattened_list, structured_list): +def reform_list( + flattened_list: List[Any], + structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]] +) -> List[List[Any]]: """ - Reform the index based on structured_list structure. + Reform the list based on the structure of ``structured_list``. Parameters ---------- - flattened_list : list - A flattened list of predictions. - structured_list : list - A list containing saved predictions, which could be nested lists or tuples. + flattened_list : List[Any] + A flattened list of elements. + structured_list : List[Union[Any, List[Any], Tuple[Any, ...]]] + A list that reflects the desired structure, which may contain sublists or tuples. Returns ------- - list - A reformed list that mimics the structure of structured_list. + List[List[Any]] + A reformed list that mimics the structure of ``structured_list``. """ - # if not isinstance(flattened_list, list): - # raise TypeError("Input must be of type list.") - if not isinstance(structured_list[0], (list, tuple)): return flattened_list @@ -72,16 +68,15 @@ def hamming_dist(pred_pseudo_label, candidates): Parameters ---------- - pred_pseudo_label : list - First array to compare. - candidates : list - Second array to compare, expected to have shape (n, m) - where n is the number of rows, m is the length of pred_pseudo_label. + pred_pseudo_label : List[Any] + Pseudo-labels of an example. + candidates : List[List[Any]] + Multiple possible candidates. Returns ------- - numpy.ndarray - Hamming distances. + np.ndarray + Hamming distances computed for each candidate. """ pred_pseudo_label = np.array(pred_pseudo_label) candidates = np.array(candidates) @@ -92,27 +87,26 @@ def hamming_dist(pred_pseudo_label, candidates): return np.sum(pred_pseudo_label != candidates, axis=1) -def confidence_dist(pred_prob, candidates): +def confidence_dist(pred_prob, candidates_idxs): """ Compute the confidence distance between prediction probabilities and candidates. Parameters ---------- - pred_prob : list of numpy.ndarray + pred_prob : List[np.ndarray] Prediction probability distributions, each element is an ndarray representing the probability distribution of a particular prediction. - candidates : list of list of int - Index of candidate labels, each element is a list of indexes being considered - as a candidate correction. + candidates_idxs : List[List[Any]] + Multiple possible candidates' indices. Returns ------- - numpy.ndarray + np.ndarray Confidence distances computed for each candidate. """ pred_prob = np.clip(pred_prob, 1e-9, 1) - _, cols = np.indices((len(candidates), len(candidates[0]))) - return 1 - np.prod(pred_prob[cols, candidates], axis=1) + _, cols = np.indices((len(candidates_idxs), len(candidates_idxs[0]))) + return 1 - np.prod(pred_prob[cols, candidates_idxs], axis=1) def to_hashable(x): @@ -121,12 +115,12 @@ def to_hashable(x): Parameters ---------- - x : list or other type + x : Union[List[Any], Any] A potentially nested list to convert to a tuple. Returns ------- - tuple or other type + Union[Tuple[Any, ...], Any] The input converted to a tuple if it was a list, otherwise the original input. """ @@ -141,12 +135,12 @@ def restore_from_hashable(x): Parameters ---------- - x : tuple or other type + x : Union[Tuple[Any, ...], Any] A potentially nested tuple to convert to a list. Returns ------- - list or other type + Union[List[Any], Any] The input converted to a list if it was a tuple, otherwise the original input. """ @@ -156,14 +150,15 @@ def restore_from_hashable(x): def tab_data_to_tuple(X, y, reasoning_result = 0): ''' - Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. + Convert a tabular data to a tuple by adding a dimension to each element of + X and y. The tuple contains three elements: data, label, and reasoning result. If X is None, return None. Parameters ---------- - X : list or other type + X : Union[List[Any], Any] The data. - y : list or other type + y : Union[List[Any], Any] The label. reasoning_result : Any, optional The reasoning result, by default 0.