Browse Source

Create utils.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
83a4258e86
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 0 deletions
  1. +35
    -0
      utils/utils.py

+ 35
- 0
utils/utils.py View File

@@ -0,0 +1,35 @@
import numpy as np

# for multiple predictions, modify from `learn_add.py`
def _flatten(l):
return [item for sublist in l for item in _flatten(sublist)] if isinstance(l, list) else [l]
# for multiple predictions, modify from `learn_add.py`
def _reform_ids(flatten_pred_res, save_pred_res):
re = []
i = 0
for e in save_pred_res:
j = 0
ids = []
while j < len(e):
ids.append(flatten_pred_res[i + j])
j += 1
re.append(ids)
i = i + j
return re

def _hamming_dist(A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def _confidence_dist(A, B):
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)

Loading…
Cancel
Save