diff --git a/framework_hed.py b/framework_hed.py index 04b41b8..36159da 100644 --- a/framework_hed.py +++ b/framework_hed.py @@ -11,26 +11,18 @@ # ================================================================# import pickle as pk - -import numpy as np - -import random - -random_seed = random.randint(0, 10000) -print("Selected random seed is : ", random_seed) -np.random.seed(random_seed) -random.seed(random_seed) - -from models.nn import MLP -from models.basic_model import BasicModel, BasicDataset -import torch.nn as nn +import math import torch +import torch.nn as nn +import numpy as np from utils.plog import INFO, DEBUG, clocker from utils.utils import flatten, reform_idx, block_sample +from utils.utils import copy_state_dict -from sklearn.tree import DecisionTreeClassifier - +from sklearn.linear_model import LogisticRegression +from models.nn import MLP +from models.basic_model import BasicModel, BasicDataset def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag): result = {} @@ -65,59 +57,89 @@ def filter_data(X, abduced_Z): return finetune_X, finetune_Z -def pretrain(net, pretrain_data_loader, recorder): - INFO("Pretrain Start") + + +def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbose=-1): + train_X, train_Z, train_Y = train_data + test_X, test_Z, test_Y = test_data - criterion = nn.MSELoss() - optimizer = torch.optim.RMSprop( - net.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 - ) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - pretrain_model = BasicModel( - net, - criterion, - optimizer, - device, - save_interval=1, - save_dir=recorder.save_dir, - num_epochs=10, - recorder=recorder, - ) + # Set default parameters + if sample_num == -1: + sample_num = len(train_X) - pretrain_model.fit(pretrain_data_loader) + if verbose < 1: + verbose = epochs + char_acc_flag = 1 + if train_Z == None: + char_acc_flag = 0 + train_Z = [None] * len(train_X) -def get_char_acc(model, X, consistent_pred_res): - pred_res = flatten(model.predict(X)["cls"]) - assert len(pred_res) == len(flatten(consistent_pred_res)) - return sum( - [ - pred_res[idx] == flatten(consistent_pred_res)[idx] - for idx in range(len(pred_res)) - ] - ) / len(pred_res) + predict_func = clocker(model.predict) + train_func = clocker(model.train) + abduce_func = clocker(abducer.batch_abduce) + for epoch_idx in range(epochs): + X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx) + preds_res = predict_func(X) + # input() + abduced_Z = abduce_func(preds_res, Y) -def gen_mappings(chars, symbs): - n_char = len(chars) - n_symbs = len(symbs) - if n_char != n_symbs: - INFO("Characters and symbols size dosen't match.") - return - from itertools import permutations + if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1): + res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag) + INFO('epoch: ', epoch_idx + 1, ' ', res) + + finetune_X, finetune_Z = filter_data(X, abduced_Z) + if len(finetune_X) > 0: + # model.valid(finetune_X, finetune_Z) + train_func(finetune_X, finetune_Z) + else: + INFO("lack of data, all abduced failed", len(finetune_X)) - mappings = [] - perms = permutations(symbs) - for p in perms: - mappings.append(dict(zip(chars, list(p)))) - return mappings + return res -def map_res(original_pred_res, m): - pred_res = [[m[symbol] for symbol in formula] for formula in original_pred_res] + +def pretrain(pretrain_model, pretrain_data): + INFO("Pretrain Start") + pretrain_data_loader = torch.utils.data.DataLoader( + pretrain_data, + batch_size=64, + shuffle=True, + num_workers=2, + ) + pretrain_model.fit(pretrain_data_loader) + + +def get_char_acc(model, X, consistent_pred_res): + print('Abduced labels: ', flatten(consistent_pred_res)) + pred_res = flatten(model.predict(X)['cls']) + print('Current model\'s output:', pred_res) + assert len(pred_res) == len(flatten(consistent_pred_res)) + return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res) + +def gen_mappings(chars, symbs): + n_char = len(chars) + n_symbs = len(symbs) + if n_char != n_symbs: + print('Characters and symbols size dosen\'t match.') + return + from itertools import permutations + mappings = [] + # returned mappings + perms = permutations(symbs) + for p in perms: + mappings.append(dict(zip(chars, list(p)))) + return mappings + +def map_res(pred_res, m): + for i in range(len(pred_res)): + for j in range(len(pred_res[i])): + pred_res[i][j] = m[pred_res[i][j]] return pred_res +def map_res(original_pred_res, m): + return [[m[symbol] for symbol in formula] for formula in original_pred_res] def abduce_and_train(model, abducer, train_X_true, select_num): select_idx = np.random.randint(len(train_X_true), size=select_num) @@ -125,68 +147,59 @@ def abduce_and_train(model, abducer, train_X_true, select_num): for idx in select_idx: X.append(train_X_true[idx]) - pred_res = model.predict(X)["cls"] - - maps = gen_mappings(["+", "=", 0, 1], ["+", "=", 0, 1]) - + pred_res = model.predict(X)['cls'] + + maps = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1]) + consistent_idx = [] consistent_pred_res = [] - + import copy original_pred_res = copy.deepcopy(pred_res) mapping = None - + for m in maps: pred_res = map_res(original_pred_res, m) remapping = {} for key, value in m.items(): remapping[value] = key - - max_abduce_num = 10 - solution = abducer.zoopt_get_solution( - pred_res, [1] * len(pred_res), max_abduce_num - ) + + max_abduce_num = 20 + solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num) all_address_flag = reform_idx(solution, pred_res) consistent_idx_tmp = [] consistent_pred_res_tmp = [] - + for idx in range(len(pred_res)): - address_idx = [ - i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 - ] + address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True) if len(candidate) > 0: consistent_idx_tmp.append(idx) - consistent_pred_res_tmp.append( - [remapping[symbol] for symbol in candidate[0][0]] - ) - + consistent_pred_res_tmp.append([remapping[symbol] for symbol in candidate[0][0]]) + if len(consistent_idx_tmp) > len(consistent_idx): consistent_idx = consistent_idx_tmp consistent_pred_res = consistent_pred_res_tmp mapping = m - + if len(consistent_idx) == 0: return 0, 0, None - - INFO("Consistent predict results are:", map_res(consistent_pred_res, mapping)) - INFO("Train pool size is:", len(flatten(consistent_pred_res))) + + INFO("Consistent predict results are: ", map_res(consistent_pred_res, mapping)) + INFO('Train pool size is:', len(flatten(consistent_pred_res))) + INFO("Start to use abduced pseudo label to train model...") model.train([X[idx] for idx in consistent_idx], consistent_pred_res) consistent_acc = len(consistent_idx) / select_num - char_acc = get_char_acc( - model, [X[idx] for idx in consistent_idx], consistent_pred_res - ) - INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) + char_acc = get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res) + INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) return consistent_acc, char_acc, mapping -def get_rules_from_data( - model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim -): +def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim): rules = [] for _ in range(logic_output_dim): while True: @@ -194,8 +207,8 @@ def get_rules_from_data( X = [] for idx in select_idx: X.append(train_X_true[idx]) - pred_res = model.predict(X)["cls"] - pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] + original_pred_res = model.predict(X)['cls'] + pred_res = map_res(original_pred_res, mapping) consistent_idx = [] consistent_pred_res = [] @@ -208,17 +221,16 @@ def get_rules_from_data( rule = abducer.abduce_rules(consistent_pred_res) if rule != None: break - rules.append(rule) + INFO('Learned rules from data:') - for rule in rules: - INFO(rule) + INFO(rules) return rules def get_mlp_vector(model, abducer, mapping, X, rules): - pred_res = model.predict([X])["cls"] - pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res] + original_pred_res = model.predict([X])['cls'] + pred_res = map_res(original_pred_res, mapping) vector = [] for rule in rules: if abducer.kb.consist_rule(pred_res, rule): @@ -241,26 +253,13 @@ def get_mlp_data(model, abducer, mapping, X_true, X_false, rules): return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64) -def validation( - model, - abducer, - mapping, - train_X_true, - train_X_false, - val_X_true, - val_X_false, - recorder, -): +def validation(model, abducer, mapping, train_X_true, train_X_false, val_X_true, val_X_false): INFO("Now checking if we can go to next course") samples_per_rule = 3 logic_output_dim = 50 - rules = get_rules_from_data( - model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim - ) + rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim) - mlp_train_vectors, mlp_train_labels = get_mlp_data( - model, abducer, mapping, train_X_true, train_X_false, rules - ) + mlp_train_vectors, mlp_train_labels = get_mlp_data(model, abducer, mapping, train_X_true, train_X_false, rules) idx = np.array(list(range(len(mlp_train_labels)))) np.random.shuffle(idx) @@ -276,28 +275,18 @@ def validation( criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - mlp_model = BasicModel( - mlp, - criterion, - optimizer, - device, - batch_size=128, - num_epochs=60, - recorder=recorder, - ) + + mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60) mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels) mlp_train_data_loader = torch.utils.data.DataLoader( mlp_train_data, batch_size=128, - shuffle=True, + shuffle=True ) loss = mlp_model.fit(mlp_train_data_loader) INFO("mlp training loss is %f" % loss) - - mlp_val_vectors, mlp_val_labels = get_mlp_data( - model, abducer, mapping, val_X_true, val_X_false, rules - ) + + mlp_val_vectors, mlp_val_labels = get_mlp_data(model, abducer, mapping, val_X_true, val_X_false, rules) # Get MLP validation result mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels) @@ -306,7 +295,6 @@ def validation( batch_size=64, shuffle=True, ) - accuracy = mlp_model.val(mlp_val_data_loader) if accuracy > best_accuracy: @@ -314,41 +302,33 @@ def validation( return best_accuracy -def train_with_rule( - model, - abducer, - train_data, - val_data, - select_num=10, - recorder=None -): +def train_with_rule(model, abducer, train_data, val_data, epochs=50, select_num=10, verbose=-1): train_X = train_data val_X = val_data min_len = 5 - max_len = 8 + max_len = 18 # Start training / for each length of equations for equation_len in range(min_len, max_len): - INFO("============== equation_len:%d ================" % (equation_len)) + INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1)) train_X_true = train_X[1][equation_len] train_X_false = train_X[0][equation_len] val_X_true = val_X[1][equation_len] val_X_false = val_X[0][equation_len] + train_X_true.extend(train_X[1][equation_len + 1]) train_X_false.extend(train_X[0][equation_len + 1]) val_X_true.extend(val_X[1][equation_len + 1]) val_X_false.extend(val_X[0][equation_len + 1]) - condition_cnt = 0 + condition_cnt = 0 while True: # Abduce and train NN - consistent_acc, char_acc, mapping = abduce_and_train( - model, abducer, train_X_true, select_num - ) + consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, train_X_true, select_num) if consistent_acc == 0: continue - + # Test if we can use mlp to evaluate if consistent_acc >= 0.9 and char_acc >= 0.9: condition_cnt += 1 @@ -357,37 +337,18 @@ def train_with_rule( # The condition has been satisfied continuously five times if condition_cnt >= 5: - best_accuracy = validation( - model, - abducer, - mapping, - train_X_true, - train_X_false, - val_X_true, - val_X_false, - recorder, - ) - - INFO("best_accuracy is %f" % (best_accuracy)) - + # Try to abduce rules in `validation` + best_accuracy = validation(model, abducer, mapping, train_X_true, train_X_false, val_X_true, val_X_false) + INFO('best_accuracy is %f' %(best_accuracy)) # decide next course or restart - if best_accuracy > 0.86: - torch.save( - model.cls_list[0].model.state_dict(), - "./weights/train_weights_%d.pth" % equation_len, - ) + if best_accuracy > 0.85: + torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) break else: if equation_len == min_len: - model.cls_list[0].model.load_state_dict( - torch.load("./weights/pretrain_weights.pth") - ) + model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth")) else: - model.cls_list[0].model.load_state_dict( - torch.load( - "./weights/train_weights_%d.pth" % (equation_len - 1) - ) - ) + model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1))) condition_cnt = 0 return model