From 71a59777b9759dcbee63cc4de491fddfdd74e717 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 15 Oct 2023 20:41:20 +0800 Subject: [PATCH] [FIX] fix bug in confidence dist calculation --- abl/utils/utils.py | 67 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index c9d00ce..9d1dc7a 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -46,9 +46,9 @@ def reform_idx(flattened_list, structured_list): list A reformed list that mimics the structure of structured_list. """ - if not isinstance(flattened_list, list): - raise TypeError("Input must be of type list.") - + # if not isinstance(flattened_list, list): + # raise TypeError("Input must be of type list.") + if not isinstance(structured_list[0], (list, tuple)): return flattened_list @@ -107,8 +107,8 @@ def confidence_dist(pred_prob, candidates): Confidence distances computed for each candidate. """ pred_prob = np.clip(pred_prob, 1e-9, 1) - rows, cols = np.indices((len(candidates), len(candidates[0]))) - return 1 - np.prod(pred_prob[rows, cols, candidates], axis=1) + _, cols = np.indices((len(candidates), len(candidates[0]))) + return 1 - np.prod(pred_prob[cols, candidates], axis=1) def block_sample(X, Z, Y, sample_num, seg_idx): @@ -248,3 +248,60 @@ def calculate_revision_num(parameter, total_length): if parameter < 0: raise ValueError("If parameter is an int, it must be non-negative.") return parameter + + +if __name__ == "__main__": + A = np.array( + [ + [ + 0.18401675, + 0.06797526, + 0.06797541, + 0.06801736, + 0.06797528, + 0.06797526, + 0.06818808, + 0.06797527, + 0.06800033, + 0.06797526, + 0.06797526, + 0.06797526, + 0.06797526, + ], + [ + 0.07223161, + 0.0685229, + 0.06852708, + 0.17227574, + 0.06852163, + 0.07018146, + 0.06860291, + 0.06852849, + 0.06852163, + 0.0685216, + 0.0685216, + 0.06852174, + 0.0685216, + ], + [ + 0.06794382, + 0.0679436, + 0.06794395, + 0.06794346, + 0.06794346, + 0.18467231, + 0.06794345, + 0.06794871, + 0.06794345, + 0.06794345, + 0.06794345, + 0.06794345, + 0.06794345, + ], + ], + dtype=np.float32, + ) + B = [[0, 9, 3], [0, 11, 4]] + + print(ori_confidence_dist(A, B)) + print(confidence_dist(A, B))