From 736d11c03e0369bc52c1c5c0a99e89e7158c8271 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Fri, 13 Oct 2023 15:21:43 +0800 Subject: [PATCH] [MNT] reformat code and add doc string to utils.py --- abl/reasoning/reasoner.py | 4 +- abl/utils/utils.py | 306 +++++++++++++++++++++++++++++--------- 2 files changed, 236 insertions(+), 74 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 2dc1cab..4ff0a75 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -5,7 +5,7 @@ from ..utils.utils import ( flatten, reform_idx, hamming_dist, - float_parameter, + calculate_revision_num, ) @@ -214,7 +214,7 @@ class ReasonerBase: The abduced revisions. """ symbol_num = len(flatten(pred_pseudo_label)) - max_revision_num = float_parameter(max_revision, symbol_num) + max_revision_num = calculate_revision_num(max_revision, symbol_num) if self.use_zoopt: solution = self.zoopt_get_solution( diff --git a/abl/utils/utils.py b/abl/utils/utils.py index c7a4bee..475e5de 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,85 +1,247 @@ import numpy as np from itertools import chain -def flatten(l): - if not isinstance(l[0], (list, tuple)): - return l - return list(chain.from_iterable(l)) - -def reform_idx(flatten_pred_res, save_pred_res): - if not isinstance(save_pred_res[0], (list, tuple)): - return flatten_pred_res - - re = [] - i = 0 - for e in save_pred_res: - re.append(flatten_pred_res[i:i + len(e)]) - i += len(e) - return re - - -def hamming_dist(A, B): - A = np.array(A, dtype='>> X = [1, 2, 3, 4, 5, 6] + >>> Z = ['a', 'b', 'c', 'd', 'e', 'f'] + >>> Y = [10, 11, 12, 13, 14, 15] + >>> block_sample(X, Z, Y, 2, 1) + ([3, 4], ['c', 'd'], [12, 13]) + """ + start_idx = sample_num * seg_idx + end_idx = sample_num * (seg_idx + 1) + + return (data[start_idx:end_idx] for data in (X, Z, Y)) def check_equal(a, b, max_err=0): - if isinstance(a, (int, float)) and isinstance(b, (int, float)): - return abs(a - b) <= max_err - - if isinstance(a, list) and isinstance(b, list): - if len(a) != len(b): - return False - for i in range(len(a)): - if not check_equal(a[i], b[i]): - return False - return True + """ + Check whether two numbers a and b are equal within a maximum allowable error. + + Parameters + ---------- + a, b : int or float + The numbers to compare. + max_err : int or float, optional + The maximum allowable absolute difference between a and b for them to be considered equal. + Default is 0, meaning the numbers must be exactly equal. + + Returns + ------- + bool + True if a and b are equal within the allowable error, False otherwise. + + Raises + ------ + TypeError + If a or b are not of type int or float. + """ + if not (isinstance(a, (int, float)) and isinstance(b, (int, float))): + raise TypeError("Input values must be int or float.") + + return abs(a - b) <= max_err + + +def to_hashable(x): + """ + Convert a nested list to a nested tuple so it is hashable. + + Parameters + ---------- + x : list or other type + A potentially nested list to convert to a tuple. + + Returns + ------- + tuple or other type + The input converted to a tuple if it was a list, + otherwise the original input. + """ + if isinstance(x, list): + return tuple(to_hashable(item) for item in x) + return x + + +def hashable_to_list(x): + """ + Convert a nested tuple back to a nested list. + + Parameters + ---------- + x : tuple or other type + A potentially nested tuple to convert to a list. + + Returns + ------- + list or other type + The input converted to a list if it was a tuple, + otherwise the original input. + """ + if isinstance(x, tuple): + return [hashable_to_list(item) for item in x] + return x + + +def calculate_revision_num(parameter, total_length): + """ + Convert a float parameter to an integer, based on a total length. + + Parameters + ---------- + parameter : int or float + The parameter to convert. If float, it should be between 0 and 1. + If int, it should be non-negative. If -1, it will be replaced with total_length. + total_length : int + The total length to calculate the parameter from if it's a fraction. + + Returns + ------- + int + The calculated parameter. + + Raises + ------ + TypeError + If parameter is not an int or a float. + ValueError + If parameter is a float not in [0, 1] or an int below 0. + """ + if not isinstance(parameter, (int, float)): + raise TypeError("Parameter must be of type int or float.") - else: - return a == b - - -def to_hashable(l): - if type(l) is not list: - return l - if type(l[0]) is not list: - return tuple(l) - return tuple(tuple(sublist) for sublist in l) - -def hashable_to_list(t): - if type(t) is not tuple: - return t - if type(t[0]) is not tuple: - return list(t) - return [list(subtuple) for subtuple in t] - - -def float_parameter(parameter, total_length): - assert(type(parameter) in (int, float)) if parameter == -1: return total_length - elif type(parameter) == float: - assert(parameter >= 0 and parameter <= 1) + elif isinstance(parameter, float): + if not (0 <= parameter <= 1): + raise ValueError("If parameter is a float, it must be between 0 and 1.") return round(total_length * parameter) else: - assert(parameter >= 0) - return parameter \ No newline at end of file + if parameter < 0: + raise ValueError("If parameter is an int, it must be non-negative.") + return parameter