| @@ -2,16 +2,10 @@ import numpy as np | |||
| from utils.plog import INFO | |||
| from collections import OrderedDict | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def flatten(l): | |||
| return ( | |||
| [item for sublist in l for item in flatten(sublist)] | |||
| if isinstance(l, list) | |||
| else [l] | |||
| ) | |||
| return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] | |||
| # for multiple predictions, modify from `learn_add.py` | |||
| def reform_idx(flatten_pred_res, save_pred_res): | |||
| re = [] | |||
| @@ -26,6 +20,22 @@ def reform_idx(flatten_pred_res, save_pred_res): | |||
| i = i + j | |||
| return re | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis = 1) | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.clip(A, 1e-9, 1) | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis = 1) | |||
| def block_sample(X, Z, Y, sample_num, epoch_idx): | |||
| part_num = len(X) // sample_num | |||
| @@ -40,28 +50,25 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): | |||
| return X, Z, Y | |||
| def hamming_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.expand_dims(A, axis=0).repeat(axis=0, repeats=(len(B))) | |||
| return np.sum(A != B, axis=1) | |||
| def confidence_dist(A, B): | |||
| B = np.array(B) | |||
| A = np.clip(A, 1e-9, 1) | |||
| A = np.expand_dims(A, axis=0) | |||
| A = A.repeat(axis=0, repeats=(len(B))) | |||
| rows = np.array(range(len(B))) | |||
| rows = np.expand_dims(rows, axis=1).repeat(axis=1, repeats=len(B[0])) | |||
| cols = np.array(range(len(B[0]))) | |||
| cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis=1) | |||
| def gen_mappings(chars, symbs): | |||
| n_char = len(chars) | |||
| n_symbs = len(symbs) | |||
| if n_char != n_symbs: | |||
| print('Characters and symbols size dosen\'t match.') | |||
| return | |||
| from itertools import permutations | |||
| mappings = [] | |||
| # returned mappings | |||
| perms = permutations(symbs) | |||
| for p in perms: | |||
| mappings.append(dict(zip(chars, list(p)))) | |||
| return mappings | |||
| def mapping_res(original_pred_res, m): | |||
| return [[m[symbol] for symbol in formula] for formula in original_pred_res] | |||
| def copy_state_dict(state_dict): | |||
| new_state_dict = OrderedDict() | |||
| for k, v in state_dict.items(): | |||
| if k.startswith('base_model'): | |||
| name = ".".join(k.split(".")[1:]) | |||
| new_state_dict[name] = v | |||
| return new_state_dict | |||
| def remapping_res(pred_res, m): | |||
| remapping = {} | |||
| for key, value in m.items(): | |||
| remapping[value] = key | |||
| return [[remapping[symbol] for symbol in formula] for formula in pred_res] | |||