Browse Source

[MNT] change block_sample

pull/3/head
troyyyyy 3 years ago
parent
commit
d5fbf3f806
5 changed files with 83 additions and 58 deletions
  1. +41
    -29
      abl/framework.py
  2. +2
    -11
      abl/reasoning/reasoner.py
  3. +13
    -7
      abl/utils/utils.py
  4. +1
    -1
      examples/hwf/hwf_example.ipynb
  5. +26
    -10
      examples/mnist_add/mnist_add_example.ipynb

+ 41
- 29
abl/framework.py View File

@@ -11,7 +11,7 @@
# ================================================================#

from .utils.plog import INFO, clocker
from .utils.utils import block_sample
from .utils.utils import block_sample, float_parameter


def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
@@ -47,18 +47,15 @@ def filter_data(X, abduced_Z):
return finetune_X, finetune_Z


def train(
model, abducer, train_data, test_data, loop_num=50, sample_num=-1, verbose=-1
):
def train(model, abducer, train_data, epochs=50, sample=-1, verbose=-1):
train_X, train_Z, train_Y = train_data
test_X, test_Z, test_Y = test_data

# Set default parameters
if sample_num == -1:
sample_num = len(train_X)
sample_num = float_parameter(sample, len(train_X))
part_num = len(train_X) // sample_num + 1

if verbose < 1:
verbose = loop_num
verbose = epochs

char_acc_flag = 1
if train_Z == None:
@@ -68,27 +65,42 @@ def train(
predict_func = clocker(model.predict)
train_func = clocker(model.train)
abduce_func = clocker(abducer.batch_abduce)

for loop_idx in range(loop_num):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, loop_idx)
preds_res = predict_func(X)
abduced_Z = abduce_func(preds_res, Y)

if ((loop_idx + 1) % verbose == 0) or (loop_idx == loop_num - 1):
res = result_statistics(
preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag
)
INFO("loop: ", loop_idx + 1, " ", res)

finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
# model.valid(finetune_X, finetune_Z)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))

return res

for epoch in range(epochs):
for seg_idx in range(part_num):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, seg_idx)
INFO("epoch:", epoch + 1, ", seg_idx:", seg_idx + 1, "/", part_num, ", data num:", len(X))
preds_res = predict_func(X)
abduced_Z = abduce_func(preds_res, Y)

## TODO: change verbose
if ((seg_idx + 1) % verbose == 0) or (seg_idx == epochs - 1):
res = result_statistics(preds_res["cls"], Z, Y, abducer.kb.logic_forward, char_acc_flag)
INFO("seg: ", seg_idx + 1, " ", res)

finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
# model.valid(finetune_X, finetune_Z)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))

return model

## TODO: test
def test(model, abducer, test_data):
test_X, test_Z, test_Y = test_data
predict_func = clocker(model.predict)
preds_res = predict_func(test_X)
char_acc_flag = 1
if test_Z == None:
char_acc_flag = 0
test_Z = [None] * len(test_X)
res = result_statistics(preds_res["cls"], test_Z, test_Y, abducer.kb.logic_forward, char_acc_flag)
INFO(res)

if __name__ == "__main__":
pass

+ 2
- 11
abl/reasoning/reasoner.py View File

@@ -2,7 +2,7 @@ import abc
import numpy as np
from multiprocessing import Pool
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist
from ..utils.utils import confidence_dist, flatten, reform_idx, hamming_dist, float_parameter

class ReasonerBase(abc.ABC):
def __init__(self, kb, dist_func='hamming', zoopt=False):
@@ -173,16 +173,7 @@ class ReasonerBase(abc.ABC):
The abduced revisiones.
"""
pred_res, pred_res_prob, y = data
assert(type(max_revision) in (int, float))
if max_revision == -1:
max_revision_num = len(flatten(pred_res))
elif type(max_revision) == float:
assert(max_revision >= 0 and max_revision <= 1)
max_revision_num = round(len(flatten(pred_res)) * max_revision)
else:
assert(max_revision >= 0)
max_revision_num = max_revision
max_revision_num = float_parameter(max_revision, len(flatten(pred_res)))

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num)


+ 13
- 7
abl/utils/utils.py View File

@@ -36,16 +36,10 @@ def confidence_dist(A, B):
cols = np.expand_dims(cols, axis=0).repeat(axis=0, repeats=len(B))
return 1 - np.prod(A[rows, cols, B], axis=1)

def block_sample(X, Z, Y, sample_num, epoch_idx):
part_num = len(X) // sample_num
if part_num == 0:
part_num = 1
seg_idx = epoch_idx % part_num
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
def block_sample(X, Z, Y, sample_num, seg_idx):
X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)]
Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)]
Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)]

return X, Z, Y


@@ -78,3 +72,15 @@ def hashable_to_list(t):
if type(t[0]) is not tuple:
return list(t)
return [list(subtuple) for subtuple in t]


def float_parameter(parameter, total_length):
assert(type(parameter) in (int, float))
if parameter == -1:
return total_length
elif type(parameter) == float:
assert(parameter >= 0 and parameter <= 1)
return round(total_length * parameter)
else:
assert(parameter >= 0)
return parameter

+ 1
- 1
examples/hwf/hwf_example.ipynb View File

@@ -177,7 +177,7 @@
"source": [
"# Train model\n",
"framework.train(\n",
" model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n",
" model, abducer, train_data, epochs=15, sample=5000, verbose=1\n",
")\n",
"\n",
"# Save results\n",


+ 26
- 10
examples/mnist_add/mnist_add_example.ipynb View File

@@ -136,8 +136,8 @@
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)"
"train_data = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_data = get_mnist_add(train=False, get_pseudo_label=True)"
]
},
{
@@ -155,24 +155,40 @@
"outputs": [],
"source": [
"# Train model\n",
"framework.train(\n",
"model = framework.train(\n",
" model,\n",
" abducer,\n",
" (train_X, train_Z, train_Y),\n",
" (test_X, test_Z, test_Y),\n",
" loop_num=15,\n",
" sample_num=5000,\n",
" train_data,\n",
" epochs=5,\n",
" sample=12000,\n",
" verbose=1,\n",
")\n",
"\n",
"# Save results\n",
"recorder.dump()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### TODO: Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"framework.test(model, abducer, test_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ABL",
"display_name": "abl",
"language": "python",
"name": "python3"
},
@@ -186,12 +202,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb6f4ceeabb9a733f366948eb80109f83aedf798cc984df1e68fb411adb27d58"
"hash": "9c8d454494e49869a4ee4046edcac9a39ff683f7d38abf0769f648402670238e"
}
}
},


Loading…
Cancel
Save