Browse Source

Add test_hed

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
ae472fccc9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 38 additions and 31 deletions
  1. +38
    -31
      utils/utils.py

+ 38
- 31
utils/utils.py View File

@@ -2,16 +2,10 @@ import numpy as np
from utils.plog import INFO
from collections import OrderedDict


# for multiple predictions, modify from `learn_add.py`
def flatten(l):
return (
[item for sublist in l for item in flatten(sublist)]
if isinstance(l, list)
else [l]
)


return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l]
# for multiple predictions, modify from `learn_add.py`
def reform_idx(flatten_pred_res, save_pred_res):
re = []
@@ -26,6 +20,22 @@ def reform_idx(flatten_pred_res, save_pred_res):
i = i + j
return re

def hamming_dist(A, B):
B = np.array(B)
A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis = 1)

def confidence_dist(A, B):
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
return 1 - np.prod(A[rows, cols, B], axis = 1)


def block_sample(X, Z, Y, sample_num, epoch_idx):
part_num = len(X) // sample_num
@@ -40,28 +50,25 @@ def block_sample(X, Z, Y, sample_num, epoch_idx):
return X, Z, Y


def hamming_dist(A, B):
B = np.array(B)
A = np.expand_dims(A, axis=0).repeat(axis=0, repeats=(len(B)))
return np.sum(A != B, axis=1)


def confidence_dist(A, B):
B = np.array(B)
A = np.clip(A, 1e-9, 1)
A = np.expand_dims(A, axis=0)
A = A.repeat(axis=0, repeats=(len(B)))
rows = np.array(range(len(B)))
rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0]))
cols = np.array(range(len(B[0])))
cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B))
return 1 - np.prod(A[rows, cols, B], axis=1)
def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print('Characters and symbols size dosen\'t match.')
return
from itertools import permutations
mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings

def mapping_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]

def copy_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('base_model'):
name = ".".join(k.split(".")[1:])
new_state_dict[name] = v
return new_state_dict
def remapping_res(pred_res, m):
remapping = {}
for key, value in m.items():
remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res]

Loading…
Cancel
Save