|
|
|
@@ -15,11 +15,6 @@ def flatten(nested_list): |
|
|
|
------- |
|
|
|
list |
|
|
|
A flattened version of the input list. |
|
|
|
|
|
|
|
Raises |
|
|
|
------ |
|
|
|
TypeError |
|
|
|
If the input object is not a list. |
|
|
|
""" |
|
|
|
if not isinstance(nested_list, list): |
|
|
|
raise TypeError("Input must be of type list.") |
|
|
|
@@ -46,9 +41,6 @@ def reform_idx(flattened_list, structured_list): |
|
|
|
list |
|
|
|
A reformed list that mimics the structure of structured_list. |
|
|
|
""" |
|
|
|
# if not isinstance(flattened_list, list): |
|
|
|
# raise TypeError("Input must be of type list.") |
|
|
|
|
|
|
|
if not isinstance(structured_list[0], (list, tuple)): |
|
|
|
return flattened_list |
|
|
|
|
|
|
|
@@ -88,7 +80,7 @@ def hamming_dist(pred_pseudo_label, candidates): |
|
|
|
return np.sum(pred_pseudo_label != candidates, axis=1) |
|
|
|
|
|
|
|
|
|
|
|
def confidence_dist(pred_prob, candidates): |
|
|
|
def confidence_dist(pred_prob, candidates_idx): |
|
|
|
""" |
|
|
|
Compute the confidence distance between prediction probabilities and candidates. |
|
|
|
|
|
|
|
@@ -97,7 +89,7 @@ def confidence_dist(pred_prob, candidates): |
|
|
|
pred_prob : list of numpy.ndarray |
|
|
|
Prediction probability distributions, each element is an ndarray |
|
|
|
representing the probability distribution of a particular prediction. |
|
|
|
candidates : list of list of int |
|
|
|
candidates_idx : list of list of int |
|
|
|
Index of candidate labels, each element is a list of indexes being considered |
|
|
|
as a candidate correction. |
|
|
|
|
|
|
|
@@ -107,8 +99,8 @@ def confidence_dist(pred_prob, candidates): |
|
|
|
Confidence distances computed for each candidate. |
|
|
|
""" |
|
|
|
pred_prob = np.clip(pred_prob, 1e-9, 1) |
|
|
|
_, cols = np.indices((len(candidates), len(candidates[0]))) |
|
|
|
return 1 - np.prod(pred_prob[cols, candidates], axis=1) |
|
|
|
_, cols = np.indices((len(candidates_idx), len(candidates_idx[0]))) |
|
|
|
return 1 - np.prod(pred_prob[cols, candidates_idx], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
def block_sample(X, Z, Y, sample_num, seg_idx): |
|
|
|
@@ -143,34 +135,6 @@ def block_sample(X, Z, Y, sample_num, seg_idx): |
|
|
|
return (data[start_idx:end_idx] for data in (X, Z, Y)) |
|
|
|
|
|
|
|
|
|
|
|
def check_equal(a, b, max_err=0): |
|
|
|
""" |
|
|
|
Check whether two numbers a and b are equal within a maximum allowable error. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
a, b : int or float |
|
|
|
The numbers to compare. |
|
|
|
max_err : int or float, optional |
|
|
|
The maximum allowable absolute difference between a and b for them to be considered equal. |
|
|
|
Default is 0, meaning the numbers must be exactly equal. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
bool |
|
|
|
True if a and b are equal within the allowable error, False otherwise. |
|
|
|
|
|
|
|
Raises |
|
|
|
------ |
|
|
|
TypeError |
|
|
|
If a or b are not of type int or float. |
|
|
|
""" |
|
|
|
if not (isinstance(a, (int, float)) and isinstance(b, (int, float))): |
|
|
|
raise TypeError("Input values must be int or float.") |
|
|
|
|
|
|
|
return abs(a - b) <= max_err |
|
|
|
|
|
|
|
|
|
|
|
def to_hashable(x): |
|
|
|
""" |
|
|
|
Convert a nested list to a nested tuple so it is hashable. |
|
|
|
@@ -190,7 +154,6 @@ def to_hashable(x): |
|
|
|
return tuple(to_hashable(item) for item in x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def hashable_to_list(x): |
|
|
|
""" |
|
|
|
Convert a nested tuple back to a nested list. |
|
|
|
@@ -227,13 +190,6 @@ def calculate_revision_num(parameter, total_length): |
|
|
|
------- |
|
|
|
int |
|
|
|
The calculated parameter. |
|
|
|
|
|
|
|
Raises |
|
|
|
------ |
|
|
|
TypeError |
|
|
|
If parameter is not an int or a float. |
|
|
|
ValueError |
|
|
|
If parameter is a float not in [0, 1] or an int below 0. |
|
|
|
""" |
|
|
|
if not isinstance(parameter, (int, float)): |
|
|
|
raise TypeError("Parameter must be of type int or float.") |
|
|
|
@@ -303,5 +259,4 @@ if __name__ == "__main__": |
|
|
|
) |
|
|
|
B = [[0, 9, 3], [0, 11, 4]] |
|
|
|
|
|
|
|
print(ori_confidence_dist(A, B)) |
|
|
|
print(confidence_dist(A, B)) |