diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 65c85fc..05b97a4 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -4,7 +4,10 @@ from collections import OrderedDict # for multiple predictions, modify from `learn_add.py` def flatten(l): - return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] + # return [item for sublist in l for item in flatten(sublist)] if isinstance(l, (list, tuple)) else [l] + if not isinstance(l[0], (list, tuple)): + return l + return [item for sublist in l for item in sublist] if isinstance(l, (list, tuple)) else [l] # for multiple predictions, modify from `learn_add.py` def reform_idx(flatten_pred_res, save_pred_res):