Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
15eda317bc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 10 deletions
  1. +10
    -10
      abducer/kb.py

+ 10
- 10
abducer/kb.py View File

@@ -19,7 +19,7 @@ from collections import defaultdict
from itertools import product

class KBBase(ABC):
def __init__(self):
def __init__(self, GKB_flag = False):
pass

@abstractmethod
@@ -46,13 +46,13 @@ class KBBase(ABC):


class ClsKB(KBBase):
def __init__(self, pseudo_label_list, len_list = None):
def __init__(self, GKB_flag = False, pseudo_label_list = None, len_list = None):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.base = {}
self.len_list = len_list
if(self.len_list != None):
if GKB_flag:
X = self.get_X(self.pseudo_label_list, self.len_list)
Y = self.get_Y(X, self.logic_forward)
for x, y in zip(X, Y):
@@ -94,18 +94,18 @@ class ClsKB(KBBase):


class add_KB(ClsKB):
def __init__(self, len_list = [2]):
def __init__(self, GKB_flag = False, len_list = [2]):
self.pseudo_label_list = list(range(10))
super().__init__(self.pseudo_label_list, len_list)
super().__init__(GKB_flag, self.pseudo_label_list, len_list)
def logic_forward(self, nums):
return sum(nums)
class hwf_KB(ClsKB):
def __init__(self, len_list = [1, 3, 5, 7]):
def __init__(self, GKB_flag = False, len_list = [1, 3, 5, 7]):
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/']
super().__init__(self.pseudo_label_list, len_list)
super().__init__(GKB_flag, self.pseudo_label_list, len_list)
def valid_formula(self, formula):
if(len(formula) % 2 == 0):
@@ -127,7 +127,7 @@ class hwf_KB(ClsKB):

class RegKB(KBBase):
def __init__(self, X, Y = None):
def __init__(self, GKB_flag = False, X = None, Y = None):
super().__init__()
tmp_dict = {}
for x, y in zip(X, Y):
@@ -176,7 +176,7 @@ class RegKB(KBBase):
import time
if __name__ == "__main__":
# With ground KB
kb = add_KB(len_list = [2])
kb = add_KB(GKB_flag = True)
print('len(kb):', len(kb))
res = kb.get_candidates(0)
print(res)
@@ -202,7 +202,7 @@ if __name__ == "__main__":
print()
start = time.time()
kb = hwf_KB(len_list = [1, 3, 5, 7])
kb = hwf_KB(GKB_flag = True)
print(time.time() - start)
print('len(kb):', len(kb))
res = kb.get_candidates(2, length = 1)


Loading…
Cancel
Save