Browse Source

[FIX] fix bug in confidence dist calculation

pull/3/head
Gao Enhao 2 years ago
parent
commit
71a59777b9
1 changed files with 62 additions and 5 deletions
  1. +62
    -5
      abl/utils/utils.py

+ 62
- 5
abl/utils/utils.py View File

@@ -46,9 +46,9 @@ def reform_idx(flattened_list, structured_list):
list
A reformed list that mimics the structure of structured_list.
"""
if not isinstance(flattened_list, list):
raise TypeError("Input must be of type list.")
# if not isinstance(flattened_list, list):
# raise TypeError("Input must be of type list.")
if not isinstance(structured_list[0], (list, tuple)):
return flattened_list

@@ -107,8 +107,8 @@ def confidence_dist(pred_prob, candidates):
Confidence distances computed for each candidate.
"""
pred_prob = np.clip(pred_prob, 1e-9, 1)
rows, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[rows, cols, candidates], axis=1)
_, cols = np.indices((len(candidates), len(candidates[0])))
return 1 - np.prod(pred_prob[cols, candidates], axis=1)


def block_sample(X, Z, Y, sample_num, seg_idx):
@@ -248,3 +248,60 @@ def calculate_revision_num(parameter, total_length):
if parameter < 0:
raise ValueError("If parameter is an int, it must be non-negative.")
return parameter


if __name__ == "__main__":
A = np.array(
[
[
0.18401675,
0.06797526,
0.06797541,
0.06801736,
0.06797528,
0.06797526,
0.06818808,
0.06797527,
0.06800033,
0.06797526,
0.06797526,
0.06797526,
0.06797526,
],
[
0.07223161,
0.0685229,
0.06852708,
0.17227574,
0.06852163,
0.07018146,
0.06860291,
0.06852849,
0.06852163,
0.0685216,
0.0685216,
0.06852174,
0.0685216,
],
[
0.06794382,
0.0679436,
0.06794395,
0.06794346,
0.06794346,
0.18467231,
0.06794345,
0.06794871,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
0.06794345,
],
],
dtype=np.float32,
)
B = [[0, 9, 3], [0, 11, 4]]

print(ori_confidence_dist(A, B))
print(confidence_dist(A, B))

Loading…
Cancel
Save