From 4170a1fe7d7ff97c572b00a190fd422c388bfd55 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Wed, 22 Feb 2023 10:09:16 +0800 Subject: [PATCH] Update example.py --- example.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example.py b/example.py index 72c0bca..30c50fc 100644 --- a/example.py +++ b/example.py @@ -24,7 +24,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, split_equation -import framework +import framework_hed def run_test(): @@ -46,7 +46,7 @@ def run_test(): cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - framework.hed_pretrain(kb, cls, recorder) + framework_hed.hed_pretrain(kb, cls, recorder) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) @@ -62,8 +62,8 @@ def run_test(): # train_data = get_hwf(train = True, get_pseudo_label = True) # test_data = get_hwf(train = False, get_pseudo_label = True) - model, mapping = framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) - framework.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) + model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) + framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) recorder.dump() return True