diff --git a/utils/utils.py b/utils/utils.py index d4cf947..44ba6a1 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -2,19 +2,19 @@ import numpy as np # for multiple predictions, modify from `learn_add.py` def flatten(l): - return [item for sublist in l for item in _flatten(sublist)] if isinstance(l, list) else [l] + return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] # for multiple predictions, modify from `learn_add.py` -def reform_ids(flatten_pred_res, save_pred_res): +def reform_idx(flatten_pred_res, save_pred_res): re = [] i = 0 for e in save_pred_res: j = 0 - ids = [] + idx = [] while j < len(e): - ids.append(flatten_pred_res[i + j]) + idx.append(flatten_pred_res[i + j]) j += 1 - re.append(ids) + re.append(idx) i = i + j return re