| @@ -0,0 +1,85 @@ | |||||
| # coding: utf-8 | |||||
| #================================================================# | |||||
| # Copyright (C) 2021 Freecss All rights reserved. | |||||
| # | |||||
| # File Name :share_example.py | |||||
| # Author :freecss | |||||
| # Email :karlfreecss@gmail.com | |||||
| # Created Date :2021/06/07 | |||||
| # Description : | |||||
| # | |||||
| #================================================================# | |||||
| from utils.plog import logger | |||||
| from models.wabl_models import DecisionTree, KNN | |||||
| import pickle as pk | |||||
| import numpy as np | |||||
| import time | |||||
| import framework | |||||
| import utils.plog as plog | |||||
| import torch.nn as nn | |||||
| import torch | |||||
| from models.lenet5 import LeNet5 | |||||
| from models.basic_model import BasicModel | |||||
| from models.wabl_models import MyModel | |||||
| from multiprocessing import Pool | |||||
| import os | |||||
| from datasets.data_generator import generate_data_via_codes, code_generator | |||||
| from collections import defaultdict | |||||
| from abducer.abducer_base import AbducerBase | |||||
| from abducer.kb import add_KB, hwf_KB | |||||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||||
| from datasets.hwf.get_hwf import get_hwf | |||||
| class Params: | |||||
| imgH = 45 | |||||
| imgW = 45 | |||||
| keep_ratio = True | |||||
| saveInterval = 10 | |||||
| batchSize = 16 | |||||
| workers = 16 | |||||
| n_epoch = 10 | |||||
| stop_loss = None | |||||
| def run_test(): | |||||
| result_dir = 'results' | |||||
| recorder_file_path = f"{result_dir}/1116.pk"# | |||||
| # words = code_generator(code_len, code_num, letter_num) | |||||
| kb = add_KB() | |||||
| abducer = AbducerBase(kb) | |||||
| recorder = logger() | |||||
| recorder.set_savefile("test.log") | |||||
| train_X, train_Y, test_X, test_Y = get_mnist_add() | |||||
| # train_X, train_Y, test_X, test_Y = get_hwf() | |||||
| recorder = plog.ResultRecorder() | |||||
| cls = LeNet5() | |||||
| criterion = nn.CrossEntropyLoss(size_average=True) | |||||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||||
| sign_list = list(range(10)) | |||||
| base_model = BasicModel(cls, criterion, optimizer, device, Params(), sign_list, recorder=recorder) | |||||
| model = MyModel(base_model) | |||||
| res = framework.train(model, abducer, train_X, train_Y, logic_forward = kb.logic_forward, sample_num = 10000, verbose = 1) | |||||
| print(res) | |||||
| recorder.dump(open(recorder_file_path, "wb")) | |||||
| return True | |||||
| if __name__ == "__main__": | |||||
| os.system("mkdir results") | |||||
| run_test() | |||||