From 953a88d15a90adca9d970c08f04ccf9b92d198f9 Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Fri, 17 Mar 2023 14:25:35 +0800 Subject: [PATCH] Examples for HWF --- examples/datasets/hwf/get_hwf.py | 8 +++---- examples/example.py | 41 +++++++++++++------------------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/examples/datasets/hwf/get_hwf.py b/examples/datasets/hwf/get_hwf.py index b478a15..87da5cf 100644 --- a/examples/datasets/hwf/get_hwf.py +++ b/examples/datasets/hwf/get_hwf.py @@ -7,7 +7,7 @@ img_transform = transforms.Compose([ transforms.Normalize((0.5,), (1,)) ]) -def get_data(file, get_pseudo_label, precision_num = 2): +def get_data(file, get_pseudo_label): X = [] if get_pseudo_label: Z = [] @@ -27,20 +27,20 @@ def get_data(file, get_pseudo_label, precision_num = 2): X.append(imgs) if get_pseudo_label: Z.append(imgs_pseudo_label) - Y.append(round(data[idx]['res'], precision_num)) + Y.append(data[idx]['res']) if get_pseudo_label: return X, Z, Y else: return X, None, Y -def get_hwf(train = True, get_pseudo_label = False, precision_num = 2): +def get_hwf(train = True, get_pseudo_label = False): if(train): file = './datasets/hwf/data/expr_train.json' else: file = './datasets/hwf/data/expr_test.json' - return get_data(file, get_pseudo_label, precision_num) + return get_data(file, get_pseudo_label) if __name__ == "__main__": train_X, train_Y = get_hwf(train = True) diff --git a/examples/example.py b/examples/example.py index 0f75895..df101e2 100644 --- a/examples/example.py +++ b/examples/example.py @@ -33,48 +33,41 @@ from abl import framework_hed def run_test(): # kb = add_KB() - # kb = HWF_KB() - # abducer = AbducerBase(kb, 'confidence') + kb = HWF_KB(GKB_flag=True) + abducer = AbducerBase(kb, 'confidence') - kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') - abducer = HED_Abducer(kb) + # kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') + # abducer = HED_Abducer(kb) recorder = logger() - total_train_data = get_hed(train=True) - train_data, val_data = split_equation(total_train_data, 3, 1) - test_data = get_hed(train=False) + # total_train_data = get_hed(train=True) + # train_data, val_data = split_equation(total_train_data, 3, 1) + # test_data = get_hed(train=False) - # train_data = get_mnist_add(train = True, get_pseudo_label = True) - # test_data = get_mnist_add(train = False, get_pseudo_label = True) + # train_data = get_mnist_add(train=True, get_pseudo_label=True) + # test_data = get_mnist_add(train=False, get_pseudo_label=True) - # train_data = get_hwf(train = True, get_pseudo_label = True) - # test_data = get_hwf(train = False, get_pseudo_label = True) + train_data = get_hwf(train=True, get_pseudo_label=True) + test_data = get_hwf(train=False, get_pseudo_label=True) # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) - cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) + cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - framework_hed.hed_pretrain(kb, cls, recorder) + # framework_hed.hed_pretrain(kb, cls, recorder) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) # optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) - - base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder) + base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder) model = WABLBasicModel(base_model, kb.pseudo_label_list) - - # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) - # test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True) - - # train_data = get_hwf(train = True, get_pseudo_label = True) - # test_data = get_hwf(train = False, get_pseudo_label = True) - model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) - framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) + # model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) + # framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) - # framework_hed.train(model, abducer, train_data, test_data, sample_num=10000, verbose=1) + framework_hed.train(model, abducer, train_data, test_data, sample_num=-1, verbose=1) recorder.dump() return True