From 57f31fb83c4b58e49ab75bac466dab1ddcbbbabe Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 31 Mar 2023 15:52:57 +0800 Subject: [PATCH] [ENH] move hed specific method in utils to examples/hed/utils.py --- abl/utils/utils.py | 50 +------------------------------------------ examples/hed/utils.py | 47 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 49 deletions(-) create mode 100644 examples/hed/utils.py diff --git a/abl/utils/utils.py b/abl/utils/utils.py index fb55467..c0b867b 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,8 +1,5 @@ -import torch -import torch.nn as nn import numpy as np from .plog import INFO -from collections import OrderedDict from itertools import chain def flatten(l): @@ -51,31 +48,6 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): return X, Z, Y -def gen_mappings(chars, symbs): - n_char = len(chars) - n_symbs = len(symbs) - if n_char != n_symbs: - print("Characters and symbols size dosen't match.") - return - from itertools import permutations - - mappings = [] - # returned mappings - perms = permutations(symbs) - for p in perms: - mappings.append(dict(zip(chars, list(p)))) - return mappings - - -def mapping_res(original_pred_res, m): - return [[m[symbol] for symbol in formula] for formula in original_pred_res] - - -def remapping_res(pred_res, m): - remapping = {} - for key, value in m.items(): - remapping[value] = key - return [[remapping[symbol] for symbol in formula] for formula in pred_res] def check_equal(a, b, max_err=0): if isinstance(a, (int, float)) and isinstance(b, (int, float)): @@ -90,27 +62,7 @@ def check_equal(a, b, max_err=0): return True else: - return a == b - - -def extract_feature(img): - extractor = nn.AvgPool2d(2, stride=2) - feature_map = np.array(extractor(torch.Tensor(img))) - return feature_map.reshape((-1,)) - return np.concatenate( - (np.squeeze(np.sum(img, axis=1)), np.squeeze(np.sum(img, axis=2))), axis=0 - ) - - -def reduce_dimension(data): - for truth_value in [0, 1]: - for equation_len in range(5, 27): - equations = data[truth_value][equation_len] - reduced_equations = [ - [extract_feature(symbol_img) for symbol_img in equation] - for equation in equations - ] - data[truth_value][equation_len] = reduced_equations + return a == b def to_hashable(l): diff --git a/examples/hed/utils.py b/examples/hed/utils.py new file mode 100644 index 0000000..cf35eaf --- /dev/null +++ b/examples/hed/utils.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import numpy as np + + +def gen_mappings(chars, symbs): + n_char = len(chars) + n_symbs = len(symbs) + if n_char != n_symbs: + print("Characters and symbols size dosen't match.") + return + from itertools import permutations + + mappings = [] + # returned mappings + perms = permutations(symbs) + for p in perms: + mappings.append(dict(zip(chars, list(p)))) + return mappings + + +def mapping_res(original_pred_res, m): + return [[m[symbol] for symbol in formula] for formula in original_pred_res] + + +def remapping_res(pred_res, m): + remapping = {} + for key, value in m.items(): + remapping[value] = key + return [[remapping[symbol] for symbol in formula] for formula in pred_res] + + +def extract_feature(img): + extractor = nn.AvgPool2d(2, stride=2) + feature_map = np.array(extractor(torch.Tensor(img))) + return feature_map.reshape((-1,)) + + +def reduce_dimension(data): + for truth_value in [0, 1]: + for equation_len in range(5, 27): + equations = data[truth_value][equation_len] + reduced_equations = [ + [extract_feature(symbol_img) for symbol_img in equation] + for equation in equations + ] + data[truth_value][equation_len] = reduced_equations \ No newline at end of file