Browse Source

[ENH] move hed specific method in utils to examples/hed/utils.py

pull/3/head
Gao Enhao 3 years ago
parent
commit
57f31fb83c
2 changed files with 48 additions and 49 deletions
  1. +1
    -49
      abl/utils/utils.py
  2. +47
    -0
      examples/hed/utils.py

+ 1
- 49
abl/utils/utils.py View File

@@ -1,8 +1,5 @@
import torch
import torch.nn as nn
import numpy as np import numpy as np
from .plog import INFO from .plog import INFO
from collections import OrderedDict
from itertools import chain from itertools import chain


def flatten(l): def flatten(l):
@@ -51,31 +48,6 @@ def block_sample(X, Z, Y, sample_num, epoch_idx):


return X, Z, Y return X, Z, Y


def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print("Characters and symbols size dosen't match.")
return
from itertools import permutations

mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings


def mapping_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]


def remapping_res(pred_res, m):
remapping = {}
for key, value in m.items():
remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res]


def check_equal(a, b, max_err=0): def check_equal(a, b, max_err=0):
if isinstance(a, (int, float)) and isinstance(b, (int, float)): if isinstance(a, (int, float)) and isinstance(b, (int, float)):
@@ -90,27 +62,7 @@ def check_equal(a, b, max_err=0):
return True return True


else: else:
return a == b


def extract_feature(img):
extractor = nn.AvgPool2d(2, stride=2)
feature_map = np.array(extractor(torch.Tensor(img)))
return feature_map.reshape((-1,))
return np.concatenate(
(np.squeeze(np.sum(img, axis=1)), np.squeeze(np.sum(img, axis=2))), axis=0
)


def reduce_dimension(data):
for truth_value in [0, 1]:
for equation_len in range(5, 27):
equations = data[truth_value][equation_len]
reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
]
data[truth_value][equation_len] = reduced_equations
return a == b


def to_hashable(l): def to_hashable(l):


+ 47
- 0
examples/hed/utils.py View File

@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import numpy as np


def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print("Characters and symbols size dosen't match.")
return
from itertools import permutations

mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings


def mapping_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]


def remapping_res(pred_res, m):
remapping = {}
for key, value in m.items():
remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res]


def extract_feature(img):
extractor = nn.AvgPool2d(2, stride=2)
feature_map = np.array(extractor(torch.Tensor(img)))
return feature_map.reshape((-1,))


def reduce_dimension(data):
for truth_value in [0, 1]:
for equation_len in range(5, 27):
equations = data[truth_value][equation_len]
reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
]
data[truth_value][equation_len] = reduced_equations

Loading…
Cancel
Save