| @@ -46,9 +46,9 @@ def reform_idx(flattened_list, structured_list): | |||||
| list | list | ||||
| A reformed list that mimics the structure of structured_list. | A reformed list that mimics the structure of structured_list. | ||||
| """ | """ | ||||
| if not isinstance(flattened_list, list): | |||||
| raise TypeError("Input must be of type list.") | |||||
| # if not isinstance(flattened_list, list): | |||||
| # raise TypeError("Input must be of type list.") | |||||
| if not isinstance(structured_list[0], (list, tuple)): | if not isinstance(structured_list[0], (list, tuple)): | ||||
| return flattened_list | return flattened_list | ||||
| @@ -107,8 +107,8 @@ def confidence_dist(pred_prob, candidates): | |||||
| Confidence distances computed for each candidate. | Confidence distances computed for each candidate. | ||||
| """ | """ | ||||
| pred_prob = np.clip(pred_prob, 1e-9, 1) | pred_prob = np.clip(pred_prob, 1e-9, 1) | ||||
| rows, cols = np.indices((len(candidates), len(candidates[0]))) | |||||
| return 1 - np.prod(pred_prob[rows, cols, candidates], axis=1) | |||||
| _, cols = np.indices((len(candidates), len(candidates[0]))) | |||||
| return 1 - np.prod(pred_prob[cols, candidates], axis=1) | |||||
| def block_sample(X, Z, Y, sample_num, seg_idx): | def block_sample(X, Z, Y, sample_num, seg_idx): | ||||
| @@ -248,3 +248,60 @@ def calculate_revision_num(parameter, total_length): | |||||
| if parameter < 0: | if parameter < 0: | ||||
| raise ValueError("If parameter is an int, it must be non-negative.") | raise ValueError("If parameter is an int, it must be non-negative.") | ||||
| return parameter | return parameter | ||||
| if __name__ == "__main__": | |||||
| A = np.array( | |||||
| [ | |||||
| [ | |||||
| 0.18401675, | |||||
| 0.06797526, | |||||
| 0.06797541, | |||||
| 0.06801736, | |||||
| 0.06797528, | |||||
| 0.06797526, | |||||
| 0.06818808, | |||||
| 0.06797527, | |||||
| 0.06800033, | |||||
| 0.06797526, | |||||
| 0.06797526, | |||||
| 0.06797526, | |||||
| 0.06797526, | |||||
| ], | |||||
| [ | |||||
| 0.07223161, | |||||
| 0.0685229, | |||||
| 0.06852708, | |||||
| 0.17227574, | |||||
| 0.06852163, | |||||
| 0.07018146, | |||||
| 0.06860291, | |||||
| 0.06852849, | |||||
| 0.06852163, | |||||
| 0.0685216, | |||||
| 0.0685216, | |||||
| 0.06852174, | |||||
| 0.0685216, | |||||
| ], | |||||
| [ | |||||
| 0.06794382, | |||||
| 0.0679436, | |||||
| 0.06794395, | |||||
| 0.06794346, | |||||
| 0.06794346, | |||||
| 0.18467231, | |||||
| 0.06794345, | |||||
| 0.06794871, | |||||
| 0.06794345, | |||||
| 0.06794345, | |||||
| 0.06794345, | |||||
| 0.06794345, | |||||
| 0.06794345, | |||||
| ], | |||||
| ], | |||||
| dtype=np.float32, | |||||
| ) | |||||
| B = [[0, 9, 3], [0, 11, 4]] | |||||
| print(ori_confidence_dist(A, B)) | |||||
| print(confidence_dist(A, B)) | |||||