diff --git a/abl/utils/__init__.py b/abl/utils/__init__.py index d69e09b..9cfd590 100644 --- a/abl/utils/__init__.py +++ b/abl/utils/__init__.py @@ -1,6 +1,6 @@ from .cache import Cache, abl_cache from .logger import ABLLogger, print_log -from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable +from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple __all__ = [ "Cache", @@ -12,4 +12,5 @@ __all__ = [ "reform_list", "to_hashable", "abl_cache", + "tab_data_to_tuple", ] diff --git a/abl/utils/utils.py b/abl/utils/utils.py index bbeb58b..f2ef808 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -154,4 +154,10 @@ def restore_from_hashable(x): return [restore_from_hashable(item) for item in x] return x - +def tab_data_to_tuple(X, y, reasoning_result = 0): + ''' + Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple is a list of three elements: data, label, and reasoning result. + ''' + if len(X) != len(y): + raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) + return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) \ No newline at end of file diff --git a/examples/zoo/main.py b/examples/zoo/main.py index 4ece65f..b4da2d1 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.learning import ABLModel from abl.reasoning import Reasoner -from abl.utils import ABLLogger, confidence_dist, print_log +from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple from get_dataset import load_and_preprocess_dataset, split_dataset from kb import ZooKB -def transform_tab_data(X, y): - return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) def consitency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob @@ -39,9 +37,9 @@ def main(): X, y = load_and_preprocess_dataset(dataset_id=62) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) - label_data = transform_tab_data(X_label, y_label) - test_data = transform_tab_data(X_test, y_test) - train_data = transform_tab_data(X_unlabel, y_unlabel) + label_data = tab_data_to_tuple(X_label, y_label) + test_data = tab_data_to_tuple(X_test, y_test) + train_data = tab_data_to_tuple(X_unlabel, y_unlabel) ### Building the Learning Part print_log("Building the Learning Part.", logger="current") diff --git a/examples/zoo/zoo.ipynb b/examples/zoo/zoo.ipynb index 4596a55..bf21f43 100644 --- a/examples/zoo/zoo.ipynb +++ b/examples/zoo/zoo.ipynb @@ -27,7 +27,7 @@ "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.learning import ABLModel\n", "from abl.reasoning import Reasoner\n", - "from abl.utils import ABLLogger, confidence_dist, print_log\n", + "from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n", "\n", "from get_dataset import load_and_preprocess_dataset, split_dataset\n", "from kb import ZooKB" @@ -106,11 +106,9 @@ "metadata": {}, "outputs": [], "source": [ - "def transform_tab_data(X, y):\n", - " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", - "label_data = transform_tab_data(X_label, y_label)\n", - "test_data = transform_tab_data(X_test, y_test)\n", - "train_data = transform_tab_data(X_unlabel, y_unlabel)" + "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", + "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", + "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" ] }, {