Browse Source

Update framework_hed.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
7a424c8eac
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 59 additions and 117 deletions
  1. +59
    -117
      framework_hed.py

+ 59
- 117
framework_hed.py View File

@@ -119,7 +119,7 @@ def hed_pretrain(kb, cls, recorder):
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))


def get_char_acc(model, X, consistent_pred_res, mapping):
def _get_char_acc(model, X, consistent_pred_res, mapping):
original_pred_res = model.predict(X)['cls']
pred_res = flatten(mapping_res(original_pred_res, mapping))
INFO('Current model\'s output: ', pred_res)
@@ -177,22 +177,29 @@ def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))

consistent_acc = len(consistent_idx) / select_num
char_acc = get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
return consistent_acc, char_acc, mapping

def _remove_duplicate_rule(rule_dict):
add_nums_dict = {}
for r in list(rule_dict):
add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
if add_nums in add_nums_dict:
old_r = add_nums_dict[add_nums]
if rule_dict[r] >= rule_dict[old_r]:
rule_dict.pop(old_r)
add_nums_dict[add_nums] = r
else:
rule_dict.pop(r)
else:
add_nums_dict[add_nums] = r
return list(rule_dict)

def output_rules(rules):
all_rule_dict = {}
for rule in rules:
for r in rule:
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items()}# if cnt >= 5}
return rule_dict

def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim):
def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num):
rules = []
for _ in range(logic_output_dim):
for _ in range(samples_num):
while True:
select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
X = []
@@ -213,83 +220,32 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule,
if rule != None:
break
rules.append(rule)
return rules


def get_mlp_vector(model, abducer, mapping, X, rules):
original_pred_res = model.predict([X])['cls']
pred_res = flatten(mapping_res(original_pred_res, mapping))
vector = []
for rule in rules:
if abducer.kb.consist_rule(pred_res, rule):
vector.append(1)
else:
vector.append(0)
return vector


def get_mlp_data(model, abducer, mapping, X_true, X_false, rules):
mlp_vectors = []
mlp_labels = []
for X in X_true:
mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules))
mlp_labels.append(1)
for X in X_false:
mlp_vectors.append(get_mlp_vector(model, abducer, mapping, X, rules))
mlp_labels.append(0)
return np.array(mlp_vectors, dtype=np.float32), np.array(mlp_labels, dtype=np.int64)

def get_all_mlp_data(model, abducer, mapping, X_true, X_false, rules, min_len, max_len):
for equation_len in range(min_len, max_len + 1):
mlp_vectors, mlp_labels = get_mlp_data(model, abducer, mapping, X_true[equation_len], X_false[equation_len], rules)
if equation_len == min_len:
all_mlp_vectors = mlp_vectors
all_mlp_labels = mlp_labels
else:
all_mlp_vectors = np.concatenate((all_mlp_vectors, mlp_vectors))
all_mlp_labels = np.concatenate((all_mlp_labels, mlp_labels))
return all_mlp_vectors, all_mlp_labels


def validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false):
mlp_train_vectors, mlp_train_labels = get_mlp_data(model, abducer, mapping, train_X_true, train_X_false, rules)
mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels)
mlp_val_vectors, mlp_val_labels = get_mlp_data(model, abducer, mapping, val_X_true, val_X_false, rules)
mlp_val_data = BasicDataset(mlp_val_vectors, mlp_val_labels)
all_rule_dict = {}
for rule in rules:
for r in rule:
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
rules = _remove_duplicate_rule(rule_dict)
best_accuracy = 0
# Try three times to find the best mlp
for _ in range(3):
INFO("Training mlp...")
mlp = MLP(input_dim=logic_output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
loss = mlp_model.fit(mlp_train_data_loader)
INFO("mlp training final loss is %f" % loss)
mlp_val_data_loader = torch.utils.data.DataLoader(mlp_val_data, batch_size=64, shuffle=True)
accuracy = mlp_model.val(mlp_val_data_loader)

if accuracy > best_accuracy:
best_accuracy = accuracy
return best_accuracy


return rules


def _get_consist_rule_acc(model, abducer, mapping, rules, X):
cnt = 0
for x in X:
original_pred_res = model.predict([x])['cls']
pred_res = flatten(mapping_res(original_pred_res, mapping))
if abducer.kb.consist_rule(pred_res, rules):
cnt += 1
return cnt / len(X)


def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8):
train_X = train_data
val_X = val_data
logic_output_dim = 50
samples_num = 50
samples_per_rule = 3

# Start training / for each length of equations
@@ -324,12 +280,15 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len
# The condition has been satisfied continuously five times
if condition_cnt >= 5:
INFO("Now checking if we can go to next course")
rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, logic_output_dim)
INFO('Learned rules from data:', output_rules(rules))
best_accuracy = validation(model, abducer, mapping, logic_output_dim, rules, train_X_true, train_X_false, val_X_true, val_X_false)
INFO('best_accuracy is %f\n' %(best_accuracy))
rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num)
INFO('Learned rules from data:', rules)
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true)
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false)
INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc))
# decide next course or restart
if best_accuracy > 0.88:
if true_consist_rule_acc > 0.9 and false_consist_rule_acc < 0.1:
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
break
else:
@@ -347,50 +306,33 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=
test_X = test_data
# Calcualte how many equations should be selected in each length
# for each length, there are select_equation_cnt[equation_len] rules
# for each length, there are equation_samples_num[equation_len] rules
print("Now begin to train final mlp model")
select_equation_cnt = []
equation_samples_num = []
len_cnt = max_len - min_len + 1
logic_output_dim = 50
select_equation_cnt += [0] * min_len
if logic_output_dim % len_cnt == 0:
select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt
samples_num = 50
equation_samples_num += [0] * min_len
if samples_num % len_cnt == 0:
equation_samples_num += [samples_num // len_cnt] * len_cnt
else:
select_equation_cnt += [logic_output_dim // len_cnt] * len_cnt
select_equation_cnt[-1] += logic_output_dim % len_cnt
assert sum(select_equation_cnt) == logic_output_dim
equation_samples_num += [samples_num // len_cnt] * len_cnt
equation_samples_num[-1] += samples_num % len_cnt
assert sum(equation_samples_num) == samples_num

# Abduce rules
rules = []
samples_per_rule = 3
for equation_len in range(min_len, max_len + 1):
equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, select_equation_cnt[equation_len])
equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len])
rules.extend(equation_rules)
INFO('Learned rules from data:', output_rules(rules))
mlp_train_vectors, mlp_train_labels = get_all_mlp_data(model, abducer, mapping, train_X[1], train_X[0], rules, min_len, max_len)
mlp_train_data = BasicDataset(mlp_train_vectors, mlp_train_labels)
# Try three times to find the best mlp
for _ in range(3):
mlp = MLP(input_dim=logic_output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, betas=(0.9, 0.999))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlp_model = BasicModel(mlp, criterion, optimizer, device, batch_size=128, num_epochs=100)
mlp_train_data_loader = torch.utils.data.DataLoader(mlp_train_data, batch_size=128, shuffle=True)
loss = mlp_model.fit(mlp_train_data_loader)
INFO("mlp training final loss is %f" % loss)
rules = list(set(rules))
INFO('Learned rules from data:', rules)
for equation_len in range(5, 27):
mlp_test_vectors, mlp_test_labels = get_mlp_data(model, abducer, mapping, test_X[1][equation_len], test_X[0][equation_len], rules)
mlp_test_data = BasicDataset(mlp_test_vectors, mlp_test_labels)
mlp_test_data_loader = torch.utils.data.DataLoader(mlp_test_data, batch_size=64, shuffle=True)
accuracy = mlp_model.val(mlp_test_data_loader)
INFO("The accuracy of testing length %d equations is: %f" % (equation_len, accuracy))
INFO("\n")
for equation_len in range(5, 27):
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len])
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len])
INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc))

if __name__ == "__main__":
pass

Loading…
Cancel
Save