|
|
|
@@ -2,19 +2,19 @@ import numpy as np |
|
|
|
|
|
|
|
# for multiple predictions, modify from `learn_add.py` |
|
|
|
def flatten(l): |
|
|
|
return [item for sublist in l for item in _flatten(sublist)] if isinstance(l, list) else [l] |
|
|
|
return [item for sublist in l for item in flatten(sublist)] if isinstance(l, list) else [l] |
|
|
|
|
|
|
|
# for multiple predictions, modify from `learn_add.py` |
|
|
|
def reform_ids(flatten_pred_res, save_pred_res): |
|
|
|
def reform_idx(flatten_pred_res, save_pred_res): |
|
|
|
re = [] |
|
|
|
i = 0 |
|
|
|
for e in save_pred_res: |
|
|
|
j = 0 |
|
|
|
ids = [] |
|
|
|
idx = [] |
|
|
|
while j < len(e): |
|
|
|
ids.append(flatten_pred_res[i + j]) |
|
|
|
idx.append(flatten_pred_res[i + j]) |
|
|
|
j += 1 |
|
|
|
re.append(ids) |
|
|
|
re.append(idx) |
|
|
|
i = i + j |
|
|
|
return re |
|
|
|
|
|
|
|
|