From 85ba22bc27eca089cc308542612264ca5a4a47df Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 20 Feb 2023 14:37:20 +0800 Subject: [PATCH] Update example.py --- example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example.py b/example.py index 29c7629..a81b51d 100644 --- a/example.py +++ b/example.py @@ -26,7 +26,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data -import framework +import framework_hed def run_test(): @@ -59,7 +59,7 @@ def run_test(): optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6) pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder) - framework.pretrain(pretrain_model, pretrain_data) + framework_hed.pretrain(pretrain_model, pretrain_data) torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth") cls.load_state_dict(cls_autoencoder.base_model.state_dict()) @@ -80,7 +80,7 @@ def run_test(): # train_data = get_hwf(train = True, get_pseudo_label = True) # test_data = get_hwf(train = False, get_pseudo_label = True) - framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1) + framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1) # recorder.print(res) recorder.dump()