diff --git a/utils/utils.py b/utils/utils.py index f4db6f2..1138361 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -72,3 +72,18 @@ def remapping_res(pred_res, m): for key, value in m.items(): remapping[value] = key return [[remapping[symbol] for symbol in formula] for formula in pred_res] + +def check_equal(a, b): + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + return abs(a - b) <= 1e-3 + + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + for i in range(len(a)): + if not check_equal(a[i], b[i]): + return False + return True + + else: + return a == b