diff --git a/examples/hed/framework_hed.py b/examples/hed/framework_hed.py index 6e42d1b..fbfecf9 100644 --- a/examples/hed/framework_hed.py +++ b/examples/hed/framework_hed.py @@ -81,7 +81,7 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): for idx in range(len(pred_res)): address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.address_by_idx([pred_res[idx]], None, address_idx) + candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0]) diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 2681074..76f1ded 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -71,7 +71,7 @@ " def __init__(self, kb, dist_func='hamming'):\n", " super().__init__(kb, dist_func, zoopt=True)\n", " \n", - " def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", + " def _revise_by_idxs(self, pred_res, key, all_address_flag, idxs):\n", " pred = []\n", " k = []\n", " address_flag = []\n", @@ -80,10 +80,10 @@ " k.append(key[idx])\n", " address_flag += list(all_address_flag[idx])\n", " address_idx = np.where(np.array(address_flag) != 0)[0] \n", - " candidate = self.address_by_idx(pred, k, address_idx)\n", + " candidate = self.revise_by_idx(pred, k, address_idx)\n", " return candidate\n", " \n", - " def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): \n", + " def zoopt_revision_score(self, pred_res, pred_res_prob, key, sol): \n", " all_address_flag = reform_idx(sol.get_x(), pred_res)\n", " lefted_idxs = [i for i in range(len(pred_res))]\n", " candidate_size = [] \n", @@ -95,7 +95,7 @@ " for idx in range(-1, len(pred_res)):\n", " if (not idx in idxs) and (idx >= 0):\n", " idxs.append(idx)\n", - " candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)\n", + " candidate = self._revise_by_idxs(pred_res, key, all_address_flag, idxs)\n", " if len(candidate) == 0:\n", " if len(idxs) > 1:\n", " idxs.pop()\n",