Browse Source

Update kb.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
f6f92a5f08
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 13 deletions
  1. +10
    -13
      abducer/kb.py

+ 10
- 13
abducer/kb.py View File

@@ -45,9 +45,9 @@ class KBBase(ABC):
pass

class add_KB(KBBase):
def __init__(self, pseudo_label_list, kb_max_len = -1):
def __init__(self, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.pseudo_label_list = list(range(10))
self.base = {}
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
@@ -92,9 +92,9 @@ class add_KB(KBBase):
return sum(self._dict_len(v) for v in self.base.values())
class hwf_KB(KBBase):
def __init__(self, pseudo_label_list, kb_max_len = -1):
def __init__(self, kb_max_len = -1):
super().__init__()
self.pseudo_label_list = pseudo_label_list
self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/']
self.base = {}
self.kb_max_len = kb_max_len
if(self.kb_max_len > 0):
@@ -157,13 +157,13 @@ class hwf_KB(KBBase):
def valid_formula(self, formula):
symbol_idx_list = []
first_minus_flag = 0
for idx, c in enumerate(formula):
if(idx == 0 and c == '-'):
first_minus_flag = 1
if(len(formula) == 1 or formula[1] in ['+', '-', '*', '/']):
return False
continue
if(c in ['+', '-', '*', '/']):
if(idx - 1 in symbol_idx_list or (idx == 1 and first_minus_flag == 1)):
if(idx - 1 in symbol_idx_list):
return False
symbol_idx_list.append(idx)
if(0 in symbol_idx_list or len(formula) - 1 in symbol_idx_list):
@@ -289,8 +289,7 @@ class reg_KB(KBBase):

if __name__ == "__main__":
# With ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list, kb_max_len = 5)
kb = add_KB(kb_max_len = 5)
print('len(kb):', len(kb))
res = kb.get_candidates(0)
print(res)
@@ -303,8 +302,7 @@ if __name__ == "__main__":
print()
# Without ground KB
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
kb = add_KB()
print('len(kb):', len(kb))
res = kb.get_candidates(0)
print(res)
@@ -316,8 +314,7 @@ if __name__ == "__main__":
print(res)
print()
pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '+', '-', '*', '/']
kb = hwf_KB(pseudo_label_list, kb_max_len = 5)
kb = hwf_KB(kb_max_len = 5)
print('len(kb):', len(kb))
res = kb.get_candidates(1, length = 3)
print(res)


Loading…
Cancel
Save