| @@ -11,7 +11,7 @@ | |||
| # ================================================================# | |||
| from .utils.plog import INFO, clocker | |||
| from .utils.utils import block_sample | |||
| from .utils.utils import block_sample, float_parameter | |||
| def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): | |||
| @@ -47,18 +47,15 @@ def filter_data(X, abduced_Z): | |||
| return finetune_X, finetune_Z | |||
| def train( | |||
| model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1 | |||
| ): | |||
| def train(model, abducer, train_data, epochs=50, sample=-1, verbose=-1): | |||
| train_X, train_Z, train_Y = train_data | |||
| test_X, test_Z, test_Y = test_data | |||
| # Set default parameters | |||
| if sample_num == -1: | |||
| sample_num = len(train_X) | |||
| sample_num = float_parameter(sample, len(train_X)) | |||
| part_num = len(train_X) // sample_num + 1 | |||
| if verbose < 1: | |||
| verbose = loop_num | |||
| verbose = epochs | |||
| char_acc_flag = 1 | |||
| if train_Z == None: | |||
| @@ -68,27 +65,42 @@ def train( | |||
| predict_func = clocker(model.predict) | |||
| train_func = clocker(model.train) | |||
| abduce_func = clocker(abducer.batch_abduce) | |||
| for loop_idx in range(loop_num): | |||
| X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx) | |||
| preds_res = predict_func(X) | |||
| abduced_Z = abduce_func(preds_res, Y) | |||
| if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1): | |||
| res = result_statistics( | |||
| preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag | |||
| ) | |||
| INFO("loop: ", loop_idx + 1, " ", res) | |||
| finetune_X, finetune_Z = filter_data(X, abduced_Z) | |||
| if len(finetune_X) > 0: | |||
| # model.valid(finetune_X, finetune_Z) | |||
| train_func(finetune_X, finetune_Z) | |||
| else: | |||
| INFO("lack of data, all abduced failed", len(finetune_X)) | |||
| return res | |||
| for epoch in range(epochs): | |||
| for seg_idx in range(part_num): | |||
| X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, seg_idx) | |||
| INFO("epoch:", epoch + 1, ", seg_idx:", seg_idx + 1, "/", part_num, ", data num:", len(X)) | |||
| preds_res = predict_func(X) | |||
| abduced_Z = abduce_func(preds_res, Y) | |||
| ## TODO: change verbose | |||
| if ((seg_idx + 1) % verbose == 0) or (seg_idx == epochs - 1): | |||
| res = result_statistics(preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag) | |||
| INFO("seg: ", seg_idx + 1, " ", res) | |||
| finetune_X, finetune_Z = filter_data(X, abduced_Z) | |||
| if len(finetune_X) > 0: | |||
| # model.valid(finetune_X, finetune_Z) | |||
| train_func(finetune_X, finetune_Z) | |||
| else: | |||
| INFO("lack of data, all abduced failed", len(finetune_X)) | |||
| return model | |||
| ## TODO: test | |||
| def test(model, abducer, test_data): | |||
| test_X, test_Z, test_Y = test_data | |||
| predict_func = clocker(model.predict) | |||
| preds_res = predict_func(test_X) | |||
| char_acc_flag = 1 | |||
| if test_Z == None: | |||
| char_acc_flag = 0 | |||
| test_Z = [None] * len(test_X) | |||
| res = result_statistics(preds_res["cls"], test_Z, test_Y, abducer.kb.logic_forward, char_acc_flag) | |||
| INFO(res) | |||
| if __name__ == "__main__": | |||
| pass | |||
| @@ -2,7 +2,7 @@ import abc | |||
| import numpy as np | |||
| from multiprocessing import Pool | |||
| from zoopt import Dimension, Objective, Parameter, Opt | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist | |||
| from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter | |||
| class ReasonerBase(abc.ABC): | |||
| def __init__(self, kb, dist_func='hamming', zoopt=False): | |||
| @@ -173,16 +173,7 @@ class ReasonerBase(abc.ABC): | |||
| The abduced revisiones. | |||
| """ | |||
| pred_res, pred_res_prob, y = data | |||
| assert(type(max_revision) in (int, float)) | |||
| if max_revision == -1: | |||
| max_revision_num = len(flatten(pred_res)) | |||
| elif type(max_revision) == float: | |||
| assert(max_revision >= 0 and max_revision <= 1) | |||
| max_revision_num = round(len(flatten(pred_res)) * max_revision) | |||
| else: | |||
| assert(max_revision >= 0) | |||
| max_revision_num = max_revision | |||
| max_revision_num = float_parameter(max_revision, len(flatten(pred_res))) | |||
| if self.zoopt: | |||
| solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num) | |||
| @@ -36,16 +36,10 @@ def confidence_dist(A, B): | |||
| cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B)) | |||
| return 1 - np.prod(A[rows, cols, B], axis=1) | |||
| def block_sample(X, Z, Y, sample_num, epoch_idx): | |||
| part_num = len(X) // sample_num | |||
| if part_num == 0: | |||
| part_num = 1 | |||
| seg_idx = epoch_idx % part_num | |||
| INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) | |||
| def block_sample(X, Z, Y, sample_num, seg_idx): | |||
| X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)] | |||
| return X, Z, Y | |||
| @@ -78,3 +72,15 @@ def hashable_to_list(t): | |||
| if type(t[0]) is not tuple: | |||
| return list(t) | |||
| return [list(subtuple) for subtuple in t] | |||
| def float_parameter(parameter, total_length): | |||
| assert(type(parameter) in (int, float)) | |||
| if parameter == -1: | |||
| return total_length | |||
| elif type(parameter) == float: | |||
| assert(parameter >= 0 and parameter <= 1) | |||
| return round(total_length * parameter) | |||
| else: | |||
| assert(parameter >= 0) | |||
| return parameter | |||
| @@ -177,7 +177,7 @@ | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework.train(\n", | |||
| " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", | |||
| " model, abducer, train_data, epochs=15, sample=5000, verbose=1\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| @@ -136,8 +136,8 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| "train_data = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| @@ -155,24 +155,40 @@ | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework.train(\n", | |||
| "model = framework.train(\n", | |||
| " model,\n", | |||
| " abducer,\n", | |||
| " (train_X, train_Z, train_Y),\n", | |||
| " (test_X, test_Z, test_Y),\n", | |||
| " loop_num=15,\n", | |||
| " sample_num=5000,\n", | |||
| " train_data,\n", | |||
| " epochs=5,\n", | |||
| " sample=12000,\n", | |||
| " verbose=1,\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| "recorder.dump()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### TODO: Test" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "framework.test(model, abducer, test_data)" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "display_name": "abl", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| @@ -186,12 +202,12 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.13" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| "interpreter": { | |||
| "hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58" | |||
| "hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e" | |||
| } | |||
| } | |||
| }, | |||