From 7a424c8eac0c0c2d9d63cc2b9ec713d1c280231d Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 24 Feb 2023 17:03:51 +0800 Subject: [PATCH] Update framework_hed.py --- framework_hed.py | 176 ++++++++++++++++------------------------------- 1 file changed, 59 insertions(+), 117 deletions(-) diff --git a/framework_hed.py b/framework_hed.py index 3476462..b7439c3 100644 --- a/framework_hed.py +++ b/framework_hed.py @@ -119,7 +119,7 @@ def hed_pretrain(kb, cls, recorder): cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) -def get_char_acc(model, X, consistent_pred_res, mapping): +def _get_char_acc(model, X, consistent_pred_res, mapping): original_pred_res = model.predict(X)['cls'] pred_res = flatten(mapping_res(original_pred_res, mapping)) INFO('Current model\'s output: ', pred_res) @@ -177,22 +177,29 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num): model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) consistent_acc = len(consistent_idx) / select_num - char_acc = get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) + char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) return consistent_acc, char_acc, mapping +def _remove_duplicate_rule(rule_dict): + add_nums_dict = {} + for r in list(rule_dict): + add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' + if add_nums in add_nums_dict: + old_r = add_nums_dict[add_nums] + if rule_dict[r] >= rule_dict[old_r]: + rule_dict.pop(old_r) + add_nums_dict[add_nums] = r + else: + rule_dict.pop(r) + else: + add_nums_dict[add_nums] = r + return list(rule_dict) -def output_rules(rules): - all_rule_dict = {} - for rule in rules: - for r in rule: - all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 - rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items()}# if cnt >= 5} - return rule_dict -def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim): +def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num): rules = [] - for _ in range(logic_output_dim): + for _ in range(samples_num): while True: select_idx = np.random.randint(len(train_X_true), size=samples_per_rule) X = [] @@ -213,83 +220,32 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, if rule != None: break rules.append(rule) - return rules - - -def get_mlp_vector(model, abducer, mapping, X, rules): - original_pred_res = model.predict([X])['cls'] - pred_res = flatten(mapping_res(original_pred_res, mapping)) - vector = [] - for rule in rules: - if abducer.kb.consist_rule(pred_res, rule): - vector.append(1) - else: - vector.append(0) - return vector - - -def get_mlp_data(model, abducer, mapping, X_true, X_false, rules): - mlp_vectors = [] - mlp_labels = [] - for X in X_true: - mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) - mlp_labels.append(1) - for X in X_false: - mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules)) - mlp_labels.append(0) - return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64) - -def get_all_mlp_data(model, abducer, mapping, X_true, X_false, rules, min_len, max_len): - for equation_len in range(min_len, max_len + 1): - mlp_vectors, mlp_labels = get_mlp_data(model, abducer, mapping, X_true[equation_len], X_false[equation_len], rules) - if equation_len == min_len: - all_mlp_vectors = mlp_vectors - all_mlp_labels = mlp_labels - else: - all_mlp_vectors = np.concatenate((all_mlp_vectors, mlp_vectors)) - all_mlp_labels = np.concatenate((all_mlp_labels, mlp_labels)) - return all_mlp_vectors, all_mlp_labels - - -def validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false): - mlp_train_vectors, mlp_train_labels = get_mlp_data(model, abducer, mapping, train_X_true, train_X_false, rules) - mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels) - mlp_val_vectors, mlp_val_labels = get_mlp_data(model, abducer, mapping, val_X_true, val_X_false, rules) - mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels) + all_rule_dict = {} + for rule in rules: + for r in rule: + all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 + rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} + rules = _remove_duplicate_rule(rule_dict) - best_accuracy = 0 - # Try three times to find the best mlp - for _ in range(3): - INFO("Training mlp...") - mlp = MLP(input_dim=logic_output_dim) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100) - mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True) - - loss = mlp_model.fit(mlp_train_data_loader) - INFO("mlp training final loss is %f" % loss) - - mlp_val_data_loader = torch.utils.data.DataLoader(mlp_val_data, batch_size=64, shuffle=True) - accuracy = mlp_model.val(mlp_val_data_loader) - - if accuracy > best_accuracy: - best_accuracy = accuracy - return best_accuracy - - + return rules +def _get_consist_rule_acc(model, abducer, mapping, rules, X): + cnt = 0 + for x in X: + original_pred_res = model.predict([x])['cls'] + pred_res = flatten(mapping_res(original_pred_res, mapping)) + if abducer.kb.consist_rule(pred_res, rules): + cnt += 1 + return cnt / len(X) def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8): train_X = train_data val_X = val_data - logic_output_dim = 50 + samples_num = 50 samples_per_rule = 3 # Start training / for each length of equations @@ -324,12 +280,15 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len # The condition has been satisfied continuously five times if condition_cnt >= 5: INFO("Now checking if we can go to next course") - rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim) - INFO('Learned rules from data:', output_rules(rules)) - best_accuracy = validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false) - INFO('best_accuracy is %f\n' %(best_accuracy)) + rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num) + INFO('Learned rules from data:', rules) + + true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true) + false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false) + + INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc)) # decide next course or restart - if best_accuracy > 0.88: + if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1: torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) break else: @@ -347,50 +306,33 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len= test_X = test_data # Calcualte how many equations should be selected in each length - # for each length, there are select_equation_cnt[equation_len] rules + # for each length, there are equation_samples_num[equation_len] rules print("Now begin to train final mlp model") - select_equation_cnt = [] + equation_samples_num = [] len_cnt = max_len - min_len + 1 - logic_output_dim = 50 - select_equation_cnt += [0] * min_len - if logic_output_dim % len_cnt == 0: - select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt + samples_num = 50 + equation_samples_num += [0] * min_len + if samples_num % len_cnt == 0: + equation_samples_num += [samples_num // len_cnt] * len_cnt else: - select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt - select_equation_cnt[-1] += logic_output_dim % len_cnt - assert sum(select_equation_cnt) == logic_output_dim + equation_samples_num += [samples_num // len_cnt] * len_cnt + equation_samples_num[-1] += samples_num % len_cnt + assert sum(equation_samples_num) == samples_num # Abduce rules rules = [] samples_per_rule = 3 for equation_len in range(min_len, max_len + 1): - equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, select_equation_cnt[equation_len]) + equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len]) rules.extend(equation_rules) - INFO('Learned rules from data:', output_rules(rules)) - - mlp_train_vectors, mlp_train_labels = get_all_mlp_data(model, abducer, mapping, train_X[1], train_X[0], rules, min_len, max_len) - mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels) - - # Try three times to find the best mlp - for _ in range(3): - mlp = MLP(input_dim=logic_output_dim) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100) - mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True) - - loss = mlp_model.fit(mlp_train_data_loader) - INFO("mlp training final loss is %f" % loss) + rules = list(set(rules)) + INFO('Learned rules from data:', rules) + - for equation_len in range(5, 27): - mlp_test_vectors, mlp_test_labels = get_mlp_data(model, abducer, mapping, test_X[1][equation_len], test_X[0][equation_len], rules) - mlp_test_data = BasicDataset(mlp_test_vectors, mlp_test_labels) - mlp_test_data_loader = torch.utils.data.DataLoader(mlp_test_data, batch_size=64, shuffle=True) - accuracy = mlp_model.val(mlp_test_data_loader) - INFO("The accuracy of testing length %d equations is: %f" % (equation_len, accuracy)) - INFO("\n") + for equation_len in range(5, 27): + true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len]) + false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len]) + INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc)) if __name__ == "__main__": pass