|
|
|
@@ -33,48 +33,41 @@ from abl import framework_hed |
|
|
|
def run_test(): |
|
|
|
|
|
|
|
# kb = add_KB() |
|
|
|
# kb = HWF_KB() |
|
|
|
# abducer = AbducerBase(kb, 'confidence') |
|
|
|
kb = HWF_KB(GKB_flag=True) |
|
|
|
abducer = AbducerBase(kb, 'confidence') |
|
|
|
|
|
|
|
kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') |
|
|
|
abducer = HED_Abducer(kb) |
|
|
|
# kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') |
|
|
|
# abducer = HED_Abducer(kb) |
|
|
|
|
|
|
|
recorder = logger() |
|
|
|
|
|
|
|
total_train_data = get_hed(train=True) |
|
|
|
train_data, val_data = split_equation(total_train_data, 3, 1) |
|
|
|
test_data = get_hed(train=False) |
|
|
|
# total_train_data = get_hed(train=True) |
|
|
|
# train_data, val_data = split_equation(total_train_data, 3, 1) |
|
|
|
# test_data = get_hed(train=False) |
|
|
|
|
|
|
|
# train_data = get_mnist_add(train = True, get_pseudo_label = True) |
|
|
|
# test_data = get_mnist_add(train = False, get_pseudo_label = True) |
|
|
|
# train_data = get_mnist_add(train=True, get_pseudo_label=True) |
|
|
|
# test_data = get_mnist_add(train=False, get_pseudo_label=True) |
|
|
|
|
|
|
|
# train_data = get_hwf(train = True, get_pseudo_label = True) |
|
|
|
# test_data = get_hwf(train = False, get_pseudo_label = True) |
|
|
|
train_data = get_hwf(train=True, get_pseudo_label=True) |
|
|
|
test_data = get_hwf(train=False, get_pseudo_label=True) |
|
|
|
|
|
|
|
# cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) |
|
|
|
cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) |
|
|
|
cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
framework_hed.hed_pretrain(kb, cls, recorder) |
|
|
|
# framework_hed.hed_pretrain(kb, cls, recorder) |
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) |
|
|
|
# optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) |
|
|
|
|
|
|
|
|
|
|
|
base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder) |
|
|
|
base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder) |
|
|
|
model = WABLBasicModel(base_model, kb.pseudo_label_list) |
|
|
|
|
|
|
|
# train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) |
|
|
|
# test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) |
|
|
|
|
|
|
|
# train_data = get_hwf(train = True, get_pseudo_label = True) |
|
|
|
# test_data = get_hwf(train = False, get_pseudo_label = True) |
|
|
|
|
|
|
|
model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) |
|
|
|
framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) |
|
|
|
# model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) |
|
|
|
# framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) |
|
|
|
|
|
|
|
# framework_hed.train(model, abducer, train_data, test_data, sample_num=10000, verbose=1) |
|
|
|
framework_hed.train(model, abducer, train_data, test_data, sample_num=-1, verbose=1) |
|
|
|
|
|
|
|
recorder.dump() |
|
|
|
return True |
|
|
|
|