Browse Source

Update example.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
4170a1fe7d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      example.py

+ 4
- 4
example.py View File

@@ -24,7 +24,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation from datasets.hed.get_hed import get_hed, split_equation
import framework
import framework_hed




def run_test(): def run_test():
@@ -46,7 +46,7 @@ def run_test():
cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
framework.hed_pretrain(kb, cls, recorder)
framework_hed.hed_pretrain(kb, cls, recorder)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
@@ -62,8 +62,8 @@ def run_test():
# train_data = get_hwf(train = True, get_pseudo_label = True) # train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True) # test_data = get_hwf(train = False, get_pseudo_label = True)


model, mapping = framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)


recorder.dump() recorder.dump()
return True return True


Loading…
Cancel
Save