| @@ -1,6 +1,6 @@ | |||||
| from .cache import Cache, abl_cache | from .cache import Cache, abl_cache | ||||
| from .logger import ABLLogger, print_log | from .logger import ABLLogger, print_log | ||||
| from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable | |||||
| from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple | |||||
| __all__ = [ | __all__ = [ | ||||
| "Cache", | "Cache", | ||||
| @@ -12,4 +12,5 @@ __all__ = [ | |||||
| "reform_list", | "reform_list", | ||||
| "to_hashable", | "to_hashable", | ||||
| "abl_cache", | "abl_cache", | ||||
| "tab_data_to_tuple", | |||||
| ] | ] | ||||
| @@ -154,4 +154,10 @@ def restore_from_hashable(x): | |||||
| return [restore_from_hashable(item) for item in x] | return [restore_from_hashable(item) for item in x] | ||||
| return x | return x | ||||
| def tab_data_to_tuple(X, y, reasoning_result = 0): | |||||
| ''' | |||||
| Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple is a list of three elements: data, label, and reasoning result. | |||||
| ''' | |||||
| if len(X) != len(y): | |||||
| raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))) | |||||
| return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y)) | |||||
| @@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge | |||||
| from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | from abl.data.evaluation import ReasoningMetric, SymbolAccuracy | ||||
| from abl.learning import ABLModel | from abl.learning import ABLModel | ||||
| from abl.reasoning import Reasoner | from abl.reasoning import Reasoner | ||||
| from abl.utils import ABLLogger, confidence_dist, print_log | |||||
| from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple | |||||
| from get_dataset import load_and_preprocess_dataset, split_dataset | from get_dataset import load_and_preprocess_dataset, split_dataset | ||||
| from kb import ZooKB | from kb import ZooKB | ||||
| def transform_tab_data(X, y): | |||||
| return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y)) | |||||
| def consitency(data_example, candidates, candidate_idxs, reasoning_results): | def consitency(data_example, candidates, candidate_idxs, reasoning_results): | ||||
| pred_prob = data_example.pred_prob | pred_prob = data_example.pred_prob | ||||
| @@ -39,9 +37,9 @@ def main(): | |||||
| X, y = load_and_preprocess_dataset(dataset_id=62) | X, y = load_and_preprocess_dataset(dataset_id=62) | ||||
| X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) | ||||
| label_data = transform_tab_data(X_label, y_label) | |||||
| test_data = transform_tab_data(X_test, y_test) | |||||
| train_data = transform_tab_data(X_unlabel, y_unlabel) | |||||
| label_data = tab_data_to_tuple(X_label, y_label) | |||||
| test_data = tab_data_to_tuple(X_test, y_test) | |||||
| train_data = tab_data_to_tuple(X_unlabel, y_unlabel) | |||||
| ### Building the Learning Part | ### Building the Learning Part | ||||
| print_log("Building the Learning Part.", logger="current") | print_log("Building the Learning Part.", logger="current") | ||||
| @@ -27,7 +27,7 @@ | |||||
| "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", | ||||
| "from abl.learning import ABLModel\n", | "from abl.learning import ABLModel\n", | ||||
| "from abl.reasoning import Reasoner\n", | "from abl.reasoning import Reasoner\n", | ||||
| "from abl.utils import ABLLogger, confidence_dist, print_log\n", | |||||
| "from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n", | |||||
| "\n", | "\n", | ||||
| "from get_dataset import load_and_preprocess_dataset, split_dataset\n", | "from get_dataset import load_and_preprocess_dataset, split_dataset\n", | ||||
| "from kb import ZooKB" | "from kb import ZooKB" | ||||
| @@ -106,11 +106,9 @@ | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| "def transform_tab_data(X, y):\n", | |||||
| " return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n", | |||||
| "label_data = transform_tab_data(X_label, y_label)\n", | |||||
| "test_data = transform_tab_data(X_test, y_test)\n", | |||||
| "train_data = transform_tab_data(X_unlabel, y_unlabel)" | |||||
| "label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n", | |||||
| "test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n", | |||||
| "train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)" | |||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||