| @@ -46,9 +46,9 @@ def reform_idx(flattened_list, structured_list): | |||
| list | |||
| A reformed list that mimics the structure of structured_list. | |||
| """ | |||
| if not isinstance(flattened_list, list): | |||
| raise TypeError("Input must be of type list.") | |||
| # if not isinstance(flattened_list, list): | |||
| # raise TypeError("Input must be of type list.") | |||
| if not isinstance(structured_list[0], (list, tuple)): | |||
| return flattened_list | |||
| @@ -107,8 +107,8 @@ def confidence_dist(pred_prob, candidates): | |||
| Confidence distances computed for each candidate. | |||
| """ | |||
| pred_prob = np.clip(pred_prob, 1e-9, 1) | |||
| rows, cols = np.indices((len(candidates), len(candidates[0]))) | |||
| return 1 - np.prod(pred_prob[rows, cols, candidates], axis=1) | |||
| _, cols = np.indices((len(candidates), len(candidates[0]))) | |||
| return 1 - np.prod(pred_prob[cols, candidates], axis=1) | |||
| def block_sample(X, Z, Y, sample_num, seg_idx): | |||
| @@ -248,3 +248,60 @@ def calculate_revision_num(parameter, total_length): | |||
| if parameter < 0: | |||
| raise ValueError("If parameter is an int, it must be non-negative.") | |||
| return parameter | |||
| if __name__ == "__main__": | |||
| A = np.array( | |||
| [ | |||
| [ | |||
| 0.18401675, | |||
| 0.06797526, | |||
| 0.06797541, | |||
| 0.06801736, | |||
| 0.06797528, | |||
| 0.06797526, | |||
| 0.06818808, | |||
| 0.06797527, | |||
| 0.06800033, | |||
| 0.06797526, | |||
| 0.06797526, | |||
| 0.06797526, | |||
| 0.06797526, | |||
| ], | |||
| [ | |||
| 0.07223161, | |||
| 0.0685229, | |||
| 0.06852708, | |||
| 0.17227574, | |||
| 0.06852163, | |||
| 0.07018146, | |||
| 0.06860291, | |||
| 0.06852849, | |||
| 0.06852163, | |||
| 0.0685216, | |||
| 0.0685216, | |||
| 0.06852174, | |||
| 0.0685216, | |||
| ], | |||
| [ | |||
| 0.06794382, | |||
| 0.0679436, | |||
| 0.06794395, | |||
| 0.06794346, | |||
| 0.06794346, | |||
| 0.18467231, | |||
| 0.06794345, | |||
| 0.06794871, | |||
| 0.06794345, | |||
| 0.06794345, | |||
| 0.06794345, | |||
| 0.06794345, | |||
| 0.06794345, | |||
| ], | |||
| ], | |||
| dtype=np.float32, | |||
| ) | |||
| B = [[0, 9, 3], [0, 11, 4]] | |||
| print(ori_confidence_dist(A, B)) | |||
| print(confidence_dist(A, B)) | |||