Browse Source

Update framework_hed.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
3678d2d0ad
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 128 additions and 167 deletions
  1. +128
    -167
      framework_hed.py

+ 128
- 167
framework_hed.py View File

@@ -11,26 +11,18 @@
# ================================================================#

import pickle as pk

import numpy as np

import random

random_seed = random.randint(0, 10000)
print("Selected random seed is : ", random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

from models.nn import MLP
from models.basic_model import BasicModel, BasicDataset
import torch.nn as nn
import math
import torch
import torch.nn as nn
import numpy as np

from utils.plog import INFO, DEBUG, clocker
from utils.utils import flatten, reform_idx, block_sample
from utils.utils import copy_state_dict

from sklearn.tree import DecisionTreeClassifier

from sklearn.linear_model import LogisticRegression
from models.nn import MLP
from models.basic_model import BasicModel, BasicDataset

def result_statistics(pred_Z, Z, Y, logic_forward, char_acc_flag):
result = {}
@@ -65,59 +57,89 @@ def filter_data(X, abduced_Z):
return finetune_X, finetune_Z


def pretrain(net, pretrain_data_loader, recorder):
INFO("Pretrain Start")
def train(model, abducer, train_data, test_data, epochs=50, sample_num=-1, verbose=-1):
train_X, train_Z, train_Y = train_data
test_X, test_Z, test_Y = test_data

criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(
net.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

pretrain_model = BasicModel(
net,
criterion,
optimizer,
device,
save_interval=1,
save_dir=recorder.save_dir,
num_epochs=10,
recorder=recorder,
)
# Set default parameters
if sample_num == -1:
sample_num = len(train_X)

pretrain_model.fit(pretrain_data_loader)
if verbose < 1:
verbose = epochs

char_acc_flag = 1
if train_Z == None:
char_acc_flag = 0
train_Z = [None] * len(train_X)

def get_char_acc(model, X, consistent_pred_res):
pred_res = flatten(model.predict(X)["cls"])
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum(
[
pred_res[idx] == flatten(consistent_pred_res)[idx]
for idx in range(len(pred_res))
]
) / len(pred_res)
predict_func = clocker(model.predict)
train_func = clocker(model.train)
abduce_func = clocker(abducer.batch_abduce)

for epoch_idx in range(epochs):
X, Z, Y = block_sample(train_X, train_Z, train_Y, sample_num, epoch_idx)
preds_res = predict_func(X)
# input()
abduced_Z = abduce_func(preds_res, Y)

def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
INFO("Characters and symbols size dosen't match.")
return
from itertools import permutations
if ((epoch_idx + 1) % verbose == 0) or (epoch_idx == epochs - 1):
res = result_statistics(preds_res['cls'], Z, Y, abducer.kb.logic_forward, char_acc_flag)
INFO('epoch: ', epoch_idx + 1, ' ', res)

finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:
# model.valid(finetune_X, finetune_Z)
train_func(finetune_X, finetune_Z)
else:
INFO("lack of data, all abduced failed", len(finetune_X))

mappings = []
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings
return res


def map_res(original_pred_res, m):
pred_res = [[m[symbol] for symbol in formula] for formula in original_pred_res]

def pretrain(pretrain_model, pretrain_data):
INFO("Pretrain Start")
pretrain_data_loader = torch.utils.data.DataLoader(
pretrain_data,
batch_size=64,
shuffle=True,
num_workers=2,
)
pretrain_model.fit(pretrain_data_loader)


def get_char_acc(model, X, consistent_pred_res):
print('Abduced labels: ', flatten(consistent_pred_res))
pred_res = flatten(model.predict(X)['cls'])
print('Current model\'s output:', pred_res)
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res)

def gen_mappings(chars, symbs):
n_char = len(chars)
n_symbs = len(symbs)
if n_char != n_symbs:
print('Characters and symbols size dosen\'t match.')
return
from itertools import permutations
mappings = []
# returned mappings
perms = permutations(symbs)
for p in perms:
mappings.append(dict(zip(chars, list(p))))
return mappings

def map_res(pred_res, m):
for i in range(len(pred_res)):
for j in range(len(pred_res[i])):
pred_res[i][j] = m[pred_res[i][j]]
return pred_res

def map_res(original_pred_res, m):
return [[m[symbol] for symbol in formula] for formula in original_pred_res]

def abduce_and_train(model, abducer, train_X_true, select_num):
select_idx = np.random.randint(len(train_X_true), size=select_num)
@@ -125,68 +147,59 @@ def abduce_and_train(model, abducer, train_X_true, select_num):
for idx in select_idx:
X.append(train_X_true[idx])

pred_res = model.predict(X)["cls"]
maps = gen_mappings(["+", "=", 0, 1], ["+", "=", 0, 1])
pred_res = model.predict(X)['cls']
maps = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1])
consistent_idx = []
consistent_pred_res = []
import copy

original_pred_res = copy.deepcopy(pred_res)
mapping = None
for m in maps:
pred_res = map_res(original_pred_res, m)
remapping = {}
for key, value in m.items():
remapping[value] = key

max_abduce_num = 10
solution = abducer.zoopt_get_solution(
pred_res, [1] * len(pred_res), max_abduce_num
)
max_abduce_num = 20
solution = abducer.zoopt_get_solution(pred_res, [1] * len(pred_res), max_abduce_num)
all_address_flag = reform_idx(solution, pred_res)

consistent_idx_tmp = []
consistent_pred_res_tmp = []
for idx in range(len(pred_res)):
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.kb.address_by_idx([pred_res[idx]], 1, address_idx, True)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(
[remapping[symbol] for symbol in candidate[0][0]]
)

consistent_pred_res_tmp.append([remapping[symbol] for symbol in candidate[0][0]])
if len(consistent_idx_tmp) > len(consistent_idx):
consistent_idx = consistent_idx_tmp
consistent_pred_res = consistent_pred_res_tmp
mapping = m
if len(consistent_idx) == 0:
return 0, 0, None

INFO("Consistent predict results are:", map_res(consistent_pred_res, mapping))
INFO("Train pool size is:", len(flatten(consistent_pred_res)))
INFO("Consistent predict results are: ", map_res(consistent_pred_res, mapping))
INFO('Train pool size is:', len(flatten(consistent_pred_res)))
INFO("Start to use abduced pseudo label to train model...")
model.train([X[idx] for idx in consistent_idx], consistent_pred_res)

consistent_acc = len(consistent_idx) / select_num
char_acc = get_char_acc(
model, [X[idx] for idx in consistent_idx], consistent_pred_res
)
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc))
char_acc = get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res)
INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
return consistent_acc, char_acc, mapping


def get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim
):
def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim):
rules = []
for _ in range(logic_output_dim):
while True:
@@ -194,8 +207,8 @@ def get_rules_from_data(
X = []
for idx in select_idx:
X.append(train_X_true[idx])
pred_res = model.predict(X)["cls"]
pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res]
original_pred_res = model.predict(X)['cls']
pred_res = map_res(original_pred_res, mapping)

consistent_idx = []
consistent_pred_res = []
@@ -208,17 +221,16 @@ def get_rules_from_data(
rule = abducer.abduce_rules(consistent_pred_res)
if rule != None:
break

rules.append(rule)
INFO('Learned rules from data:')
for rule in rules:
INFO(rule)
INFO(rules)
return rules


def get_mlp_vector(model, abducer, mapping, X, rules):
pred_res = model.predict([X])["cls"]
pred_res = [[mapping[symbol] for symbol in formula] for formula in pred_res]
original_pred_res = model.predict([X])['cls']
pred_res = map_res(original_pred_res, mapping)
vector = []
for rule in rules:
if abducer.kb.consist_rule(pred_res, rule):
@@ -241,26 +253,13 @@ def get_mlp_data(model, abducer, mapping, X_true, X_false, rules):
return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64)


def validation(
model,
abducer,
mapping,
train_X_true,
train_X_false,
val_X_true,
val_X_false,
recorder,
):
def validation(model, abducer, mapping, train_X_true, train_X_false, val_X_true, val_X_false):
INFO("Now checking if we can go to next course")
samples_per_rule = 3
logic_output_dim = 50
rules = get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim
)
rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim)

mlp_train_vectors, mlp_train_labels = get_mlp_data(
model, abducer, mapping, train_X_true, train_X_false, rules
)
mlp_train_vectors, mlp_train_labels = get_mlp_data(model, abducer, mapping, train_X_true, train_X_false, rules)

idx = np.array(list(range(len(mlp_train_labels))))
np.random.shuffle(idx)
@@ -276,28 +275,18 @@ def validation(
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mlp_model = BasicModel(
mlp,
criterion,
optimizer,
device,
batch_size=128,
num_epochs=60,
recorder=recorder,
)
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=60)
mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels)
mlp_train_data_loader = torch.utils.data.DataLoader(
mlp_train_data,
batch_size=128,
shuffle=True,
shuffle=True
)
loss = mlp_model.fit(mlp_train_data_loader)
INFO("mlp training loss is %f" % loss)

mlp_val_vectors, mlp_val_labels = get_mlp_data(
model, abducer, mapping, val_X_true, val_X_false, rules
)
mlp_val_vectors, mlp_val_labels = get_mlp_data(model, abducer, mapping, val_X_true, val_X_false, rules)

# Get MLP validation result
mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels)
@@ -306,7 +295,6 @@ def validation(
batch_size=64,
shuffle=True,
)

accuracy = mlp_model.val(mlp_val_data_loader)

if accuracy > best_accuracy:
@@ -314,41 +302,33 @@ def validation(
return best_accuracy


def train_with_rule(
model,
abducer,
train_data,
val_data,
select_num=10,
recorder=None
):
def train_with_rule(model, abducer, train_data, val_data, epochs=50, select_num=10, verbose=-1):
train_X = train_data
val_X = val_data

min_len = 5
max_len = 8
max_len = 18

# Start training / for each length of equations
for equation_len in range(min_len, max_len):
INFO("============== equation_len:%d ================" % (equation_len))
INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1))
train_X_true = train_X[1][equation_len]
train_X_false = train_X[0][equation_len]
val_X_true = val_X[1][equation_len]
val_X_false = val_X[0][equation_len]
train_X_true.extend(train_X[1][equation_len + 1])
train_X_false.extend(train_X[0][equation_len + 1])
val_X_true.extend(val_X[1][equation_len + 1])
val_X_false.extend(val_X[0][equation_len + 1])
condition_cnt = 0

condition_cnt = 0
while True:
# Abduce and train NN
consistent_acc, char_acc, mapping = abduce_and_train(
model, abducer, train_X_true, select_num
)
consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, train_X_true, select_num)
if consistent_acc == 0:
continue
# Test if we can use mlp to evaluate
if consistent_acc >= 0.9 and char_acc >= 0.9:
condition_cnt += 1
@@ -357,37 +337,18 @@ def train_with_rule(

# The condition has been satisfied continuously five times
if condition_cnt >= 5:
best_accuracy = validation(
model,
abducer,
mapping,
train_X_true,
train_X_false,
val_X_true,
val_X_false,
recorder,
)

INFO("best_accuracy is %f" % (best_accuracy))

# Try to abduce rules in `validation`
best_accuracy = validation(model, abducer, mapping, train_X_true, train_X_false, val_X_true, val_X_false)
INFO('best_accuracy is %f' %(best_accuracy))
# decide next course or restart
if best_accuracy > 0.86:
torch.save(
model.cls_list[0].model.state_dict(),
"./weights/train_weights_%d.pth" % equation_len,
)
if best_accuracy > 0.85:
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
break
else:
if equation_len == min_len:
model.cls_list[0].model.load_state_dict(
torch.load("./weights/pretrain_weights.pth")
)
model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
else:
model.cls_list[0].model.load_state_dict(
torch.load(
"./weights/train_weights_%d.pth" % (equation_len - 1)
)
)
model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1)))
condition_cnt = 0

return model


Loading…
Cancel
Save