Browse Source

[ENH] add reasoning test

pull/1/head
troyyyyy 2 years ago
parent
commit
bbc9ef46bb
2 changed files with 255 additions and 1 deletions
  1. +108
    -1
      tests/conftest.py
  2. +147
    -0
      tests/test_reasoning.py

+ 108
- 1
tests/conftest.py View File

@@ -4,10 +4,10 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim


from abl.learning import BasicNN from abl.learning import BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase
from abl.structures import ListData from abl.structures import ListData
from examples.models.nn import LeNet5 from examples.models.nn import LeNet5



# Fixture for BasicNN instance # Fixture for BasicNN instance
@pytest.fixture @pytest.fixture
def basic_nn_instance(): def basic_nn_instance():
@@ -32,3 +32,110 @@ def list_data_instance():
data_samples.Y = [1, 2, 3] data_samples.Y = [1, 2, 3]
data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] data_samples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
return data_samples return data_samples


@pytest.fixture
def data_samples_add():
# favor 1 in first one
prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
# favor 7 in first one
prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0],
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]
data_samples_add = ListData()
data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]]
data_samples_add.pred_prob = [prob1, prob2, prob1, prob2]
data_samples_add.Y = [8, 8, 17, 10]
return data_samples_add

@pytest.fixture
def data_samples_hwf():
data_samples_hwf = ListData()
data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]]
data_samples_hwf.pred_prob = [None, None, None, None]
data_samples_hwf.Y = [3, 64, 65, 3.17]
return data_samples_hwf

class AddKB(KBBase):
def __init__(self, pseudo_label_list=list(range(10)),
use_cache=False):
super().__init__(pseudo_label_list, use_cache=use_cache)

def logic_forward(self, nums):
return sum(nums)
class AddGroundKB(GroundKB):
def __init__(self, pseudo_label_list=list(range(10)),
GKB_len_list=[2]):
super().__init__(pseudo_label_list, GKB_len_list)
def logic_forward(self, nums):
return sum(nums)
class HwfKB(KBBase):
def __init__(
self,
pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9",
"+", "-", "times", "div"],
max_err=1e-3,
use_cache=False,
):
super().__init__(pseudo_label_list, max_err, use_cache)

def _valid_candidate(self, formula):
if len(formula) % 2 == 0:
return False
for i in range(len(formula)):
if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
return False
if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]:
return False
return True

def logic_forward(self, formula):
if not self._valid_candidate(formula):
return None
mapping = {str(i): str(i) for i in range(1, 10)}
mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"})
formula = [mapping[f] for f in formula]
return eval("".join(formula))
class HedKB(PrologKB):
def __init__(self, pseudo_label_list, pl_file):
super().__init__(pseudo_label_list, pl_file)

def consist_rule(self, exs, rules):
rules = str(rules).replace("'", "")
pl_query = "eval_inst_feature(%s, %s)." % (exs, rules)
return len(list(self.prolog.query(pl_query))) != 0

@pytest.fixture
def kb_add():
return AddKB()

@pytest.fixture
def kb_add_cache():
return AddKB(use_cache=True)

@pytest.fixture
def kb_add_ground():
return AddGroundKB()

@pytest.fixture
def kb_add_prolog():
kb = PrologKB(pseudo_label_list=list(range(10)),
pl_file="examples/mnist_add/datasets/add.pl")
return kb

@pytest.fixture
def kb_hed():
kb = HedKB(
pseudo_label_list=[1, 0, "+", "="],
pl_file="examples/hed/datasets/learn_add.pl",
)
return kb

@pytest.fixture
def reasoner_instance(kb_add):
return ReasonerBase(kb_add, "confidence")

+ 147
- 0
tests/test_reasoning.py View File

@@ -0,0 +1,147 @@
import pytest
from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase

class TestKBBase(object):
def test_init(self, kb_add):
assert kb_add.pseudo_label_list == list(range(10))
def test_init_cache(self, kb_add_cache):
assert kb_add_cache.pseudo_label_list == list(range(10))
assert kb_add_cache.use_cache == True
def test_logic_forward(self, kb_add):
result = kb_add.logic_forward([1, 2])
assert result == 3
def test_revise_at_idx(self, kb_add):
result = kb_add.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]]
def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2)
assert result == [[1, 0]]
class TestGroundKB(object):
def test_init(self, kb_add_ground):
assert kb_add_ground.pseudo_label_list == list(range(10))
assert kb_add_ground.GKB_len_list == [2]
assert kb_add_ground.GKB
def test_logic_forward_ground(self, kb_add_ground):
result = kb_add_ground.logic_forward([1, 2])
assert result == 3
def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2)
assert result == [(1, 0)]
class TestPrologKB(object):
def test_init_pl1(self, kb_add_prolog):
assert kb_add_prolog.pseudo_label_list == list(range(10))
assert kb_add_prolog.pl_file == "examples/mnist_add/datasets/add.pl"
def test_init_pl2(self, kb_hed):
assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
assert kb_hed.pl_file == "examples/hed/datasets/learn_add.pl"
def test_prolog_file_not_exist(self):
pseudo_label_list = [1, 2]
non_existing_file = "path/to/non_existing_file.pl"
with pytest.raises(FileNotFoundError) as excinfo:
PrologKB(pseudo_label_list=pseudo_label_list,
pl_file=non_existing_file)
assert non_existing_file in str(excinfo.value)
def test_logic_forward_pl1(self, kb_add_prolog):
result = kb_add_prolog.logic_forward([1, 2])
assert result == 3
def test_logic_forward_pl2(self, kb_hed):
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs = [
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
[0, "+", 0, "=", 1],
]
assert kb_hed.logic_forward(consist_exs) == True
assert kb_hed.logic_forward(inconsist_exs) == False

def test_revise_at_idx(self, kb_add_prolog):
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0])
assert result == [[0, 2]]

class TestReaonser(object):
def test_reasoner_init(self, reasoner_instance):
assert reasoner_instance.dist_func == "confidence"
def test_invalid_dist_funce(kb_add):
with pytest.raises(NotImplementedError) as excinfo:
ReasonerBase(kb_add, "invalid_dist_func")
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value)


class test_batch_abduce(object):
def test_batch_abduce_add(self, kb_add, data_samples_add):
reasoner = ReasonerBase(kb_add, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]

def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
reasoner = ReasonerBase(kb_add_ground, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [(1, 7), (7, 1), [], (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [(1, 7), (7, 1), [], (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [(1, 7), (7, 1), (8, 9), (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [(1, 7), (7, 1), (8, 9), (7, 3)]

def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
reasoner = ReasonerBase(kb_add_prolog, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
reasoner = ReasonerBase(kb_add_prolog, "confidence", use_zoopt=True)
res = reasoner.batch_abduce(data_samples_add, max_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]

def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
reasoner = ReasonerBase(kb_hwf1, "hamming")
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
assert res == [['1', '+', '2'], [], [], []]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]

def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
reasoner = ReasonerBase(kb_hwf2, "hamming")
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]


Loading…
Cancel
Save