import torch import torch.nn as nn import numpy as np from .plog import INFO from collections import OrderedDict from itertools import chain def flatten(l): if not isinstance(l[0], (list, tuple)): return l return list(chain.from_iterable(l)) def reform_idx(flatten_pred_res, save_pred_res): if not isinstance(save_pred_res[0], (list, tuple)): return flatten_pred_res re = [] i = 0 for e in save_pred_res: re.append(flatten_pred_res[i:i + len(e)]) i += len(e) return re def hamming_dist(A, B): A = np.array(A, dtype='