From 83a4258e8681c71ee9ecff2abf2757e1fafaa75c Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 7 Dec 2022 18:51:21 +0800 Subject: [PATCH] Create utils.py --- utils/utils.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 utils/utils.py diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..742dc03 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,35 @@ +import numpy as np + +# for multiple predictions, modify from `learn_add.py` +def _flatten(l): + return [item for sublist in l for item in _flatten(sublist)] if isinstance(l, list) else [l] + +# for multiple predictions, modify from `learn_add.py` +def _reform_ids(flatten_pred_res, save_pred_res): + re = [] + i = 0 + for e in save_pred_res: + j = 0 + ids = [] + while j < len(e): + ids.append(flatten_pred_res[i + j]) + j += 1 + re.append(ids) + i = i + j + return re + +def _hamming_dist(A, B): + B = np.array(B) + A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis = 1) + +def _confidence_dist(A, B): + B = np.array(B) + A = np.clip(A, 1e-9, 1) + A = np.expand_dims(A, axis=0) + A = A.repeat(axis=0, repeats=(len(B))) + rows = np.array(range(len(B))) + rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) + cols = np.array(range(len(B[0]))) + cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) + return 1 - np.prod(A[rows, cols, B], axis = 1)