diff --git a/abl/utils/utils.py b/abl/utils/utils.py index d6e45db..8f29a02 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -158,6 +158,8 @@ def tab_data_to_tuple(X, y, reasoning_result = 0): ''' Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple contains three elements: data, label, and reasoning result. ''' + if X is None: + return None if len(X) != len(y): raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) \ No newline at end of file