Browse Source

Update example.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
85ba22bc27
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      example.py

+ 3
- 3
example.py View File

@@ -26,7 +26,7 @@ from abducer.kb import add_KB, HWF_KB, HED_prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data
import framework
import framework_hed


def run_test():
@@ -59,7 +59,7 @@ def run_test():
optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)

pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
framework.pretrain(pretrain_model, pretrain_data)
framework_hed.pretrain(pretrain_model, pretrain_data)
torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
cls.load_state_dict(cls_autoencoder.base_model.state_dict())
@@ -80,7 +80,7 @@ def run_test():
# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)

framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1)
framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1)
# recorder.print(res)

recorder.dump()


Loading…
Cancel
Save