Browse Source

[ENH] Move tab_data_to_tuple to utils

pull/1/head
Tony-HYX 2 years ago
parent
commit
eb242be48b
4 changed files with 17 additions and 14 deletions
  1. +2
    -1
      abl/utils/__init__.py
  2. +7
    -1
      abl/utils/utils.py
  3. +4
    -6
      examples/zoo/main.py
  4. +4
    -6
      examples/zoo/zoo.ipynb

+ 2
- 1
abl/utils/__init__.py View File

@@ -1,6 +1,6 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple

__all__ = [
"Cache",
@@ -12,4 +12,5 @@ __all__ = [
"reform_list",
"to_hashable",
"abl_cache",
"tab_data_to_tuple",
]

+ 7
- 1
abl/utils/utils.py View File

@@ -154,4 +154,10 @@ def restore_from_hashable(x):
return [restore_from_hashable(item) for item in x]
return x


def tab_data_to_tuple(X, y, reasoning_result = 0):
'''
Convert a tabular data to a tuple by adding a dimension to each element of X and y. The tuple is a list of three elements: data, label, and reasoning result.
'''
if len(X) != len(y):
raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)))
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

+ 4
- 6
examples/zoo/main.py View File

@@ -8,14 +8,12 @@ from abl.bridge import SimpleBridge
from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
from abl.learning import ABLModel
from abl.reasoning import Reasoner
from abl.utils import ABLLogger, confidence_dist, print_log
from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple

from get_dataset import load_and_preprocess_dataset, split_dataset
from kb import ZooKB


def transform_tab_data(X, y):
return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))

def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
@@ -39,9 +37,9 @@ def main():
X, y = load_and_preprocess_dataset(dataset_id=62)
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
label_data = transform_tab_data(X_label, y_label)
test_data = transform_tab_data(X_test, y_test)
train_data = transform_tab_data(X_unlabel, y_unlabel)
label_data = tab_data_to_tuple(X_label, y_label)
test_data = tab_data_to_tuple(X_test, y_test)
train_data = tab_data_to_tuple(X_unlabel, y_unlabel)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")


+ 4
- 6
examples/zoo/zoo.ipynb View File

@@ -27,7 +27,7 @@
"from abl.data.evaluation import ReasoningMetric, SymbolAccuracy\n",
"from abl.learning import ABLModel\n",
"from abl.reasoning import Reasoner\n",
"from abl.utils import ABLLogger, confidence_dist, print_log\n",
"from abl.utils import ABLLogger, confidence_dist, print_log, tab_data_to_tuple\n",
"\n",
"from get_dataset import load_and_preprocess_dataset, split_dataset\n",
"from kb import ZooKB"
@@ -106,11 +106,9 @@
"metadata": {},
"outputs": [],
"source": [
"def transform_tab_data(X, y):\n",
" return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))\n",
"label_data = transform_tab_data(X_label, y_label)\n",
"test_data = transform_tab_data(X_test, y_test)\n",
"train_data = transform_tab_data(X_unlabel, y_unlabel)"
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n",
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n",
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)"
]
},
{


Loading…
Cancel
Save