From e389b6427e8722f9bc7d09ffc2e742ce8b9d4b9d Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Tue, 6 Jun 2023 19:47:47 +0800 Subject: [PATCH] [ENH] run hed example successfully after change ABLModel output --- abl/learning/abl_model.py | 6 +- examples/hed/framework_hed.py | 264 ++++++++++++++++++++++++--------- examples/hed/hed_example.ipynb | 12 +- 3 files changed, 203 insertions(+), 79 deletions(-) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index a1fc6a4..8f04a60 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -87,8 +87,7 @@ class ABLModel: The accuracy score for the given data. """ data_X, _ = self.merge_data(X) - _data_Y, _ = self.merge_data(Y) - data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + data_Y, _ = self.merge_data(Y) score = self.classifier_list[0].score(X=data_X, y=data_Y) return score @@ -104,8 +103,7 @@ class ABLModel: The true labels for the given data. """ data_X, _ = self.merge_data(X) - _data_Y, _ = self.merge_data(Y) - data_Y = list(map(lambda y: self.mapping[y], _data_Y)) + data_Y, _ = self.merge_data(Y) self.classifier_list[0].fit(X=data_X, y=data_Y) @staticmethod diff --git a/examples/hed/framework_hed.py b/examples/hed/framework_hed.py index fbfecf9..089b957 100644 --- a/examples/hed/framework_hed.py +++ b/examples/hed/framework_hed.py @@ -21,6 +21,7 @@ from abl.learning.basic_nn import BasicNN, BasicDataset from utils import gen_mappings, mapping_res, remapping_res from models.nn import SymbolNetAutoencoder +from torch.utils.data import RandomSampler from datasets.get_hed import get_pretrain_data @@ -29,85 +30,170 @@ def hed_pretrain(kb, cls, recorder): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if not os.path.exists("./weights/pretrain_weights.pth"): INFO("Pretrain Start") - pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11']) + pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) - pretrain_data_loader = torch.utils.data.DataLoader(pretrain_data, batch_size=64, shuffle=True) - + pretrain_data_loader = torch.utils.data.DataLoader( + pretrain_data, batch_size=64, shuffle=True + ) + criterion = nn.MSELoss() - optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6) + optimizer = torch.optim.RMSprop( + cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 + ) - pretrain_model = BasicNN(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder) + pretrain_model = BasicNN( + cls_autoencoder, + criterion, + optimizer, + device, + save_interval=1, + save_dir=recorder.save_dir, + num_epochs=10, + recorder=recorder, + ) pretrain_model.fit(pretrain_data_loader) - torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth") + torch.save( + cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" + ) cls.load_state_dict(cls_autoencoder.base_model.state_dict()) - + else: cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) def _get_char_acc(model, X, consistent_pred_res, mapping): - original_pred_res = model.predict(X)['cls'] + original_pred_res = model.predict(X)["label"] pred_res = flatten(mapping_res(original_pred_res, mapping)) - INFO('Current model\'s output: ', pred_res) - INFO('Abduced labels: ', flatten(consistent_pred_res)) + INFO("Current model's output: ", pred_res) + INFO("Abduced labels: ", flatten(consistent_pred_res)) assert len(pred_res) == len(flatten(consistent_pred_res)) - return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res) + return sum( + [ + pred_res[idx] == flatten(consistent_pred_res)[idx] + for idx in range(len(pred_res)) + ] + ) / len(pred_res) def abduce_and_train(model, abducer, mapping, train_X_true, select_num): - select_idx = np.random.randint(len(train_X_true), size=select_num) - X = [] - for idx in select_idx: - X.append(train_X_true[idx]) + select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False) + X = [train_X_true[idx] for idx in select_idx] + + # original_pred_res = model.predict(X)['label'] + pred_label = model.predict(X)["label"] - original_pred_res = model.predict(X)['cls'] - if mapping == None: - mappings = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1]) + mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) else: mappings = [mapping] - + consistent_idx = [] consistent_pred_res = [] - + for m in mappings: - pred_res = mapping_res(original_pred_res, m) - max_abduce_num = 20 - solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) - all_address_flag = reform_idx(solution, pred_res) + pred_pseudo_label = mapping_res(pred_label, m) + max_revision_num = 20 + solution = abducer.zoopt_get_solution( + pred_label, + pred_pseudo_label, + [None] * len(pred_label), + [None] * len(pred_label), + max_revision_num, + ) + all_address_flag = reform_idx(solution, pred_label) consistent_idx_tmp = [] consistent_pred_res_tmp = [] - - for idx in range(len(pred_res)): - address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] - candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) + + for idx in range(len(pred_label)): + address_idx = [ + i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 + ] + candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx) if len(candidate) > 0: consistent_idx_tmp.append(idx) consistent_pred_res_tmp.append(candidate[0][0]) - + if len(consistent_idx_tmp) > len(consistent_idx): consistent_idx = consistent_idx_tmp consistent_pred_res = consistent_pred_res_tmp if len(mappings) > 1: mapping = m - + if len(consistent_idx) == 0: return 0, 0, None - - INFO('Train pool size is:', len(flatten(consistent_pred_res))) + + INFO("Train pool size is:", len(flatten(consistent_pred_res))) INFO("Start to use abduced pseudo label to train model...") - model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) + model.train( + [X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping) + ) consistent_acc = len(consistent_idx) / select_num - char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) - INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) + char_acc = _get_char_acc( + model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping + ) + INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) return consistent_acc, char_acc, mapping + +# def abduce_and_train(model, abducer, mapping, train_X_true, select_num): +# select_idx = np.random.randint(len(train_X_true), size=select_num) +# X = [] +# for idx in select_idx: +# X.append(train_X_true[idx]) + +# original_pred_res = model.predict(X)['label'] + +# if mapping == None: +# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1]) +# else: +# mappings = [mapping] + +# consistent_idx = [] +# consistent_pred_res = [] + +# for m in mappings: +# pred_res = mapping_res(original_pred_res, m) +# max_abduce_num = 20 +# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) +# all_address_flag = reform_idx(solution, pred_res) + +# consistent_idx_tmp = [] +# consistent_pred_res_tmp = [] + +# for idx in range(len(pred_res)): +# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] +# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) +# if len(candidate) > 0: +# consistent_idx_tmp.append(idx) +# consistent_pred_res_tmp.append(candidate[0][0]) + +# if len(consistent_idx_tmp) > len(consistent_idx): +# consistent_idx = consistent_idx_tmp +# consistent_pred_res = consistent_pred_res_tmp +# if len(mappings) > 1: +# mapping = m + +# if len(consistent_idx) == 0: +# return 0, 0, None + +# INFO('Train pool size is:', len(flatten(consistent_pred_res))) +# INFO("Start to use abduced pseudo label to train model...") +# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) + +# consistent_acc = len(consistent_idx) / select_num +# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) +# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) +# return consistent_acc, char_acc, mapping + + def _remove_duplicate_rule(rule_dict): add_nums_dict = {} for r in list(rule_dict): - add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' + add_nums = str(r.split("]")[0].split("[")[1]) + str( + r.split("]")[1].split("[")[1] + ) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' if add_nums in add_nums_dict: old_r = add_nums_dict[add_nums] if rule_dict[r] >= rule_dict[old_r]: @@ -120,7 +206,9 @@ def _remove_duplicate_rule(rule_dict): return list(rule_dict) -def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num): +def get_rules_from_data( + model, abducer, mapping, train_X_true, samples_per_rule, samples_num +): rules = [] for _ in range(samples_num): while True: @@ -128,7 +216,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, X = [] for idx in select_idx: X.append(train_X_true[idx]) - original_pred_res = model.predict(X)['cls'] + original_pred_res = model.predict(X)["label"] pred_res = mapping_res(original_pred_res, mapping) consistent_idx = [] @@ -143,42 +231,47 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, if rule != None: break rules.append(rule) - + all_rule_dict = {} for rule in rules: for r in rule: all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} rules = _remove_duplicate_rule(rule_dict) - + return rules def _get_consist_rule_acc(model, abducer, mapping, rules, X): cnt = 0 for x in X: - original_pred_res = model.predict([x])['cls'] + original_pred_res = model.predict([x])["label"] pred_res = flatten(mapping_res(original_pred_res, mapping)) if abducer.kb.consist_rule(pred_res, rules): cnt += 1 return cnt / len(X) -def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8): +def train_with_rule( + model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8 +): train_X = train_data val_X = val_data - + samples_num = 50 samples_per_rule = 3 # Start training / for each length of equations for equation_len in range(min_len, max_len): - INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1)) + INFO( + "============== equation_len: %d-%d ================" + % (equation_len, equation_len + 1) + ) train_X_true = train_X[1][equation_len] train_X_false = train_X[0][equation_len] val_X_true = val_X[1][equation_len] val_X_false = val_X[0][equation_len] - + train_X_true.extend(train_X[1][equation_len + 1]) train_X_false.extend(train_X[0][equation_len + 1]) val_X_true.extend(val_X[1][equation_len + 1]) @@ -188,12 +281,14 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len while True: if equation_len == min_len: mapping = None - + # Abduce and train NN - consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, mapping, train_X_true, select_num) + consistent_acc, char_acc, mapping = abduce_and_train( + model, abducer, mapping, train_X_true, select_num + ) if consistent_acc == 0: continue - + # Test if we can use mlp to evaluate if consistent_acc >= 0.9 and char_acc >= 0.9: condition_cnt += 1 @@ -203,32 +298,49 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len # The condition has been satisfied continuously five times if condition_cnt >= 5: INFO("Now checking if we can go to next course") - rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num) - INFO('Learned rules from data:', rules) - - true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true) - false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false) - - INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc)) + rules = get_rules_from_data( + model, abducer, mapping, train_X_true, samples_per_rule, samples_num + ) + INFO("Learned rules from data:", rules) + + true_consist_rule_acc = _get_consist_rule_acc( + model, abducer, mapping, rules, val_X_true + ) + false_consist_rule_acc = _get_consist_rule_acc( + model, abducer, mapping, rules, val_X_false + ) + + INFO( + "consist_rule_acc is %f, %f\n" + % (true_consist_rule_acc, false_consist_rule_acc) + ) # decide next course or restart if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: - torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) + torch.save( + model.classifier_list[0].model.state_dict(), + "./weights/weights_%d.pth" % equation_len, + ) break else: if equation_len == min_len: - INFO('Final mapping is: ', mapping) - model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth")) + INFO("Final mapping is: ", mapping) + model.classifier_list[0].model.load_state_dict( + torch.load("./weights/pretrain_weights.pth") + ) else: - model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1))) + model.classifier_list[0].model.load_state_dict( + torch.load("./weights/weights_%d.pth" % (equation_len - 1)) + ) condition_cnt = 0 - INFO('Reload Model and retrain') - + INFO("Reload Model and retrain") + return model, mapping + def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8): train_X = train_data test_X = test_data - + # Calcualte how many equations should be selected in each length # for each length, there are equation_samples_num[equation_len] rules print("Now begin to train final mlp model") @@ -247,16 +359,30 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len= rules = [] samples_per_rule = 3 for equation_len in range(min_len, max_len + 1): - equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len]) + equation_rules = get_rules_from_data( + model, + abducer, + mapping, + train_X[1][equation_len], + samples_per_rule, + equation_samples_num[equation_len], + ) rules.extend(equation_rules) rules = list(set(rules)) - INFO('Learned rules from data:', rules) - - + INFO("Learned rules from data:", rules) + for equation_len in range(5, 27): - true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len]) - false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len]) - INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc)) + true_consist_rule_acc = _get_consist_rule_acc( + model, abducer, mapping, rules, test_X[1][equation_len] + ) + false_consist_rule_acc = _get_consist_rule_acc( + model, abducer, mapping, rules, test_X[0][equation_len] + ) + INFO( + "consist_rule_acc of testing length %d equations are %f, %f" + % (equation_len, true_consist_rule_acc, false_consist_rule_acc) + ) + if __name__ == "__main__": pass diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 76f1ded..d562852 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -83,8 +83,8 @@ " candidate = self.revise_by_idx(pred, k, address_idx)\n", " return candidate\n", " \n", - " def zoopt_revision_score(self, pred_res, pred_res_prob, key, sol): \n", - " all_address_flag = reform_idx(sol.get_x(), pred_res)\n", + " def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, key, sol): \n", + " all_address_flag = reform_idx(sol.get_x(), pseudo_label)\n", " lefted_idxs = [i for i in range(len(pred_res))]\n", " candidate_size = [] \n", " while lefted_idxs:\n", @@ -95,7 +95,7 @@ " for idx in range(-1, len(pred_res)):\n", " if (not idx in idxs) and (idx >= 0):\n", " idxs.append(idx)\n", - " candidate = self._revise_by_idxs(pred_res, key, all_address_flag, idxs)\n", + " candidate = self._revise_by_idxs(pseudo_label, key, all_address_flag, idxs)\n", " if len(candidate) == 0:\n", " if len(idxs) > 1:\n", " idxs.pop()\n", @@ -106,7 +106,7 @@ " removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n", " if found:\n", " candidate_size.append(len(removed) + 1)\n", - " lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] \n", + " lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n", " candidate_size.sort()\n", " score = 0\n", " import math\n", @@ -189,7 +189,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = ABLModel(base_model, kb.pseudo_label_list)" + "model = ABLModel(base_model)" ] }, { @@ -221,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [