Browse Source

[ENH] Move tab_data_to_tuple to utils

pull/1/head
Tony-HYX 2 years ago
parent
commit
eb242be48b
4 changed files with 17 additions and 14 deletions
  1. +2
    -1
      abl/utils/__init__.py
  2. +7
    -1
      abl/utils/utils.py
  3. +4
    -6
      examples/zoo/main.py
  4. +4
    -6
      examples/zoo/zoo.ipynb

+ 2
- 1
abl/utils/__init__.py View File

@@ -1,6 +1,6 @@
from .cache import Cache, abl_cache from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log from .logger import ABLLogger, print_log
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple


__all__ = [ __all__ = [
"Cache", "Cache",
@@ -12,4 +12,5 @@ __all__ = [
"reform_list", "reform_list",
"to_hashable", "to_hashable",
"abl_cache", "abl_cache",
"tab_data_to_tuple",
] ]

+ 7
- 1
abl/utils/utils.py View File

@@ -154,4 +154,10 @@ def restore_from_hashable(x):
return [restore_from_hashable(item) for item in x] return [restore_from_hashable(item) for item in x]
return x return x



def tab_data_to_tuple(X, y, reasoning_result = 0):
'''
Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple is a list of three elements: data, label, and reasoning result.
'''
if len(X) != len(y):
raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)))
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

+ 4
- 6
examples/zoo/main.py View File

@@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
from abl.learning import ABLModel from abl.learning import ABLModel
from abl.reasoning import Reasoner from abl.reasoning import Reasoner
from abl.utils import ABLLogger, confidence_dist, print_log
from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple


from get_dataset import load_and_preprocess_dataset, split_dataset from get_dataset import load_and_preprocess_dataset, split_dataset
from kb import ZooKB from kb import ZooKB




def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))


def consitency(data_example, candidates, candidate_idxs, reasoning_results): def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob pred_prob = data_example.pred_prob
@@ -39,9 +37,9 @@ def main():
X, y = load_and_preprocess_dataset(dataset_id=62) X, y = load_and_preprocess_dataset(dataset_id=62)
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3) X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
label_data = tab_data_to_tuple(X_label, y_label)
test_data = tab_data_to_tuple(X_test, y_test)
train_data = tab_data_to_tuple(X_unlabel, y_unlabel)


### Building the Learning Part ### Building the Learning Part
print_log("Building the Learning Part.", logger="current") print_log("Building the Learning Part.", logger="current")


+ 4
- 6
examples/zoo/zoo.ipynb View File

@@ -27,7 +27,7 @@
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n", "from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n",
"from abl.learning import ABLModel\n", "from abl.learning import ABLModel\n",
"from abl.reasoning import Reasoner\n", "from abl.reasoning import Reasoner\n",
"from abl.utils import ABLLogger, confidence_dist, print_log\n",
"from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n",
"\n", "\n",
"from get_dataset import load_and_preprocess_dataset, split_dataset\n", "from get_dataset import load_and_preprocess_dataset, split_dataset\n",
"from kb import ZooKB" "from kb import ZooKB"
@@ -106,11 +106,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n",
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n",
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)"
] ]
}, },
{ {


Loading…
Cancel
Save