| @@ -0,0 +1,85 @@ | |||
| # coding: utf-8 | |||
| #================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # File Name :share_example.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/07 | |||
| # Description : | |||
| # | |||
| #================================================================# | |||
| from utils.plog import logger | |||
| from models.wabl_models import DecisionTree, KNN | |||
| import pickle as pk | |||
| import numpy as np | |||
| import time | |||
| import framework | |||
| import utils.plog as plog | |||
| import torch.nn as nn | |||
| import torch | |||
| from models.lenet5 import LeNet5 | |||
| from models.basic_model import BasicModel | |||
| from models.wabl_models import MyModel | |||
| from multiprocessing import Pool | |||
| import os | |||
| from datasets.data_generator import generate_data_via_codes, code_generator | |||
| from collections import defaultdict | |||
| from abducer.abducer_base import AbducerBase | |||
| from abducer.kb import add_KB, hwf_KB | |||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
| from datasets.hwf.get_hwf import get_hwf | |||
| class Params: | |||
| imgH = 45 | |||
| imgW = 45 | |||
| keep_ratio = True | |||
| saveInterval = 10 | |||
| batchSize = 16 | |||
| workers = 16 | |||
| n_epoch = 10 | |||
| stop_loss = None | |||
| def run_test(): | |||
| result_dir = 'results' | |||
| recorder_file_path = f"{result_dir}/1116.pk"# | |||
| # words = code_generator(code_len, code_num, letter_num) | |||
| kb = add_KB() | |||
| abducer = AbducerBase(kb) | |||
| recorder = logger() | |||
| recorder.set_savefile("test.log") | |||
| train_X, train_Y, test_X, test_Y = get_mnist_add() | |||
| # train_X, train_Y, test_X, test_Y = get_hwf() | |||
| recorder = plog.ResultRecorder() | |||
| cls = LeNet5() | |||
| criterion = nn.CrossEntropyLoss(size_average=True) | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) | |||
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |||
| sign_list = list(range(10)) | |||
| base_model = BasicModel(cls, criterion, optimizer, device, Params(), sign_list, recorder=recorder) | |||
| model = MyModel(base_model) | |||
| res = framework.train(model, abducer, train_X, train_Y, logic_forward = kb.logic_forward, sample_num = 10000, verbose = 1) | |||
| print(res) | |||
| recorder.dump(open(recorder_file_path, "wb")) | |||
| return True | |||
| if __name__ == "__main__": | |||
| os.system("mkdir results") | |||
| run_test() | |||