Browse Source

Update example.py

pull/3/head
troyyyyy GitHub 3 years ago
parent
commit
63da2da61d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 24 deletions
  1. +6
    -24
      example.py

+ 6
- 24
example.py View File

@@ -11,28 +11,25 @@
# ================================================================#

from utils.plog import logger, INFO
from utils.utils import copy_state_dict
import torch.nn as nn
import torch

from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder
from models.nn import LeNet5, SymbolNet
from models.basic_model import BasicModel, BasicDataset
from models.wabl_models import DecisionTree, WABLBasicModel

from multiprocessing import Pool
import os
from abducer.abducer_base import AbducerBase
from abducer.kb import add_KB, HWF_KB, HED_prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data
import framework_hed
from datasets.hed.get_hed import get_hed, split_equation
import framework


def run_test():

# kb = add_KB(True)

# kb = HWF_KB(True)
# abducer = AbducerBase(kb)

@@ -46,25 +43,10 @@ def run_test():
test_data = get_hed(train=False)
# cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not os.path.exists("./weights/pretrain_weights.pth"):
INFO("Pretrain Start")
pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11'])
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)

pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
framework_hed.pretrain(pretrain_model, pretrain_data)
torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
cls.load_state_dict(cls_autoencoder.base_model.state_dict())
else:
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
framework.hed_pretrain(kb, cls, recorder)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
@@ -80,8 +62,8 @@ def run_test():
# train_data = get_hwf(train = True, get_pseudo_label = True)
# test_data = get_hwf(train = False, get_pseudo_label = True)

framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1)
# recorder.print(res)
model, mapping = framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)

recorder.dump()
return True


Loading…
Cancel
Save