From d5fbf3f8067e1d7d3ea12bd83f1db0b4a849e636 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Mon, 10 Apr 2023 11:09:30 +0800 Subject: [PATCH] [MNT] change block_sample --- abl/framework.py | 70 +++++++++++++--------- abl/reasoning/reasoner.py | 13 +--- abl/utils/utils.py | 20 ++++--- examples/hwf/hwf_example.ipynb | 2 +- examples/mnist_add/mnist_add_example.ipynb | 36 +++++++---- 5 files changed, 83 insertions(+), 58 deletions(-) diff --git a/abl/framework.py b/abl/framework.py index 61b2da4..e33e298 100644 --- a/abl/framework.py +++ b/abl/framework.py @@ -11,7 +11,7 @@ # ================================================================# from .utils.plog import INFO, clocker -from .utils.utils import block_sample +from .utils.utils import block_sample, float_parameter def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): @@ -47,18 +47,15 @@ def filter_data(X, abduced_Z): return finetune_X, finetune_Z -def train( - model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1 -): +def train(model, abducer, train_data, epochs=50, sample=-1, verbose=-1): train_X, train_Z, train_Y = train_data - test_X, test_Z, test_Y = test_data # Set default parameters - if sample_num == -1: - sample_num = len(train_X) + sample_num = float_parameter(sample, len(train_X)) + part_num = len(train_X) // sample_num + 1 if verbose < 1: - verbose = loop_num + verbose = epochs char_acc_flag = 1 if train_Z == None: @@ -68,27 +65,42 @@ def train( predict_func = clocker(model.predict) train_func = clocker(model.train) abduce_func = clocker(abducer.batch_abduce) - - for loop_idx in range(loop_num): - X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx) - preds_res = predict_func(X) - abduced_Z = abduce_func(preds_res, Y) - - if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1): - res = result_statistics( - preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag - ) - INFO("loop: ", loop_idx + 1, " ", res) - - finetune_X, finetune_Z = filter_data(X, abduced_Z) - if len(finetune_X) > 0: - # model.valid(finetune_X, finetune_Z) - train_func(finetune_X, finetune_Z) - else: - INFO("lack of data, all abduced failed", len(finetune_X)) - - return res - + + for epoch in range(epochs): + for seg_idx in range(part_num): + X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, seg_idx) + INFO("epoch:", epoch + 1, ", seg_idx:", seg_idx + 1, "/", part_num, ", data num:", len(X)) + + preds_res = predict_func(X) + abduced_Z = abduce_func(preds_res, Y) + + ## TODO: change verbose + if ((seg_idx + 1) % verbose == 0) or (seg_idx == epochs - 1): + res = result_statistics(preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag) + INFO("seg: ", seg_idx + 1, " ", res) + + finetune_X, finetune_Z = filter_data(X, abduced_Z) + if len(finetune_X) > 0: + # model.valid(finetune_X, finetune_Z) + train_func(finetune_X, finetune_Z) + else: + INFO("lack of data, all abduced failed", len(finetune_X)) + + return model + +## TODO: test +def test(model, abducer, test_data): + test_X, test_Z, test_Y = test_data + predict_func = clocker(model.predict) + preds_res = predict_func(test_X) + + char_acc_flag = 1 + if test_Z == None: + char_acc_flag = 0 + test_Z = [None] * len(test_X) + + res = result_statistics(preds_res["cls"], test_Z, test_Y, abducer.kb.logic_forward, char_acc_flag) + INFO(res) if __name__ == "__main__": pass diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index aa6cde4..f423d4c 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -2,7 +2,7 @@ import abc import numpy as np from multiprocessing import Pool from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist +from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter class ReasonerBase(abc.ABC): def __init__(self, kb, dist_func='hamming', zoopt=False): @@ -173,16 +173,7 @@ class ReasonerBase(abc.ABC): The abduced revisiones. """ pred_res, pred_res_prob, y = data - - assert(type(max_revision) in (int, float)) - if max_revision == -1: - max_revision_num = len(flatten(pred_res)) - elif type(max_revision) == float: - assert(max_revision >= 0 and max_revision <= 1) - max_revision_num = round(len(flatten(pred_res)) * max_revision) - else: - assert(max_revision >= 0) - max_revision_num = max_revision + max_revision_num = float_parameter(max_revision, len(flatten(pred_res))) if self.zoopt: solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index c0b867b..44383d5 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -36,16 +36,10 @@ def confidence_dist(A, B): cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) return 1 - np.prod(A[rows, cols, B], axis=1) -def block_sample(X, Z, Y, sample_num, epoch_idx): - part_num = len(X) // sample_num - if part_num == 0: - part_num = 1 - seg_idx = epoch_idx % part_num - INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) +def block_sample(X, Z, Y, sample_num, seg_idx): X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)] Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)] Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)] - return X, Z, Y @@ -78,3 +72,15 @@ def hashable_to_list(t): if type(t[0]) is not tuple: return list(t) return [list(subtuple) for subtuple in t] + + +def float_parameter(parameter, total_length): + assert(type(parameter) in (int, float)) + if parameter == -1: + return total_length + elif type(parameter) == float: + assert(parameter >= 0 and parameter <= 1) + return round(total_length * parameter) + else: + assert(parameter >= 0) + return parameter \ No newline at end of file diff --git a/examples/hwf/hwf_example.ipynb b/examples/hwf/hwf_example.ipynb index e11e67d..715b71f 100644 --- a/examples/hwf/hwf_example.ipynb +++ b/examples/hwf/hwf_example.ipynb @@ -177,7 +177,7 @@ "source": [ "# Train model\n", "framework.train(\n", - " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", + " model, abducer, train_data, epochs=15, sample=5000, verbose=1\n", ")\n", "\n", "# Save results\n", diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 7be49c1..f4cadff 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -136,8 +136,8 @@ "outputs": [], "source": [ "# Get training and testing data\n", - "train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n", - "test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)" + "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", + "test_data = get_mnist_add(train=False, get_pseudo_label=True)" ] }, { @@ -155,24 +155,40 @@ "outputs": [], "source": [ "# Train model\n", - "framework.train(\n", + "model = framework.train(\n", " model,\n", " abducer,\n", - " (train_X, train_Z, train_Y),\n", - " (test_X, test_Z, test_Y),\n", - " loop_num=15,\n", - " sample_num=5000,\n", + " train_data,\n", + " epochs=5,\n", + " sample=12000,\n", " verbose=1,\n", ")\n", "\n", "# Save results\n", "recorder.dump()" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### TODO: Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "framework.test(model, abducer, test_data)" + ] } ], "metadata": { "kernelspec": { - "display_name": "ABL", + "display_name": "abl", "language": "python", "name": "python3" }, @@ -186,12 +202,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" + "hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e" } } },