You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. #================================================================#
  12. from utils.plog import logger
  13. import framework
  14. import torch.nn as nn
  15. import torch
  16. from models.lenet5 import LeNet5, SymbolNet
  17. from models.basic_model import BasicModel
  18. from models.wabl_models import WABLBasicModel
  19. from multiprocessing import Pool
  20. import os
  21. from abducer.abducer_base import AbducerBase
  22. from abducer.kb import add_KB, hwf_KB
  23. from datasets.mnist_add.get_mnist_add import get_mnist_add
  24. from datasets.hwf.get_hwf import get_hwf
  25. def run_test():
  26. # kb = add_KB(True)
  27. kb = hwf_KB(True)
  28. abducer = AbducerBase(kb)
  29. recorder = logger()
  30. # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True)
  31. # test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True)
  32. train_data = get_hwf(train = True, get_pseudo_label = True)
  33. test_data = get_hwf(train = False, get_pseudo_label = True)
  34. # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
  35. cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
  36. criterion = nn.CrossEntropyLoss()
  37. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  38. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  39. base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=1, recorder=recorder)
  40. model = WABLBasicModel(base_model, kb.pseudo_label_list)
  41. res = framework.train(model, abducer, train_data, test_data, sample_num = 10000, verbose = 1)
  42. recorder.print(res)
  43. recorder.dump()
  44. return True
  45. if __name__ == "__main__":
  46. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.