Browse Source

Update abducer_base.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
70fa5cf56a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 25 deletions
  1. +27
    -25
      abducer/abducer_base.py

+ 27
- 25
abducer/abducer_base.py View File

@@ -11,7 +11,8 @@
#================================================================#

import abc
from kb import add_KB, hwf_KB
# from kb import add_KB, hwf_KB
from abducer.kb import add_KB, hwf_KB
import numpy as np

from itertools import product, combinations
@@ -37,18 +38,19 @@ def confidence_dist(A, B):


class AbducerBase(abc.ABC):
def __init__(self, kb, dist_func = "hamming", pred_res_parse = None, cache = True):
def __init__(self, kb, dist_func = 'hamming', pred_res_parse = None, cache = True):
self.kb = kb
if dist_func == "hamming":
dist_func = hamming_dist
elif dist_func == "confidence":
dist_func = confidence_dist
self.dist_func = dist_func
if(dist_func == 'hamming'):
self.dist_func = hamming_dist
elif(dist_func == 'confidence'):
self.dist_func = confidence_dist
if pred_res_parse is None:
if(dist_func == "hamming"):
if(dist_func == 'hamming'):
pred_res_parse = lambda x : x["cls"]
elif dist_func == "confidence":
pred_res_parse = lambda x : x[" "]
elif dist_func == 'confidence':
pred_res_parse = lambda x : x["prob"]
self.pred_res_parse = pred_res_parse
self.cache = cache
@@ -61,20 +63,24 @@ class AbducerBase(abc.ABC):
idxs = np.where(cost_list == min_address_num)[0]
return [candidates[idx] for idx in idxs][0]

def abduce(self, data, max_address_num = 3, require_more_address = 0, length = -1):
def abduce(self, data, max_address_num = -1, require_more_address = 0):
pred_res, ans = data
pred_res = [self.kb.pseudo_label_list[sym] for sym in pred_res]

if length == -1:
length = len(pred_res)
if max_address_num == -1:
max_address_num = len(pred_res)

if(self.cache and (tuple(pred_res), ans) in self.cache_min_address_num):
address_num = min(max_address_num, self.cache_min_address_num[(tuple(pred_res), ans)] + require_more_address)
if((tuple(pred_res), ans, address_num) in self.cache_candidates):
print('cached')
return self.cache_candidates[(tuple(pred_res), ans, address_num)]
# print('cached')
candidates = self.cache_candidates[(tuple(pred_res), ans, address_num)]
candidates = self.get_min_cost_candidate(pred_res, candidates)
return candidates
if(self.kb.base != {}):
all_candidates = self.kb.get_candidates(ans, length)
all_candidates = self.kb.get_candidates(ans, len(pred_res))
cost_list = self.dist_func(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_address_num, min_address_num + require_more_address)
@@ -107,7 +113,6 @@ class AbducerBase(abc.ABC):
def get_abduce_candidates(self, pred_res, key, max_address_num, require_more_address):
candidates = []

for address_num in range(len(pred_res) + 1):
if(address_num > max_address_num):
print('No candidates found')
@@ -132,22 +137,20 @@ class AbducerBase(abc.ABC):

return candidates, min_address_num, address_num
def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0):
def batch_abduce(self, Y, C, max_address_num = -1, require_more_address = 0):
return [
self.abduce((y, c), max_address_num, require_more_address)\
for y, c in zip(self.pred_res_parse(Y), C)
]

def __call__(self, Y, C, max_address_num = 3, require_more_address = 0):
def __call__(self, Y, C, max_address_num = -1, require_more_address = 0):
return self.batch_abduce(Y, C, max_address_num, require_more_address)



if __name__ == "__main__":
pseudo_label_list = list(range(10))
kb = add_KB(pseudo_label_list)
kb = add_KB()
abd = AbducerBase(kb)
res = abd.abduce(([1, 1, 1], 4), max_address_num = 2, require_more_address = 0)
print(res)
@@ -161,8 +164,7 @@ if __name__ == "__main__":
print(res)
print()
pseudo_label_list = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '+', '-', '*', '/']
kb = hwf_KB(pseudo_label_list)
kb = hwf_KB()
abd = AbducerBase(kb)
res = abd.abduce((['5', '+', '2'], 3), max_address_num = 2, require_more_address = 0)
print(res)


Loading…
Cancel
Save