Browse Source

[MNT] Change variable name from address to revise/revision

pull/3/head
troyyyyy 3 years ago
parent
commit
f589ccef8e
3 changed files with 93 additions and 93 deletions
  1. +30
    -30
      abl/reasoning/kb.py
  2. +19
    -19
      abl/reasoning/readme.md
  3. +44
    -44
      abl/reasoning/reasoner.py

+ 30
- 30
abl/reasoning/kb.py View File

@@ -110,51 +110,51 @@ class KBBase(ABC):
return []
else:
cost_list = hamming_dist(pred_res, all_candidates)
min_address_num = np.min(cost_list)
address_num = min(max_revision_num, min_address_num + require_more_revision)
idxs = np.where(cost_list <= address_num)[0]
min_revision_num = np.min(cost_list)
revision_num = min(max_revision_num, min_revision_num + require_more_revision)
idxs = np.where(cost_list <= revision_num)[0]
candidates = [all_candidates[idx] for idx in idxs]
return candidates

def address_by_idx(self, pred_res, y, address_idx):
def revise_by_idx(self, pred_res, y, revision_idx):
candidates = []
abduce_c = product(self.pseudo_label_list, repeat=len(address_idx))
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
if check_equal(self.logic_forward(candidate), y, self.max_err):
candidates.append(candidate)
return candidates

def _address(self, address_num, pred_res, y):
def _revision(self, revision_num, pred_res, y):
new_candidates = []
address_idx_list = combinations(list(range(len(pred_res))), address_num)
revision_idx_list = combinations(list(range(len(pred_res))), revision_num)

for address_idx in address_idx_list:
candidates = self.address_by_idx(pred_res, y, address_idx)
for revision_idx in revision_idx_list:
candidates = self.revise_by_idx(pred_res, y, revision_idx)
new_candidates += candidates
return new_candidates

def _abduce_by_search(self, pred_res, y, max_revision_num, require_more_revision):
candidates = []
for address_num in range(len(pred_res) + 1):
if address_num == 0:
for revision_num in range(len(pred_res) + 1):
if revision_num == 0:
if check_equal(self.logic_forward(pred_res), y, self.max_err):
candidates.append(pred_res)
else:
new_candidates = self._address(address_num, pred_res, y)
new_candidates = self._revision(revision_num, pred_res, y)
candidates += new_candidates
if len(candidates) > 0:
min_address_num = address_num
min_revision_num = revision_num
break
if address_num >= max_revision_num:
if revision_num >= max_revision_num:
return []

for address_num in range(min_address_num + 1, min_address_num + require_more_revision + 1):
if address_num > max_revision_num:
for revision_num in range(min_revision_num + 1, min_revision_num + require_more_revision + 1):
if revision_num > max_revision_num:
return candidates
new_candidates = self._address(address_num, pred_res, y)
new_candidates = self._revision(revision_num, pred_res, y)
candidates += new_candidates
return candidates
@@ -191,35 +191,35 @@ class prolog_KB(KBBase):
return False
return result
def _address_pred_res(self, pred_res, address_idx):
def _revision_pred_res(self, pred_res, revision_idx):
import re
address_pred_res = pred_res.copy()
address_pred_res = flatten(address_pred_res)
revision_pred_res = pred_res.copy()
revision_pred_res = flatten(revision_pred_res)
for idx in address_idx:
address_pred_res[idx] = 'P' + str(idx)
address_pred_res = reform_idx(address_pred_res, pred_res)
for idx in revision_idx:
revision_pred_res[idx] = 'P' + str(idx)
revision_pred_res = reform_idx(revision_pred_res, pred_res)
# TODO:不知道有没有更简洁的方法
regex = r"'P\d+'"
return re.sub(regex, lambda x: x.group().replace("'", ""), str(address_pred_res))
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pred_res))
def get_query_string(self, pred_res, y, address_idx):
def get_query_string(self, pred_res, y, revision_idx):
query_string = "logic_forward("
query_string += self._address_pred_res(pred_res, address_idx)
query_string += self._revision_pred_res(pred_res, revision_idx)
key_is_none_flag = y is None or (type(y) == list and y[0] is None)
query_string += ",%s)." % y if not key_is_none_flag else ")."
return query_string
def address_by_idx(self, pred_res, y, address_idx):
def revise_by_idx(self, pred_res, y, revision_idx):
candidates = []
query_string = self.get_query_string(pred_res, y, address_idx)
query_string = self.get_query_string(pred_res, y, revision_idx)
save_pred_res = pred_res
pred_res = flatten(pred_res)
abduce_c = [list(z.values()) for z in self.prolog.query(query_string)]
for c in abduce_c:
candidate = pred_res.copy()
for i, idx in enumerate(address_idx):
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_idx(candidate, save_pred_res)
candidates.append(candidate)


+ 19
- 19
abl/reasoning/readme.md View File

@@ -47,15 +47,15 @@
KB 的反绎功能在`abduce_candidates`中自动实现. 在调用`abduce_candidates`中需要传入以下参数:

- `pred_res`: 机器学习的输出的伪标记
- `key`: 逻辑推理的正确结果
- `max_address_num`: 最多修改的伪标记的个数
- `require_more_address`: 在已经反绎出结果之后, 指明是否需要继续增加修改伪标记的个数来继续得到更多的反绎结果.
- `y`: 逻辑推理的正确结果
- `max_revision_num`: 最多修改的伪标记的个数
- `require_more_revision`: 在已经反绎出结果之后, 指明是否需要继续增加修改伪标记的个数来继续得到更多的反绎结果.

得到的输出为所有可能的反绎的结果.

> 例如: 上文定义的`kb1` (MNIST_add 的 KB) 调用`kb1.abduce_candidates`得到的结果如下表:
>
> |`pred_res`|`key`|`max_address_num`|`require_more_address`|输出|
> |`pred_res`|`y`|`max_revision_num`|`require_more_revision`|输出|
> |:---:|:---:|:---:|:---:|:----|
> |[1,1]|8|2|1|[[1,7],[7,1],[2,6],[6,2],[3,5],[5,3],[4,4]]|
> |[1,1]|8|2|0|[[1,7],[7,1]]|
@@ -70,19 +70,19 @@ KB 的反绎功能在`abduce_candidates`中自动实现. 在调用`abduce_candid

#### `_abduce_by_GKB`

搜索 GKB 中是否存在 key 为`key`, 且满足`pred_res`, `max_address_num`和`require_more_address`组成的限制条件的反绎结果.
搜索 GKB 中是否存在 key 为`y`, 且满足`pred_res`, `max_revision_num`和`require_more_revision`组成的限制条件的反绎结果.

> 比如, MNIST_add 中, 传入的`key`为 4, 此时在 GKB 中可以找到 [1,3], [3,1] 和 [2,2]. 如果此时传入的`pred_res`为 [2,8], `max_address_num`为 2, `require_more_address`为 0, 则输出的结果为 [2,2].
> 比如, MNIST_add 中, 传入的`y`为 4, 此时在 GKB 中可以找到 [1,3], [3,1] 和 [2,2]. 如果此时传入的`pred_res`为 [2,8], `max_revision_num`为 2, `require_more_revision`为 0, 则输出的结果为 [2,2].

#### `_abduce_by_search`

从 0 开始不断增加修改伪标记的个数, 通过枚举得到所有可能的修改后的伪标记, 直到达到`max_address_num`或找到符合`logic_forward`定义的逻辑的结果. 接着, 如果`require_more_address`不为0就继续增加修改伪标记的个数, 将符合的结果一起输出.
从 0 开始不断增加修改伪标记的个数, 通过枚举得到所有可能的修改后的伪标记, 直到达到`max_revision_num`或找到符合`logic_forward`定义的逻辑的结果. 接着, 如果`require_more_revision`不为0就继续增加修改伪标记的个数, 将符合的结果一起输出.

> 比如, MNIST_add 中, 传入的`key`为 4, `pred_res`为 [2,8], `max_address_num`为 2: 当修改 0 个伪标记时, 不能得到正确的结果. 当修改 1 个伪标记时, 可能的修改后逻辑输入为 [2,0],[2,1],[2,2],[2,3],[2,4],[2,5],[2,6],[2,7],[2,9], [0,8],[1,8],[3,8],[4,8],[5,8],[6,8],[7,8],[8,8],[9,8], 其中, [2,2] 符合逻辑的结果, 如果传入的`require_more_address`是 0, 则最终输出的结果就是 [2,2], 否则, 继续增加修改的伪标记的个数并检验.
> 比如, MNIST_add 中, 传入的`y`为 4, `pred_res`为 [2,8], `max_revision_num`为 2: 当修改 0 个伪标记时, 不能得到正确的结果. 当修改 1 个伪标记时, 可能的修改后逻辑输入为 [2,0],[2,1],[2,2],[2,3],[2,4],[2,5],[2,6],[2,7],[2,9], [0,8],[1,8],[3,8],[4,8],[5,8],[6,8],[7,8],[8,8],[9,8], 其中, [2,2] 符合逻辑的结果, 如果传入的`require_more_revision`是 0, 则最终输出的结果就是 [2,2], 否则, 继续增加修改的伪标记的个数并检验.

_注: 如果使用 prolog 作为 KB, `_abduce_by_search`不会手动地枚举所有的可能修改后的伪标记, prolog 程序运行时会有一些剪枝等操作可以加速反绎. 关于这一部分, 将在下一小节`prolog_KB`中详述._

_注: 当使用`zoopt`或其他方式已经获得了需要修改的 index 时, 不需要调用整个`_abduce_by_search`的流程, 只需要调用其中的`address_by_idx`即可. 关于这一部分, 将在`abducer_base.py`中详述._
_注: 当使用`zoopt`或其他方式已经获得了需要修改的 index 时, 不需要调用整个`_abduce_by_search`的流程, 只需要调用其中的`revise_by_idx`即可. 关于这一部分, 将在`abducer_base.py`中详述._

## `prolog_KB`

@@ -115,9 +115,9 @@ _需要注意: 传入的 prolog 程序中需要有 `logic_forward` 的实现,

### `max_err`

当逻辑推理部分的输出为数值时, 可以传入`max_err`, 使得调用`abduce_candidates`时, 只要满足与`key`的误差在`max_err`之间的结果都会被输出.
当逻辑推理部分的输出为数值时, 可以传入`max_err`, 使得调用`abduce_candidates`时, 只要满足与`y`的误差在`max_err`之间的结果都会被输出.

> 例如: 上文定义的`kb1` (MNIST_add 的 KB), 当`pred_res`为 [2,2], `key`为 7, `max_address_num`为 2, `require_more_address`为 0 时, 如果`max_err`为 0, 则输出的结果为 [[2,5],[5,2]]; 如果`max_err`为 1, 则输出的结果为 [[2,4],[2,5],[2,6],[4,2],[5,2],[6,2]].
> 例如: 上文定义的`kb1` (MNIST_add 的 KB), 当`pred_res`为 [2,2], `y`为 7, `max_revision_num`为 2, `require_more_revision`为 0 时, 如果`max_err`为 0, 则输出的结果为 [[2,5],[5,2]]; 如果`max_err`为 1, 则输出的结果为 [[2,4],[2,5],[2,6],[4,2],[5,2],[6,2]].

### `use_cache`

@@ -140,8 +140,8 @@ _需要注意: 传入的 prolog 程序中需要有 `logic_forward` 的实现,
`AbducerBase` 主要实现的函数是 `abduce`. 它的功能是有了数据之后得到**一个** **最有可能的**反绎的结果. 在调用`abduce`中需要传入以下参数:

- `data`: 三元素组成的 tuple, 三个元素分别为 `pred_res`, `pred_res_prob`, `key`. 其中, `pred_res`和`key`的定义同`kb.py`中的`abduce_candidates`, 而 `pred_res_prob` 是机器学习输出的每个伪标记的置信度列表.
- `max_address`: 最多修改的伪标记的数量, 可以以 float 或 int 的形式输入. 如果传入 float 最多修改的伪标记占所有伪标记的比重, 如果传入 int 为最多修改伪标记的个数 (此时同`kb.py`中`abduce_candidates`中`max_address_num`的定义)
- `require_more_address`: 定义同`kb.py`中的 `abduce_candidates`同一参数.
- `max_revision`: 最多修改的伪标记的数量, 可以以 float 或 int 的形式输入. 如果传入 float 最多修改的伪标记占所有伪标记的比重, 如果传入 int 为最多修改伪标记的个数 (此时同`kb.py`中`abduce_candidates`中`max_revision_num`的定义)
- `require_more_revision`: 定义同`kb.py`中的 `abduce_candidates`同一参数.

输出为一个反绎的结果.

@@ -151,12 +151,12 @@ _需要注意: 传入的 prolog 程序中需要有 `logic_forward` 的实现,

- 当`zoopt`为`False`时, 不使用零阶优化, 直接使用`kb.py`中的`abduce_candidates`找到所有可能的反绎结果, 然后使用`_get_one_candidate`找到最有可能的反绎结果.

- 当`zoopt`为`True`时, 使用零阶优化找到修改的伪标记的 index, 然后使用`kb.py`中的`address_by_idx`找到反绎结果, 最后使用`_get_one_candidate`找到最有可能的反绎结果.
> 比如, MNIST_add 中的`pred_res`为 [2,9], `key`为 18, 首先使用零阶优化会得到应该修改的伪标记的 index 为 0, 接着代入`pred_res`, `key` 和修改的 index([0]) 到`kb.py`的`address_by_idx`中, 在其中会先得到修改 index 处的伪标记后所有可能的逻辑输入为 [0,9],[1,9],[3,9],[4,9],[5,9],[6,9],[7,9],[8,9],[9,9], 其中 [9,9] 是符合逻辑的, 则输出 [9,9].
- 当`zoopt`为`True`时, 使用零阶优化找到修改的伪标记的 index, 然后使用`kb.py`中的`revise_by_idx`找到反绎结果, 最后使用`_get_one_candidate`找到最有可能的反绎结果.
> 比如, MNIST_add 中的`pred_res`为 [2,9], `key`为 18, 首先使用零阶优化会得到应该修改的伪标记的 index 为 0, 接着代入`pred_res`, `key` 和修改的 index([0]) 到`kb.py`的`revise_by_idx`中, 在其中会先得到修改 index 处的伪标记后所有可能的逻辑输入为 [0,9],[1,9],[3,9],[4,9],[5,9],[6,9],[7,9],[8,9],[9,9], 其中 [9,9] 是符合逻辑的, 则输出 [9,9].
>
> 再比如, HED 中`pred_res`为[1,0,1,'=',1], (`key`默认设置为`None`). 首先使用零阶优化会得到应该修改的伪标记的 index 为 1, 代入到`kb.py`的`address_by_idx`可以得到修改 index 处的伪标记后所有可能的逻辑输入为 [1,1,1,'=',1],[1,'+',1,'=',1],[1,'=',1,'=',1], 其中 [1,'+',1,'=',1] 是符合逻辑的, 则输出 [1,'+',1,'=',1].
> 再比如, HED 中`pred_res`为[1,0,1,'=',1], (`key`默认设置为`None`). 首先使用零阶优化会得到应该修改的伪标记的 index 为 1, 代入到`kb.py`的`revise_by_idx`可以得到修改 index 处的伪标记后所有可能的逻辑输入为 [1,1,1,'=',1],[1,'+',1,'=',1],[1,'=',1,'=',1], 其中 [1,'+',1,'=',1] 是符合逻辑的, 则输出 [1,'+',1,'=',1].
_注:零阶优化中同样支持`max_address`的限制, 但是暂时没有设置支持`require_more_address`._
_注:零阶优化中同样支持`max_revision`的限制, 但是暂时没有设置支持`require_more_revision`._

## 何为“最有可能的”

@@ -165,8 +165,8 @@ _需要注意: 传入的 prolog 程序中需要有 `logic_forward` 的实现,
- `hamming`: 用反绎后的结果与`pred_res`之间的汉明距离作为度量, 输出距离最小的反绎结果.
- `confidence`: 用反绎后的结果与`pred_res_prob`之间的距离作为度量, 输出距离最小的反绎结果.
> 比如, MNIST_add 中的`pred_res`为 [1,1], `key` 为 8, `max_address`为 1, 如果`pred_res_prob`为 [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], 则输出的结果为 [1,7], 如果`pred_res_prob`为 [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], 则输出的结果为 [7,1].
> 比如, MNIST_add 中的`pred_res`为 [1,1], `key` 为 8, `max_revision`为 1, 如果`pred_res_prob`为 [[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], 则输出的结果为 [1,7], 如果`pred_res_prob`为 [[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], 则输出的结果为 [7,1].

## 批量化反绎

可以使用`batch_abduce`同时传入一批数据进行反绎, 如上文定义的 `abd1`, 调用`abd1.batch_abduce({'cls':[[1,1], [1,2]], 'prob':multiple_prob}, [4,8], max_address=2, require_more_address=0)`时, 返回的结果为 [[1,3], [6,2]].
可以使用`batch_abduce`同时传入一批数据进行反绎, 如上文定义的 `abd1`, 调用`abd1.batch_abduce({'cls':[[1,1], [1,2]], 'prob':multiple_prob}, [4,8], max_revision=2, require_more_revision=0)`时, 返回的结果为 [[1,3], [6,2]].

+ 44
- 44
abl/reasoning/reasoner.py View File

@@ -66,17 +66,17 @@ class ReasonerBase(abc.ABC):
candidate = candidates[np.argmin(cost_list)]
return candidate
def _zoopt_address_score_single(self, sol_x, pred_res, pred_res_prob, y):
address_idx = np.where(sol_x != 0)[0]
candidates = self.address_by_idx(pred_res, y, address_idx)
def _zoopt_revision_score_single(self, sol_x, pred_res, pred_res_prob, y):
revision_idx = np.where(sol_x != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):
"""
Get the address score for a single solution.
Get the revision score for a single solution.

Parameters
----------
@@ -92,20 +92,20 @@ class ReasonerBase(abc.ABC):
Returns
-------
float
The address score for the given solution.
The revision score for the given solution.
"""
address_idx = np.where(sol.get_x() != 0)[0]
candidates = self.address_by_idx(pred_res, y, address_idx)
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pred_res, pred_res_prob, candidates))
else:
return len(pred_res)
def _constrain_address_num(self, solution, max_address_num):
def _constrain_revision_num(self, solution, max_revision_num):
x = solution.get_x()
return max_address_num - x.sum()
return max_revision_num - x.sum()

def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_address_num):
def zoopt_get_solution(self, pred_res, pred_res_prob, y, max_revision_num):
"""Get the optimal solution using the Zoopt library.

Parameters
@@ -116,8 +116,8 @@ class ReasonerBase(abc.ABC):
List of probabilities for predicted results.
y : str
Ground truth for the predicted results.
max_address_num : int or float
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
max_revision_num : int
Maximum number of revisiones to use.

Returns
-------
@@ -127,16 +127,16 @@ class ReasonerBase(abc.ABC):
length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
objective = Objective(
lambda sol: self.zoopt_address_score(pred_res, pred_res_prob, y, sol),
lambda sol: self.zoopt_revision_score(pred_res, pred_res_prob, y, sol),
dim=dimension,
constraint=lambda sol: self._constrain_address_num(sol, max_address_num),
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
solution = Opt.min(objective, parameter).get_x()
return solution
def address_by_idx(self, pred_res, y, address_idx):
"""Get the addresses corresponding to the given indices.
def revise_by_idx(self, pred_res, y, revision_idx):
"""Get the revisiones corresponding to the given indices.

Parameters
----------
@@ -144,15 +144,15 @@ class ReasonerBase(abc.ABC):
List of predicted results.
y : str
Ground truth for the predicted results.
address_idx : array-like
Indices of the addresses to retrieve.
revision_idx : array-like
Indices of the revisiones to retrieve.

Returns
-------
list
The addresses corresponding to the given indices.
The revisiones corresponding to the given indices.
"""
return self.kb.address_by_idx(pred_res, y, address_idx)
return self.kb.revise_by_idx(pred_res, y, revision_idx)

def abduce(self, data, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data.
@@ -162,34 +162,34 @@ class ReasonerBase(abc.ABC):
data : tuple
Tuple containing the predicted results, predicted result probabilities, and y.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1.
Maximum number of revisiones to use. If float, represents the fraction of total revisiones to use.
If -1, use all revisiones. Defaults to -1.
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0.
Number of additional revisiones to require. Defaults to 0.

Returns
-------
list
The abduced addresses.
The abduced revisiones.
"""
pred_res, pred_res_prob, y = data
assert(type(max_revision) in (int, float))
if max_revision == -1:
max_address_num = len(flatten(pred_res))
max_revision_num = len(flatten(pred_res))
elif type(max_revision) == float:
assert(max_revision >= 0 and max_revision <= 1)
max_address_num = round(len(flatten(pred_res)) * max_revision)
max_revision_num = round(len(flatten(pred_res)) * max_revision)
else:
assert(max_revision >= 0)
max_address_num = max_revision
max_revision_num = max_revision

if self.zoopt:
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_address_num)
address_idx = np.where(solution != 0)[0]
candidates = self.address_by_idx(pred_res, y, address_idx)
solution = self.zoopt_get_solution(pred_res, pred_res_prob, y, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(pred_res, y, revision_idx)
else:
candidates = self.kb.abduce_candidates(pred_res, y, max_address_num, require_more_revision)
candidates = self.kb.abduce_candidates(pred_res, y, max_revision_num, require_more_revision)

candidate = self._get_one_candidate(pred_res, pred_res_prob, candidates)
return candidate
@@ -204,15 +204,15 @@ class ReasonerBase(abc.ABC):
Y : list
List of ground truths.
max_revision : int or float, optional
Maximum number of addresses to use. If float, represents the fraction of total addresses to use.
If -1, use all addresses. Defaults to -1.
Maximum number of revisiones to use. If float, represents the fraction of total revisiones to use.
If -1, use all revisiones. Defaults to -1.
require_more_revision : int, optional
Number of additional addresses to require. Defaults to 0.
Number of additional revisiones to require. Defaults to 0.

Returns
-------
list
The abduced addresses.
The abduced revisiones.
"""
return [self.abduce((z, prob, y), max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]
@@ -440,20 +440,20 @@ if __name__ == '__main__':
def __init__(self, kb, dist_func='hamming'):
super().__init__(kb, dist_func, zoopt=True)
def _address_by_idxs(self, pred_res, y, all_address_flag, idxs):
def _revise_by_idxs(self, pred_res, y, all_revision_flag, idxs):
pred = []
k = []
address_flag = []
revision_flag = []
for idx in idxs:
pred.append(pred_res[idx])
k.append(y[idx])
address_flag += list(all_address_flag[idx])
address_idx = np.where(np.array(address_flag) != 0)[0]
candidate = self.address_by_idx(pred, k, address_idx)
revision_flag += list(all_revision_flag[idx])
revision_idx = np.where(np.array(revision_flag) != 0)[0]
candidate = self.revise_by_idx(pred, k, revision_idx)
return candidate
def zoopt_address_score(self, pred_res, pred_res_prob, y, sol):
all_address_flag = reform_idx(sol.get_x(), pred_res)
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []
while lefted_idxs:
@@ -464,7 +464,7 @@ if __name__ == '__main__':
for idx in range(-1, len(pred_res)):
if (not idx in idxs) and (idx >= 0):
idxs.append(idx)
candidate = self._address_by_idxs(pred_res, y, all_address_flag, idxs)
candidate = self._revise_by_idxs(pred_res, y, all_revision_flag, idxs)
if len(candidate) == 0:
if len(idxs) > 1:
idxs.pop()


Loading…
Cancel
Save