| @@ -24,7 +24,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB | |||||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | from datasets.mnist_add.get_mnist_add import get_mnist_add | ||||
| from datasets.hwf.get_hwf import get_hwf | from datasets.hwf.get_hwf import get_hwf | ||||
| from datasets.hed.get_hed import get_hed, split_equation | from datasets.hed.get_hed import get_hed, split_equation | ||||
| import framework | |||||
| import framework_hed | |||||
| def run_test(): | def run_test(): | ||||
| @@ -46,7 +46,7 @@ def run_test(): | |||||
| cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) | cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) | ||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
| framework.hed_pretrain(kb, cls, recorder) | |||||
| framework_hed.hed_pretrain(kb, cls, recorder) | |||||
| criterion = nn.CrossEntropyLoss() | criterion = nn.CrossEntropyLoss() | ||||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) | optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) | ||||
| @@ -62,8 +62,8 @@ def run_test(): | |||||
| # train_data = get_hwf(train = True, get_pseudo_label = True) | # train_data = get_hwf(train = True, get_pseudo_label = True) | ||||
| # test_data = get_hwf(train = False, get_pseudo_label = True) | # test_data = get_hwf(train = False, get_pseudo_label = True) | ||||
| model, mapping = framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) | |||||
| framework.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) | |||||
| model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) | |||||
| framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) | |||||
| recorder.dump() | recorder.dump() | ||||
| return True | return True | ||||