|
|
|
@@ -137,6 +137,7 @@ class ClsKB(KBBase): |
|
|
|
|
|
|
|
def abduce_candidates(self, pred_res, key, max_address_num=-1, require_more_address=0, multiple_predictions=False): |
|
|
|
if self.GKB_flag: |
|
|
|
# TODO: 这里有可能是multiple_predictions吗 |
|
|
|
return self._abduce_by_GKB(pred_res, key, max_address_num, require_more_address) |
|
|
|
else: |
|
|
|
return self._abduce_by_search(pred_res, key, max_address_num, require_more_address, multiple_predictions) |
|
|
|
@@ -200,7 +201,7 @@ class add_KB(ClsKB): |
|
|
|
def logic_forward(self, nums): |
|
|
|
return sum(nums) |
|
|
|
|
|
|
|
|
|
|
|
# TODO:这是个回归任务(对于y而言),在logic_forward加round变成离散的分类任务固然可行,但最好还是用RegKB吧,作为例子示范。还需要对下面的ClsKB进行修改(见TODO) |
|
|
|
class HWF_KB(ClsKB): |
|
|
|
def __init__( |
|
|
|
self, GKB_flag=False, pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], len_list=[1, 3, 5, 7] |
|
|
|
@@ -334,7 +335,13 @@ class HED_prolog_KB(prolog_KB): |
|
|
|
|
|
|
|
# def consist_rules(self, pred_res, rules): |
|
|
|
|
|
|
|
|
|
|
|
# TODO:这里需要修改一下这个类,原本的RegKB是对GKB而言的,现在需要和ClsKB一样同时支持GKB和非GKB。需要补充非GKB部分(可能继承_abduce_by_search就行),以及修改GKB部分_abduce_by_GKB的逻辑(原本逻辑是找与key最近的y的abduce结果,现在改成与key在一定误差范围内的y的abduce结果) |
|
|
|
# TODO:我理解的RegKB是这样的: |
|
|
|
# TODO:1. 对GKB而言,即_abduce_by_GKB,给定key和length,还需要一个self.max_err,返回所有与key绝对值小于max_err的abduction结果 |
|
|
|
# TODO:比如GKB里的y有[1.3, 1.49, 1.50, 1.52, 1.6],若key=1.5,max_err=1e-5,则返回[y=1.50]的abduction结果;若key=1.5,max_err=0.05,则返回所有[y=1.49, 1.50, 1.52]的abduction结果 |
|
|
|
# TODO:因此在二分查找bisect_left后,需要分别往前和往后遍历,从GKB里找符合误差的y |
|
|
|
# TODO:self.max_err默认值取很小就行,比如HWF这类任务;但有些任务(比如法院刑期预测)的max_err需要大些,因此可以由用户自定义 |
|
|
|
# TODO:2. 对非GKB而言,估计直接用_abduce_by_search就行,check_equal那限定为数字且控制回归误差max_err |
|
|
|
class RegKB(KBBase): |
|
|
|
def __init__(self, GKB_flag=False, X=None, Y=None): |
|
|
|
super().__init__() |
|
|
|
@@ -355,7 +362,7 @@ class RegKB(KBBase): |
|
|
|
def logic_forward(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def abduce_candidates(self, key, length=None): |
|
|
|
def _abduce_by_GKB(self, key, length=None): |
|
|
|
if key is None: |
|
|
|
return self.get_all_candidates() |
|
|
|
|
|
|
|
|