|
|
|
@@ -26,7 +26,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB |
|
|
|
from datasets.mnist_add.get_mnist_add import get_mnist_add |
|
|
|
from datasets.hwf.get_hwf import get_hwf |
|
|
|
from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data |
|
|
|
import framework |
|
|
|
import framework_hed |
|
|
|
|
|
|
|
|
|
|
|
def run_test(): |
|
|
|
@@ -59,7 +59,7 @@ def run_test(): |
|
|
|
optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6) |
|
|
|
|
|
|
|
pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder) |
|
|
|
framework.pretrain(pretrain_model, pretrain_data) |
|
|
|
framework_hed.pretrain(pretrain_model, pretrain_data) |
|
|
|
torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth") |
|
|
|
cls.load_state_dict(cls_autoencoder.base_model.state_dict()) |
|
|
|
|
|
|
|
@@ -80,7 +80,7 @@ def run_test(): |
|
|
|
# train_data = get_hwf(train = True, get_pseudo_label = True) |
|
|
|
# test_data = get_hwf(train = False, get_pseudo_label = True) |
|
|
|
|
|
|
|
framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1) |
|
|
|
framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1) |
|
|
|
# recorder.print(res) |
|
|
|
|
|
|
|
recorder.dump() |
|
|
|
|