From 0e6d829ed10cb327db083cf3621cf4c85e7fc360 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Tue, 7 Nov 2023 10:29:32 +0800 Subject: [PATCH] [MNT] remove unnecessary utils functions --- abl/reasoning/kb.py | 4 +-- abl/reasoning/reasoner.py | 23 ++++++++++++++--- abl/utils/utils.py | 53 +++------------------------------------ 3 files changed, 26 insertions(+), 54 deletions(-) diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index cc34440..84b11be 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -5,7 +5,7 @@ import numpy as np from collections import defaultdict from itertools import product, combinations -from ..utils.utils import flatten, reform_idx, hamming_dist, check_equal, to_hashable, hashable_to_list +from ..utils.utils import flatten, reform_idx, hamming_dist, to_hashable, hashable_to_list from multiprocessing import Pool @@ -100,7 +100,7 @@ class KBBase(ABC): candidate = pred_pseudo_label.copy() for i, idx in enumerate(revision_idx): candidate[idx] = c[i] - if check_equal(self.logic_forward(candidate), y, self.max_err): + if abs(self.logic_forward(candidate) - y) <= self.max_err: candidates.append(candidate) return candidates diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 65879c7..f8c8497 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,11 +1,10 @@ import numpy as np from zoopt import Dimension, Objective, Parameter, Opt -from abl.utils.utils import ( +from ..utils.utils import ( confidence_dist, flatten, reform_idx, hamming_dist, - calculate_revision_num, ) @@ -168,6 +167,24 @@ class ReasonerBase: """ return self.kb.revise_at_idx(pred_pseudo_label, y, revision_idx) + def _get_max_revision_num(max_revision, symbol_num): + """ + Get the maximum revision number according to input `max_revision`. + """ + if not isinstance(max_revision, (int, float)): + raise TypeError("Parameter must be of type int or float.") + + if max_revision == -1: + return symbol_num + elif isinstance(max_revision, float): + if not (0 <= max_revision <= 1): + raise ValueError("If max_revision is a float, it must be between 0 and 1.") + return round(symbol_num * max_revision) + else: + if max_revision < 0: + raise ValueError("If max_revision is an int, it must be non-negative.") + return max_revision + def abduce( self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0 ): @@ -198,7 +215,7 @@ class ReasonerBase: knowledge base. """ symbol_num = len(flatten(pred_pseudo_label)) - max_revision_num = calculate_revision_num(max_revision, symbol_num) + max_revision_num = self._get_max_revision_num(max_revision, symbol_num) if self.use_zoopt: solution = self.zoopt_get_solution( diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 9d1dc7a..8c1b4d4 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -15,11 +15,6 @@ def flatten(nested_list): ------- list A flattened version of the input list. - - Raises - ------ - TypeError - If the input object is not a list. """ if not isinstance(nested_list, list): raise TypeError("Input must be of type list.") @@ -46,9 +41,6 @@ def reform_idx(flattened_list, structured_list): list A reformed list that mimics the structure of structured_list. """ - # if not isinstance(flattened_list, list): - # raise TypeError("Input must be of type list.") - if not isinstance(structured_list[0], (list, tuple)): return flattened_list @@ -88,7 +80,7 @@ def hamming_dist(pred_pseudo_label, candidates): return np.sum(pred_pseudo_label != candidates, axis=1) -def confidence_dist(pred_prob, candidates): +def confidence_dist(pred_prob, candidates_idx): """ Compute the confidence distance between prediction probabilities and candidates. @@ -97,7 +89,7 @@ def confidence_dist(pred_prob, candidates): pred_prob : list of numpy.ndarray Prediction probability distributions, each element is an ndarray representing the probability distribution of a particular prediction. - candidates : list of list of int + candidates_idx : list of list of int Index of candidate labels, each element is a list of indexes being considered as a candidate correction. @@ -107,8 +99,8 @@ def confidence_dist(pred_prob, candidates): Confidence distances computed for each candidate. """ pred_prob = np.clip(pred_prob, 1e-9, 1) - _, cols = np.indices((len(candidates), len(candidates[0]))) - return 1 - np.prod(pred_prob[cols, candidates], axis=1) + _, cols = np.indices((len(candidates_idx), len(candidates_idx[0]))) + return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1) def block_sample(X, Z, Y, sample_num, seg_idx): @@ -143,34 +135,6 @@ def block_sample(X, Z, Y, sample_num, seg_idx): return (data[start_idx:end_idx] for data in (X, Z, Y)) -def check_equal(a, b, max_err=0): - """ - Check whether two numbers a and b are equal within a maximum allowable error. - - Parameters - ---------- - a, b : int or float - The numbers to compare. - max_err : int or float, optional - The maximum allowable absolute difference between a and b for them to be considered equal. - Default is 0, meaning the numbers must be exactly equal. - - Returns - ------- - bool - True if a and b are equal within the allowable error, False otherwise. - - Raises - ------ - TypeError - If a or b are not of type int or float. - """ - if not (isinstance(a, (int, float)) and isinstance(b, (int, float))): - raise TypeError("Input values must be int or float.") - - return abs(a - b) <= max_err - - def to_hashable(x): """ Convert a nested list to a nested tuple so it is hashable. @@ -190,7 +154,6 @@ def to_hashable(x): return tuple(to_hashable(item) for item in x) return x - def hashable_to_list(x): """ Convert a nested tuple back to a nested list. @@ -227,13 +190,6 @@ def calculate_revision_num(parameter, total_length): ------- int The calculated parameter. - - Raises - ------ - TypeError - If parameter is not an int or a float. - ValueError - If parameter is a float not in [0, 1] or an int below 0. """ if not isinstance(parameter, (int, float)): raise TypeError("Parameter must be of type int or float.") @@ -303,5 +259,4 @@ if __name__ == "__main__": ) B = [[0, 9, 3], [0, 11, 4]] - print(ori_confidence_dist(A, B)) print(confidence_dist(A, B))