You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. import sys
  13. sys.path.append("../")
  14. from abl.utils.plog import logger, INFO
  15. import torch.nn as nn
  16. import torch
  17. from abl.models.nn import LeNet5, SymbolNet
  18. from abl.models.basic_model import BasicModel, BasicDataset
  19. from abl.models.wabl_models import DecisionTree, WABLBasicModel
  20. from multiprocessing import Pool
  21. from abl.abducer.abducer_base import AbducerBase, HED_Abducer
  22. from abl.abducer.kb import add_KB, HWF_KB, prolog_KB, HED_prolog_KB
  23. from datasets.mnist_add.get_mnist_add import get_mnist_add
  24. from datasets.hwf.get_hwf import get_hwf
  25. from datasets.hed.get_hed import get_hed, split_equation
  26. from abl import framework_hed
  27. def run_test():
  28. # kb = add_KB()
  29. kb = HWF_KB(GKB_flag=True)
  30. abducer = AbducerBase(kb, 'confidence')
  31. # kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
  32. # abducer = HED_Abducer(kb)
  33. recorder = logger()
  34. # total_train_data = get_hed(train=True)
  35. # train_data, val_data = split_equation(total_train_data, 3, 1)
  36. # test_data = get_hed(train=False)
  37. # train_data = get_mnist_add(train=True, get_pseudo_label=True)
  38. # test_data = get_mnist_add(train=False, get_pseudo_label=True)
  39. train_data = get_hwf(train=True, get_pseudo_label=True)
  40. test_data = get_hwf(train=False, get_pseudo_label=True)
  41. # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
  42. cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
  43. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  44. # framework_hed.hed_pretrain(kb, cls, recorder)
  45. criterion = nn.CrossEntropyLoss()
  46. optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
  47. # optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  48. base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=1, recorder=recorder)
  49. model = WABLBasicModel(base_model, kb.pseudo_label_list)
  50. # model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
  51. # framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)
  52. framework_hed.train(model, abducer, train_data, test_data, sample_num=-1, verbose=1)
  53. recorder.dump()
  54. return True
  55. if __name__ == "__main__":
  56. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.