From e0d4abbd3bbb09207f9e76a7f08e417148f7460a Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Thu, 16 Mar 2023 09:37:41 +0800 Subject: [PATCH] Save option for mnist_add --- examples/example.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/example.py b/examples/example.py index c74f1b3..39e8c3f 100644 --- a/examples/example.py +++ b/examples/example.py @@ -32,31 +32,37 @@ from abl import framework_hed def run_test(): - # kb = add_KB(True) + kb = add_KB() # kb = HWF_KB(True) - # abducer = AbducerBase(kb) + abducer = AbducerBase(kb, 'confidence') - kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') - abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) + # kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') + # abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() - total_train_data = get_hed(train=True) - train_data, val_data = split_equation(total_train_data, 3, 1) - test_data = get_hed(train=False) + # total_train_data = get_hed(train=True) + # train_data, val_data = split_equation(total_train_data, 3, 1) + # test_data = get_hed(train=False) - # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) - cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) + train_data = get_mnist_add(train = True, get_pseudo_label = True) + test_data = get_mnist_add(train = False, get_pseudo_label = True) + + # train_data = get_hwf(train = True, get_pseudo_label = True) + # test_data = get_hwf(train = False, get_pseudo_label = True) + + cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) + # cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - framework_hed.hed_pretrain(kb, cls, recorder) + # framework_hed.hed_pretrain(kb, cls, recorder) criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) - # optimizer = torch.optim.Adam(cls.parameters(), lr=0.00001, betas=(0.9, 0.99)) + # optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) + optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) - base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder) + base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder) model = WABLBasicModel(base_model, kb.pseudo_label_list) # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True) @@ -65,8 +71,10 @@ def run_test(): # train_data = get_hwf(train = True, get_pseudo_label = True) # test_data = get_hwf(train = False, get_pseudo_label = True) - model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) - framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) + # model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) + # framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) + + framework_hed.train(model, abducer, train_data, test_data, sample_num=10000, verbose=1) recorder.dump() return True