You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :share_example.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. from utils.plog import logger, INFO
  13. from utils.utils import copy_state_dict
  14. import torch.nn as nn
  15. import torch
  16. from models.nn import LeNet5, SymbolNet, SymbolNetAutoencoder
  17. from models.basic_model import BasicModel, BasicDataset
  18. from models.wabl_models import DecisionTree, WABLBasicModel
  19. from multiprocessing import Pool
  20. import os
  21. from abducer.abducer_base import AbducerBase
  22. from abducer.kb import add_KB, HWF_KB, HED_prolog_KB
  23. from datasets.mnist_add.get_mnist_add import get_mnist_add
  24. from datasets.hwf.get_hwf import get_hwf
  25. from datasets.hed.get_hed import get_hed, split_equation, get_pretrain_data
  26. import framework_hed
  27. def run_test():
  28. # kb = add_KB(True)
  29. # kb = HWF_KB(True)
  30. # abducer = AbducerBase(kb)
  31. kb = HED_prolog_KB()
  32. abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)
  33. recorder = logger()
  34. total_train_data = get_hed(train=True)
  35. train_data, val_data = split_equation(total_train_data, 3, 1)
  36. test_data = get_hed(train=False)
  37. # cls = LeNet5(num_classes=len(kb.pseudo_label_list), image_size=(train_data[0][0][0].shape[1:]))
  38. cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
  39. cls = SymbolNet(num_classes=len(kb.pseudo_label_list))
  40. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  41. if not os.path.exists("./weights/pretrain_weights.pth"):
  42. INFO("Pretrain Start")
  43. pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11'])
  44. pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
  45. criterion = nn.MSELoss()
  46. optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)
  47. pretrain_model = BasicModel(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
  48. framework_hed.pretrain(pretrain_model, pretrain_data)
  49. torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
  50. cls.load_state_dict(cls_autoencoder.base_model.state_dict())
  51. else:
  52. cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  53. criterion = nn.CrossEntropyLoss()
  54. optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)
  55. # optimizer = torch.optim.Adam(cls.parameters(), lr=0.00001, betas=(0.9, 0.99))
  56. base_model = BasicModel(cls, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, batch_size=32, num_epochs=10, recorder=recorder)
  57. model = WABLBasicModel(base_model, kb.pseudo_label_list)
  58. # train_X, train_Z, train_Y = get_mnist_add(train = True, get_pseudo_label = True)
  59. # test_X, test_Z, test_Y = get_mnist_add(train = False, get_pseudo_label = True)
  60. # train_data = get_hwf(train = True, get_pseudo_label = True)
  61. # test_data = get_hwf(train = False, get_pseudo_label = True)
  62. framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, verbose=1)
  63. # recorder.print(res)
  64. recorder.dump()
  65. return True
  66. if __name__ == "__main__":
  67. run_test()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.