Browse Source

Examples for HWF

pull/3/head
troyyyyy 3 years ago
parent
commit
953a88d15a
2 changed files with 21 additions and 28 deletions
  1. +4
    -4
      examples/datasets/hwf/get_hwf.py
  2. +17
    -24
      examples/example.py

+ 4
- 4
examples/datasets/hwf/get_hwf.py View File

@@ -7,7 +7,7 @@ img_transform = transforms.Compose([
transforms.Normalize((0.5,), (1,))
])

def get_data(file, get_pseudo_label, precision_num = 2):
def get_data(file, get_pseudo_label):
X = []
if get_pseudo_label:
Z = []
@@ -27,20 +27,20 @@ def get_data(file, get_pseudo_label, precision_num = 2):
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
Y.append(round(data[idx]['res'], precision_num))
Y.append(data[idx]['res'])
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_hwf(train = True, get_pseudo_label = False, precision_num = 2):
def get_hwf(train = True, get_pseudo_label = False):
if(train):
file = './datasets/hwf/data/expr_train.json'
else:
file = './datasets/hwf/data/expr_test.json'
return get_data(file, get_pseudo_label, precision_num)
return get_data(file, get_pseudo_label)

if __name__ == "__main__":
train_X, train_Y = get_hwf(train = True)


+ 17
- 24
examples/example.py View File

@@ -33,48 +33,41 @@ from abl import framework_hed
def run_test():

# kb = add_KB()
# kb = HWF_KB()
# abducer = AbducerBase(kb, 'confidence')
kb = HWF_KB(GKB_flag=True)
abducer = AbducerBase(kb, 'confidence')

kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = HED_Abducer(kb)
# kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
# abducer = HED_Abducer(kb)

recorder = logger()

total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)
# total_train_data = get_hed(train=True)
# train_data, val_data = split_equation(total_train_data, 3, 1)
# test_data = get_hed(train=False)
# train_data = get_mnist_add(train = True, get_pseudo_label = True)
# test_data = get_mnist_add(train = False, get_pseudo_label = True)
# train_data = get_mnist_add(train=True, get_pseudo_label=True)
# test_data = get_mnist_add(train=False, get_pseudo_label=True)

# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)
train_data = get_hwf(train=True, get_pseudo_label=True)
test_data = get_hwf(train=False, get_pseudo_label=True)
# cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
framework_hed.hed_pretrain(kb, cls, recorder)
# framework_hed.hed_pretrain(kb, cls, recorder)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
# optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder)
base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder)
model = WABLBasicModel(base_model, kb.pseudo_label_list)
# train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True)
# test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True)

# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)

model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
# model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
# framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
# framework_hed.train(model, abducer, train_data, test_data, sample_num=10000, verbose=1)
framework_hed.train(model, abducer, train_data, test_data, sample_num=-1, verbose=1)

recorder.dump()
return True


Loading…
Cancel
Save