diff --git a/utils/utils.py b/utils/utils.py index 57dea31..f4db6f2 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -2,16 +2,10 @@ import numpy as np from utils.plog import INFO from collections import OrderedDict - # for multiple predictions, modify from `learn_add.py` def flatten(l): - return ( - [item for sublist in l for item in flatten(sublist)] - if isinstance(l, list) - else [l] - ) - - + return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] + # for multiple predictions, modify from `learn_add.py` def reform_idx(flatten_pred_res, save_pred_res): re = [] @@ -26,6 +20,22 @@ def reform_idx(flatten_pred_res, save_pred_res): i = i + j return re +def hamming_dist(A, B): + B = np.array(B) + A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) + return np.sum(A != B, axis = 1) + +def confidence_dist(A, B): + B = np.array(B) + A = np.clip(A, 1e-9, 1) + A = np.expand_dims(A, axis=0) + A = A.repeat(axis=0, repeats=(len(B))) + rows = np.array(range(len(B))) + rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) + cols = np.array(range(len(B[0]))) + cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) + return 1 - np.prod(A[rows, cols, B], axis = 1) + def block_sample(X, Z, Y, sample_num, epoch_idx): part_num = len(X) // sample_num @@ -40,28 +50,25 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): return X, Z, Y -def hamming_dist(A, B): - B = np.array(B) - A = np.expand_dims(A, axis=0).repeat(axis=0, repeats=(len(B))) - return np.sum(A != B, axis=1) - - -def confidence_dist(A, B): - B = np.array(B) - A = np.clip(A, 1e-9, 1) - A = np.expand_dims(A, axis=0) - A = A.repeat(axis=0, repeats=(len(B))) - rows = np.array(range(len(B))) - rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0])) - cols = np.array(range(len(B[0]))) - cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) - return 1 - np.prod(A[rows, cols, B], axis=1) +def gen_mappings(chars, symbs): + n_char = len(chars) + n_symbs = len(symbs) + if n_char != n_symbs: + print('Characters and symbols size dosen\'t match.') + return + from itertools import permutations + mappings = [] + # returned mappings + perms = permutations(symbs) + for p in perms: + mappings.append(dict(zip(chars, list(p)))) + return mappings +def mapping_res(original_pred_res, m): + return [[m[symbol] for symbol in formula] for formula in original_pred_res] -def copy_state_dict(state_dict): - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - if k.startswith('base_model'): - name = ".".join(k.split(".")[1:]) - new_state_dict[name] = v - return new_state_dict +def remapping_res(pred_res, m): + remapping = {} + for key, value in m.items(): + remapping[value] = key + return [[remapping[symbol] for symbol in formula] for formula in pred_res]