Browse Source

Save option for mnist_add

pull/3/head
troyyyyy 3 years ago
parent
commit
e0d4abbd3b
1 changed files with 23 additions and 15 deletions
  1. +23
    -15
      examples/example.py

+ 23
- 15
examples/example.py View File

@@ -32,31 +32,37 @@ from abl import framework_hed

def run_test():

# kb = add_KB(True)
kb = add_KB()
# kb = HWF_KB(True)
# abducer = AbducerBase(kb)
abducer = AbducerBase(kb, 'confidence')

kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)
# kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
# abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)

recorder = logger()

total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)
# total_train_data = get_hed(train=True)
# train_data, val_data = split_equation(total_train_data, 3, 1)
# test_data = get_hed(train=False)
# cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
train_data = get_mnist_add(train = True, get_pseudo_label = True)
test_data = get_mnist_add(train = False, get_pseudo_label = True)

# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)
cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
# cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
framework_hed.hed_pretrain(kb, cls, recorder)
# framework_hed.hed_pretrain(kb, cls, recorder)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
# optimizer = torch.optim.Adam(cls.parameters(), lr=0.00001, betas=(0.9, 0.99))
# optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder)
base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder)
model = WABLBasicModel(base_model, kb.pseudo_label_list)
# train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True)
@@ -65,8 +71,10 @@ def run_test():
# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)

model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
# model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
# framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
framework_hed.train(model, abducer, train_data, test_data, sample_num=10000, verbose=1)

recorder.dump()
return True


Loading…
Cancel
Save