You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. from utils.plog import logger
  13. import numpy as np
  14. import time
  15. import framework
  16. import utils.plog as plog
  17. import torch.nn as nn
  18. import torch
  19. from models.lenet5 import LeNet5
  20. from models.basic_model import BasicModel
  21. from models.wabl_models import MyModel
  22. from multiprocessing import Pool
  23. import os
  24. from abducer.abducer_base import AbducerBase
  25. from abducer.kb import add_KB, hwf_KB
  26. from datasets.mnist_add.get_mnist_add import get_mnist_add
  27. from datasets.hwf.get_hwf import get_hwf
  28. class Params:
  29. imgH = 45
  30. imgW = 45
  31. keep_ratio = True
  32. saveInterval = 10
  33. batchSize = 16
  34. workers = 16
  35. n_epoch = 10
  36. stop_loss = None
  37. def run_test():
  38. result_dir = 'results'
  39. recorder_file_path = f"{result_dir}/1116.pk"#
  40. # kb = add_KB()
  41. kb = hwf_KB()
  42. abducer = AbducerBase(kb)
  43. recorder = logger()
  44. recorder.set_savefile("test.log")
  45. # train_X, train_Y, test_X, test_Y = get_mnist_add()
  46. train_X, train_Y, test_X, test_Y = get_hwf()
  47. recorder = plog.ResultRecorder()
  48. cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_X[0][0].shape[1:]))
  49. criterion = nn.CrossEntropyLoss(size_average=True)
  50. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  51. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  52. base_model = BasicModel(cls, criterion, optimizer, device, Params(), recorder=recorder)
  53. model = MyModel(base_model, kb.pseudo_label_list)
  54. res = framework.train(model, abducer, train_X, train_Y, sample_num = 10000, verbose = 1)
  55. print(res)
  56. recorder.dump(open(recorder_file_path, "wb"))
  57. return True
  58. if __name__ == "__main__":
  59. os.system("mkdir results")
  60. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.