# coding: utf-8 # ================================================================# # Copyright (C) 2021 Freecss All rights reserved. # # File Name :share_example.py # Author :freecss # Email :karlfreecss@gmail.com # Created Date :2021/06/07 # Description : # # ================================================================# from utils.plog import logger, INFO import framework_hed import torch.nn as nn import torch from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder from models.basic_model import BasicModel, BasicDataset from models.wabl_models import WABLBasicModel from multiprocessing import Pool import os from abducer.abducer_base import AbducerBase from abducer.kb import add_KB, HWF_KB, HED_prolog_KB from datasets.mnist_add.get_mnist_add import get_mnist_add from datasets.hwf.get_hwf import get_hwf from datasets.hed.get_hed import get_hed, get_pretrain_data, split_equation def run_test(): # kb = add_KB(True) # kb = hwf_KB(True) # abducer = AbducerBase(kb) kb = HED_prolog_KB() abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) recorder = logger() # train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True) # test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True) # train_data = get_hwf(train=True, get_pseudo_label=True) # test_data = get_hwf(train=False, get_pseudo_label=True) total_train_data = get_hed(train=True) train_data, val_data = split_equation(total_train_data, 3, 1) test_data = get_hed(train=False) # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:])) cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list)) cls = SymbolNet(num_classes=len(kb.pseudo_label_list)) if not os.path.exists("./weights/pretrain_weights.pth"): pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) pretrain_data_loader = torch.utils.data.DataLoader( pretrain_data, batch_size=64, shuffle=True, ) framework_hed.pretrain(cls_autoencoder, pretrain_data_loader, recorder) torch.save( cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" ) cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop( cls.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 ) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") base_model = BasicModel( cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder, ) model = WABLBasicModel(base_model, kb.pseudo_label_list) res = framework_hed.train_with_rule( model, abducer, train_data, val_data, recorder=recorder ) INFO(res) recorder.dump() return True if __name__ == "__main__": run_test()