Browse Source

[FIX] fix test-reasoning

pull/1/head
troyyyyy 2 years ago
parent
commit
4fa9bd522c
2 changed files with 51 additions and 48 deletions
  1. +2
    -2
      tests/conftest.py
  2. +49
    -46
      tests/test_reasoning.py

+ 2
- 2
tests/conftest.py View File

@@ -4,7 +4,7 @@ import torch.nn as nn
import torch.optim as optim

from abl.learning import BasicNN
from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner
from abl.structures import ListData
from examples.models.nn import LeNet5

@@ -138,4 +138,4 @@ def kb_hed():

@pytest.fixture
def reasoner_instance(kb_add):
return ReasonerBase(kb_add, "confidence")
return Reasoner(kb_add, "confidence")

+ 49
- 46
tests/test_reasoning.py View File

@@ -1,5 +1,5 @@
import pytest
from abl.reasoning import KBBase, GroundKB, PrologKB, ReasonerBase
from abl.reasoning import KBBase, GroundKB, PrologKB, Reasoner

class TestKBBase(object):
def test_init(self, kb_add):
@@ -18,7 +18,8 @@ class TestKBBase(object):
assert result == [[0, 2]]
def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2)
result = kb_add.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
assert result == [[1, 0]]
class TestGroundKB(object):
@@ -32,7 +33,8 @@ class TestGroundKB(object):
assert result == 3
def test_abduce_candidates_ground(self, kb_add_ground):
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2)
result = kb_add_ground.abduce_candidates([1, 2], 1, max_revision_num=2,
require_more_revision=0)
assert result == [(1, 0)]
class TestPrologKB(object):
@@ -58,8 +60,9 @@ class TestPrologKB(object):
def test_logic_forward_pl2(self, kb_hed):
consist_exs = [
[1, "+", 1, "=", 0],
[1, "+", 1, "=", 1],
[1, 1, "+", 0, "=", 1, 1],
[1, "+", 1, "=", 1, 0],
[0, "+", 0, "=", 0],
]
inconsist_exs = [
[1, 1, "+", 0, "=", 1, 1],
@@ -80,67 +83,67 @@ class TestReaonser(object):
def test_invalid_dist_funce(kb_add):
with pytest.raises(NotImplementedError) as excinfo:
ReasonerBase(kb_add, "invalid_dist_func")
Reasoner(kb_add, "invalid_dist_func")
assert "Valid options for dist_func include \"hamming\" and \"confidence\"" in str(excinfo.value)


class test_batch_abduce(object):
def test_batch_abduce_add(self, kb_add, data_samples_add):
reasoner = ReasonerBase(kb_add, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]
reasoner1 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]

def test_batch_abduce_ground(self, kb_add_ground, data_samples_add):
reasoner = ReasonerBase(kb_add_ground, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [(1, 7), (7, 1), [], (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [(1, 7), (7, 1), [], (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [(1, 7), (7, 1), (8, 9), (1, 9)]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [(1, 7), (7, 1), (8, 9), (7, 3)]
reasoner1 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add_ground, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add_ground, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner2.batch_abduce(data_samples_add) == [(1, 7), (7, 1), [], (1, 9)]
assert reasoner3.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (1, 9)]
assert reasoner4.batch_abduce(data_samples_add) == [(1, 7), (7, 1), (8, 9), (7, 3)]

def test_batch_abduce_prolog(self, kb_add_prolog, data_samples_add):
reasoner = ReasonerBase(kb_add_prolog, "confidence")
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0)
assert res == [[1, 7], [7, 1], [8, 9], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
reasoner3 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=0)
reasoner4 = Reasoner(kb_add_prolog, "confidence", max_revision=2, require_more_revision=1)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner3.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [1, 9]]
assert reasoner4.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]
def test_batch_abduce_zoopt(self, kb_add_prolog, data_samples_add):
reasoner = ReasonerBase(kb_add_prolog, "confidence", use_zoopt=True)
res = reasoner.batch_abduce(data_samples_add, max_revision=1)
assert res == [[1, 7], [7, 1], [], [1, 9]]
res = reasoner.batch_abduce(data_samples_add, max_revision=2)
assert res == [[1, 7], [7, 1], [8, 9], [7, 3]]
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)
assert reasoner1.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [], [1, 9]]
assert reasoner2.batch_abduce(data_samples_add) == [[1, 7], [7, 1], [8, 9], [7, 3]]

def test_batch_abduce_hwf1(self, kb_hwf1, data_samples_hwf):
reasoner = ReasonerBase(kb_hwf1, "hamming")
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
reasoner1 = Reasoner(kb_hwf1, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf1, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf1, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], [], [], []]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['8', 'times', '8'], [], ['4', '-', '6', 'div', '8']]

def test_batch_abduce_hwf2(self, kb_hwf2, data_samples_hwf):
reasoner = ReasonerBase(kb_hwf2, "hamming")
res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0)
reasoner1 = Reasoner(kb_hwf2, "hamming", max_revision=3, require_more_revision=0)
reasoner2 = Reasoner(kb_hwf2, "hamming", max_revision=0.5, require_more_revision=0)
reasoner3 = Reasoner(kb_hwf2, "hamming", max_revision=0.9, require_more_revision=0)
res = reasoner1.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3)
res = reasoner2.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], [], ['5', '-', '8', 'div', '8']]
res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0)
res = reasoner3.batch_abduce(data_samples_hwf)
assert res == [['1', '+', '2'], ['7', 'times', '9'], ['8', 'times', '8'], ['5', '-', '8', 'div', '8']]


Loading…
Cancel
Save