| @@ -11,28 +11,25 @@ | |||
| # ================================================================# | |||
| from utils.plog import logger, INFO | |||
| from utils.utils import copy_state_dict | |||
| import torch.nn as nn | |||
| import torch | |||
| from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder | |||
| from models.nn import LeNet5, SymbolNet | |||
| from models.basic_model import BasicModel, BasicDataset | |||
| from models.wabl_models import DecisionTree, WABLBasicModel | |||
| from multiprocessing import Pool | |||
| import os | |||
| from abducer.abducer_base import AbducerBase | |||
| from abducer.kb import add_KB, HWF_KB, HED_prolog_KB | |||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
| from datasets.hwf.get_hwf import get_hwf | |||
| from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data | |||
| import framework_hed | |||
| from datasets.hed.get_hed import get_hed, split_equation | |||
| import framework | |||
| def run_test(): | |||
| # kb = add_KB(True) | |||
| # kb = HWF_KB(True) | |||
| # abducer = AbducerBase(kb) | |||
| @@ -46,25 +43,10 @@ def run_test(): | |||
| test_data = get_hed(train=False) | |||
| # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) | |||
| cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) | |||
| cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| if not os.path.exists("./weights/pretrain_weights.pth"): | |||
| INFO("Pretrain Start") | |||
| pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11']) | |||
| pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) | |||
| criterion = nn.MSELoss() | |||
| optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6) | |||
| pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder) | |||
| framework_hed.pretrain(pretrain_model, pretrain_data) | |||
| torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth") | |||
| cls.load_state_dict(cls_autoencoder.base_model.state_dict()) | |||
| else: | |||
| cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) | |||
| framework.hed_pretrain(kb, cls, recorder) | |||
| criterion = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6) | |||
| @@ -80,8 +62,8 @@ def run_test(): | |||
| # train_data = get_hwf(train = True, get_pseudo_label = True) | |||
| # test_data = get_hwf(train = False, get_pseudo_label = True) | |||
| framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1) | |||
| # recorder.print(res) | |||
| model, mapping = framework.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8) | |||
| framework.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8) | |||
| recorder.dump() | |||
| return True | |||