Browse Source

[ENH] move hed specific method in utils to examples/hed/utils.py

pull/3/head
Gao Enhao 2 years ago
parent
commit
57f31fb83c
2 changed files with 48 additions and 49 deletions
  1. +1
    -49
      abl/utils/utils.py
  2. +47
    -0
      examples/hed/utils.py

+ 1
- 49
abl/utils/utils.py View File

@@ -1,8 +1,5 @@
import torch
import torch.nn as nn
import numpy as np
from .plog import INFO
from collections import OrderedDict
from itertools import chain

def flatten(l):
@@ -51,31 +48,6 @@ def block_sample(X, Z, Y, sample_num, epoch_idx):

return X, Z, Y

def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print("Characters and symbols size dosen't match.")
return
from itertools import permutations

mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings


def mapping_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]


def remapping_res(pred_res, m):
remapping = {}
for key, value in m.items():
remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res]

def check_equal(a, b, max_err=0):
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
@@ -90,27 +62,7 @@ def check_equal(a, b, max_err=0):
return True

else:
return a == b


def extract_feature(img):
extractor = nn.AvgPool2d(2, stride=2)
feature_map = np.array(extractor(torch.Tensor(img)))
return feature_map.reshape((-1,))
return np.concatenate(
(np.squeeze(np.sum(img, axis=1)), np.squeeze(np.sum(img, axis=2))), axis=0
)


def reduce_dimension(data):
for truth_value in [0, 1]:
for equation_len in range(5, 27):
equations = data[truth_value][equation_len]
reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
]
data[truth_value][equation_len] = reduced_equations
return a == b

def to_hashable(l):


+ 47
- 0
examples/hed/utils.py View File

@@ -0,0 +1,47 @@
import torch
import torch.nn as nn
import numpy as np


def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print("Characters and symbols size dosen't match.")
return
from itertools import permutations

mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings


def mapping_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]


def remapping_res(pred_res, m):
remapping = {}
for key, value in m.items():
remapping[value] = key
return [[remapping[symbol] for symbol in formula] for formula in pred_res]


def extract_feature(img):
extractor = nn.AvgPool2d(2, stride=2)
feature_map = np.array(extractor(torch.Tensor(img)))
return feature_map.reshape((-1,))


def reduce_dimension(data):
for truth_value in [0, 1]:
for equation_len in range(5, 27):
equations = data[truth_value][equation_len]
reduced_equations = [
[extract_feature(symbol_img) for symbol_img in equation]
for equation in equations
]
data[truth_value][equation_len] = reduced_equations

Loading…
Cancel
Save