Browse Source

[ENH] reform examples

pull/3/head
Gao Enhao 2 years ago
parent
commit
3b3f886695
17 changed files with 36040 additions and 0 deletions
  1. +83
    -0
      examples/hed/datasets/BK.pl
  2. +4
    -0
      examples/hed/datasets/README.md
  3. +130
    -0
      examples/hed/datasets/get_hed.py
  4. +84
    -0
      examples/hed/datasets/learn_add.pl
  5. +199
    -0
      examples/hed/hed_example.ipynb
  6. +69
    -0
      examples/hed/hed_knn_example.py
  7. +0
    -0
      examples/hed/weights/all_weights_here.txt
  8. +4
    -0
      examples/hwf/datasets/README.md
  9. +50
    -0
      examples/hwf/datasets/get_hwf.py
  10. +184
    -0
      examples/hwf/hwf_example.ipynb
  11. +0
    -0
      examples/hwf/weights/all_weights_here.txt
  12. +2
    -0
      examples/mnist_add/datasets/add.pl
  13. +41
    -0
      examples/mnist_add/datasets/get_mnist_add.py
  14. +5000
    -0
      examples/mnist_add/datasets/test_data.txt
  15. +30000
    -0
      examples/mnist_add/datasets/train_data.txt
  16. +190
    -0
      examples/mnist_add/mnist_add_example.ipynb
  17. +0
    -0
      examples/mnist_add/weights/all_weights_here.txt

+ 83
- 0
examples/hed/datasets/BK.pl View File

@@ -0,0 +1,83 @@
:- use_module(library(apply)).
:- use_module(library(lists)).
% :- use_module(library(tabling)).
% :- table valid_rules/2, op_rule/2.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% DCG parser for equations
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% symbols to be mapped
digit(1).
digit(0).
% digits
digits([D]) --> [D], { digit(D) }. % empty list [] is not a digit
digits([D | T]) --> [D], !, digits(T), { digit(D) }.
digits(X):-
phrase(digits(X), X).
% More integrity constraints 1:
% This two clauses forbid the first digit to be 0.
% You may uncomment them to prune the search space
% length(X, L),
% (L > 1 -> X \= [0 | _]; true).
% Equation definition
eq_arg([D]) --> [D], { \+ D == '+', \+ D == '=' }.
eq_arg([D | T]) --> [D], !, eq_arg(T), { \+ D == '+', \+ D == '=' }.
equation(eq(X, Y, Z)) -->
eq_arg(X), [+], eq_arg(Y), [=], eq_arg(Z).
% More integrity constraints 2:
% This clause restricts the length of arguments to be sane,
% You may uncomment them to prune the search space
% { length(X, LX), length(Y, LY), length(Z, LZ),
% LZ =< max(LX, LY) + 1, LZ >= max(LX, LY) }.
parse_eq(List_of_Terms, Eq) :-
phrase(equation(Eq), List_of_Terms).
%%%%%%%%%%%%%%%%%%%%%%
%% Bit-wise operation
%%%%%%%%%%%%%%%%%%%%%%
% Abductive calculation with given pseudo-labels, abduces pseudo-labels as well as operation rules
calc(Rules, Pseudo) :-
calc([], Rules, Pseudo).
calc(Rules0, Rules1, Pseudo) :-
parse_eq(Pseudo, eq(X,Y,Z)),
bitwise_calc(Rules0, Rules1, X, Y, Z).
% Bit-wise calculation that handles carrying
bitwise_calc(Rules, Rules1, X, Y, Z) :-
reverse(X, X1), reverse(Y, Y1), reverse(Z, Z1),
bitwise_calc_r(Rules, Rules1, X1, Y1, Z1),
maplist(digits, [X,Y,Z]).
bitwise_calc_r(Rs, Rs, [], Y, Y).
bitwise_calc_r(Rs, Rs, X, [], X).
bitwise_calc_r(Rules, Rules1, [D1 | X], [D2 | Y], [D3 | Z]) :-
abduce_op_rule(my_op([D1],[D2],Sum), Rules, Rules2),
((Sum = [D3], Carry = []); (Sum = [C, D3], Carry = [C])),
bitwise_calc_r(Rules2, Rules3, X, Carry, X_carried),
bitwise_calc_r(Rules3, Rules1, X_carried, Y, Z).
%%%%%%%%%%%%%%%%%%%%%%%%%
% Abduce operation rules
%%%%%%%%%%%%%%%%%%%%%%%%%
% Get an existed rule
abduce_op_rule(R, Rules, Rules) :-
member(R, Rules).
% Add a new rule
abduce_op_rule(R, Rules, [R|Rules]) :-
op_rule(R),
valid_rules(Rules, R).
% Integrity Constraints
valid_rules([], _).
valid_rules([my_op([X1],[Y1],_)|Rs], my_op([X],[Y],Z)) :-
op_rule(my_op([X],[Y],Z)),
[X,Y] \= [X1,Y1],
[X,Y] \= [Y1,X1],
valid_rules(Rs, my_op([X],[Y],Z)).
valid_rules([my_op([Y],[X],Z)|Rs], my_op([X],[Y],Z)) :-
X \= Y,
valid_rules(Rs, my_op([X],[Y],Z)).
op_rule(my_op([X],[Y],[Z])) :- digit(X), digit(Y), digit(Z).
op_rule(my_op([X],[Y],[Z1,Z2])) :- digit(X), digit(Y), digits([Z1,Z2]).

+ 4
- 0
examples/hed/datasets/README.md View File

@@ -0,0 +1,4 @@
Download the Handwritten Equation Decipherment dataset from [NJU Box](https://box.nju.edu.cn/f/391c2d48c32b436cb833/) to this folder and unzip it:
```
unzip HED.zip
```

+ 130
- 0
examples/hed/datasets/get_hed.py View File

@@ -0,0 +1,130 @@
import os
import cv2
import torch
import torchvision
import pickle
import numpy as np
import random
from collections import defaultdict
from torch.utils.data import Dataset
from torchvision.transforms import transforms


def get_data(img_dataset, train):
transform = transforms.Compose([transforms.ToTensor()])
X = []
Y = []
if train:
positive = img_dataset["train:positive"]
negative = img_dataset["train:negative"]
else:
positive = img_dataset["test:positive"]
negative = img_dataset["test:negative"]

for equation in positive:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(1)

for equation in negative:
equation = equation.astype(np.float32)
img_list = np.vsplit(equation, equation.shape[0])
X.append(img_list)
Y.append(0)

return X, None, Y


def get_pretrain_data(labels, image_size=(28, 28, 1)):
transform = transforms.Compose([transforms.ToTensor()])
X = []
for label in labels:
label_path = os.path.join(
"./datasets/hed/mnist_images", label
)
img_path_list = os.listdir(label_path)
for img_path in img_path_list:
img = cv2.imread(
os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE
)
img = cv2.resize(img, (image_size[1], image_size[0]))
X.append(np.array(img, dtype=np.float32))

X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X]
Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]

X = [transform(img) for img in X]
return X, Y


# def get_pretrain_data(train_data, image_size=(28, 28, 1)):
# X = []
# for label in [0, 1]:
# for _, equation_list in train_data[label].items():
# for equation in equation_list:
# X = X + equation

# X = np.array(X)
# index = np.array(list(range(len(X))))
# np.random.shuffle(index)
# X = X[index]

# X = [img for img in X]
# Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X]

# return X, Y


def divide_equations_by_len(equations, labels):
equations_by_len = {1: defaultdict(list), 0: defaultdict(list)}
for i, equation in enumerate(equations):
equations_by_len[labels[i]][len(equation)].append(equation)
return equations_by_len


def split_equation(equations_by_len, prop_train, prop_val):
"""
Split the equations in each length to training and validation data according to the proportion
"""
train_equations_by_len = {1: dict(), 0: dict()}
val_equations_by_len = {1: dict(), 0: dict()}

for label in range(2):
for equation_len, equations in equations_by_len[label].items():
random.shuffle(equations)
train_equations_by_len[label][equation_len] = equations[
: len(equations) // (prop_train + prop_val) * prop_train
]
val_equations_by_len[label][equation_len] = equations[
len(equations) // (prop_train + prop_val) * prop_train :
]

return train_equations_by_len, val_equations_by_len


def get_hed(dataset="mnist", train=True):

if dataset == "mnist":
with open(
"./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb",
) as f:
img_dataset = pickle.load(f)
elif dataset == "random":
with open(
"./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk",
"rb",
) as f:
img_dataset = pickle.load(f)
else:
raise Exception("Undefined dataset")

X, _, Y = get_data(img_dataset, train)
equations_by_len = divide_equations_by_len(X, Y)

return equations_by_len


if __name__ == "__main__":
get_hed()

+ 84
- 0
examples/hed/datasets/learn_add.pl View File

@@ -0,0 +1,84 @@
:- ensure_loaded(['BK.pl']).
:- thread_setconcurrency(_, 8).
:- use_module(library(thread)).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% For propositionalisation
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
eval_inst_feature(Ex, Feature):-
eval_eq(Ex, Feature).
%% Evaluate instance given feature
eval_eq(Ex, Feature):-
parse_eq(Ex, eq(X,Y,Z)),
bitwise_calc(Feature,_,X,Y,Z), !.
%%%%%%%%%%%%%%
%% Abduction
%%%%%%%%%%%%%%
% Make abduction when given examples that have been interpreted as pseudo-labels
abduce(Exs, Delta_C) :-
abduce(Exs, [], Delta_C).
abduce([], Delta_C, Delta_C).
abduce([E|Exs], Delta_C0, Delta_C1) :-
calc(Delta_C0, Delta_C2, E),
abduce(Exs, Delta_C2, Delta_C1).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Abduce pseudo-labels only
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
abduce_consistent_insts(Exs):-
abduce(Exs, _), !.
% (Experimental) Uncomment to use parallel abduction
% abduce_consistent_exs_concurrent(Exs), !.
logic_forward(Exs, X) :- abduce_consistent_insts(Exs) -> X = true ; X = false.
logic_forward(Exs) :- abduce_consistent_insts(Exs).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Abduce Delta_C given pseudo-labels
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
consistent_inst_feature(Exs, Delta_C):-
abduce(Exs, Delta_C), !.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% (Experimental) Parallel abduction
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
abduce_consistent_exs_concurrent(Exs) :-
% Split the current data batch into grounding examples and variable examples (which need to be revised)
split_exs(Exs, Ground_Exs, Var_Exs),
% Find the simplest Delta_C for grounding examples.
abduce(Ground_Exs, Ground_Delta_C), !,
% Extend Ground Delta_C into all possible variations
extend_op_rule(Ground_Delta_C, Possible_Deltas),
% Concurrently abduce the variable examples
maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List),
maplist(=.., Goals, Call_List),
% writeln(Goals),
first_solution(Var_Exs, Goals, [local(inf)]).
split_exs([],[],[]).
split_exs([E | Exs], [E | G_Exs], V_Exs):-
ground(E), !,
split_exs(Exs, G_Exs, V_Exs).
split_exs([E | Exs], G_Exs, [E | V_Exs]):-
split_exs(Exs, G_Exs, V_Exs).
:- table extend_op_rule/2.
extend_op_rule(Rules, Rules) :-
length(Rules, 4).
extend_op_rule(Rules, Ext) :-
op_rule(R),
valid_rules(Rules, R),
extend_op_rule([R|Rules], Ext).
% abduction without learning new Delta_C (Because they have been extended!)
abduce2([], _, _).
abduce2([E|Exs], Ground_Exs, Delta_C) :-
% abduce by finding ground examples
member(E, Ground_Exs),
abduce2(Exs, Ground_Exs, Delta_C).
abduce2([E|Exs], Ground_Exs, Delta_C) :-
eval_inst_feature(E, Delta_C),
abduce2(Exs, Ground_Exs, Delta_C).

+ 199
- 0
examples/hed/hed_example.ipynb View File

@@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../\")\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.abducer.abducer_base import HED_Abducer\n",
"from abl.abducer.kb import HED_prolog_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.hed.get_hed import get_hed, split_equation\n",
"from abl import framework_hed"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"recorder = logger()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n"
]
}
],
"source": [
"# Initialize knowledge base and abducer\n",
"kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n",
"abducer = HED_Abducer(kb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Machine Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize necessary component for machine learning part\n",
"cls = SymbolNet(\n",
" num_classes=len(kb.pseudo_label_list),\n",
" image_size=(28, 28, 1),\n",
")\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Pretrain NN classifier\n",
"framework_hed.hed_pretrain(kb, cls, recorder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=recorder.save_dir,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" recorder=recorder,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"total_train_data = get_hed(train=True)\n",
"train_data, val_data = split_equation(total_train_data, 3, 1)\n",
"test_data = get_hed(train=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and save"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)\n",
"framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)\n",
"\n",
"recorder.dump()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ABL",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 69
- 0
examples/hed/hed_knn_example.py View File

@@ -0,0 +1,69 @@
# coding: utf-8
# ================================================================#
# Copyright (C) 2021 Freecss All rights reserved.
#
# File Name :share_example.py
# Author :freecss
# Email :karlfreecss@gmail.com
# Created Date :2021/06/07
# Description :
#
# ================================================================#

import sys
sys.path.append("../")

from abl.utils.plog import logger, INFO
from abl.utils.utils import reduce_dimension
import torch.nn as nn
import torch

from abl.models.nn import LeNet5, SymbolNet
from abl.models.basic_model import BasicModel, BasicDataset
from abl.models.wabl_models import DecisionTree, WABLBasicModel
from sklearn.neighbors import KNeighborsClassifier

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB, HWF_KB, prolog_KB
from datasets.mnist_add.get_mnist_add import get_mnist_add
from datasets.hwf.get_hwf import get_hwf
from datasets.hed.get_hed import get_hed, split_equation
from abl import framework_hed_knn


def run_test():

# kb = add_KB(True)
# kb = HWF_KB(True)
# abducer = AbducerBase(kb)

kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl')
abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True)

recorder = logger()

total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

# ========================= KNN model ============================ #
reduce_dimension(train_data)
reduce_dimension(val_data)
reduce_dimension(test_data)
base_model = KNeighborsClassifier(n_neighbors=3)
pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model)
model = WABLBasicModel(base_model, kb.pseudo_label_list)
model, mapping = framework_hed_knn.train_with_rule(
model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8
)
framework_hed_knn.hed_test(
model, abducer, mapping, train_data, test_data, min_len=5, max_len=8
)
# ============================ End =============================== #

recorder.dump()
return True


if __name__ == "__main__":
run_test()

+ 0
- 0
examples/hed/weights/all_weights_here.txt View File


+ 4
- 0
examples/hwf/datasets/README.md View File

@@ -0,0 +1,4 @@
Download the Handwritten Formula Recognition dataset from [google drive](https://drive.google.com/file/d/1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy/view?usp=sharing) to this folder and unzip it:
```
unzip HWF.zip
```

+ 50
- 0
examples/hwf/datasets/get_hwf.py View File

@@ -0,0 +1,50 @@
import json
from PIL import Image
from torchvision.transforms import transforms

img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1,))
])

def get_data(file, get_pseudo_label):
X = []
if get_pseudo_label:
Z = []
Y = []
img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/'
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
imgs = []
imgs_pseudo_label = []
for img_path in data[idx]['img_paths']:
img = Image.open(img_dir + img_path).convert('L')
img = img_transform(img)
imgs.append(img)
if get_pseudo_label:
imgs_pseudo_label.append(img_path.split('/')[0])
X.append(imgs)
if get_pseudo_label:
Z.append(imgs_pseudo_label)
Y.append(data[idx]['res'])
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_hwf(train = True, get_pseudo_label = False):
if(train):
file = './datasets/hwf/data/expr_train.json'
else:
file = './datasets/hwf/data/expr_test.json'
return get_data(file, get_pseudo_label)

if __name__ == "__main__":
train_X, train_Y = get_hwf(train = True)
test_X, test_Y = get_hwf(train = False)
print(len(train_X), len(test_X))
print(len(train_X[0]), train_X[0][0].shape, train_Y[0])

+ 184
- 0
examples/hwf/hwf_example.ipynb View File

@@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../\")\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.abducer.abducer_base import AbducerBase\n",
"from abl.abducer.kb import HWF_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"\n",
"from models.nn import SymbolNet\n",
"from datasets.hwf.get_hwf import get_hwf\n",
"from abl import framework_hed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"recorder = logger()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"kb = HWF_KB(GKB_flag=True)\n",
"abducer = AbducerBase(kb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Machine Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize necessary component for machine learning part\n",
"cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=recorder.save_dir,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" recorder=recorder,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize WABL model\n",
"# The main function of the WABL model is to serialize data and \n",
"# provide a unified interface for different machine learning models\n",
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_data = get_hwf(train=True, get_pseudo_label=True)\n",
"test_data = get_hwf(train=False, get_pseudo_label=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and save"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train model\n",
"framework_hed.train(\n",
" model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n",
")\n",
"\n",
"# Save results\n",
"recorder.dump()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ABL",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 0
- 0
examples/hwf/weights/all_weights_here.txt View File


+ 2
- 0
examples/mnist_add/datasets/add.pl View File

@@ -0,0 +1,2 @@
pseudo_label(N) :- between(0, 9, N).
logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2.

+ 41
- 0
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -0,0 +1,41 @@
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision.transforms import transforms

def get_data(file, img_dataset, get_pseudo_label):
X = []
if get_pseudo_label:
Z = []
Y = []
with open(file) as f:
for line in f:
line = line.strip().split(' ')
X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]])
if get_pseudo_label:
Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]])
Y.append(int(line[2]))
if get_pseudo_label:
return X, Z, Y
else:
return X, None, Y

def get_mnist_add(train = True, get_pseudo_label = False):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))])
img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform)
if train:
file = './datasets/mnist_add/train_data.txt'
else:
file = './datasets/mnist_add/test_data.txt'
return get_data(file, img_dataset, get_pseudo_label)

if __name__ == "__main__":
train_X, train_Y = get_mnist_add(train = True)
test_X, test_Y = get_mnist_add(train = False)
print(len(train_X), len(test_X))
print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0])

+ 5000
- 0
examples/mnist_add/datasets/test_data.txt
File diff suppressed because it is too large
View File


+ 30000
- 0
examples/mnist_add/datasets/train_data.txt
File diff suppressed because it is too large
View File


+ 190
- 0
examples/mnist_add/mnist_add_example.ipynb View File

@@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../\")\n",
"\n",
"import torch.nn as nn\n",
"import torch\n",
"\n",
"from abl.abducer.abducer_base import AbducerBase\n",
"from abl.abducer.kb import add_KB\n",
"\n",
"from abl.utils.plog import logger\n",
"from abl.models.basic_model import BasicModel\n",
"from abl.models.wabl_models import WABLBasicModel\n",
"\n",
"from models.nn import LeNet5\n",
"from datasets.mnist_add.get_mnist_add import get_mnist_add\n",
"from abl import framework_hed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize logger\n",
"recorder = logger()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logic Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize knowledge base and abducer\n",
"kb = add_KB(GKB_flag=True)\n",
"abducer = AbducerBase(kb, dist_func=\"confidence\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Machine Learning Part"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize necessary component for machine learning part\n",
"cls = LeNet5(num_classes=len(kb.pseudo_label_list))\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize BasicModel\n",
"# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n",
"base_model = BasicModel(\n",
" cls,\n",
" criterion,\n",
" optimizer,\n",
" device,\n",
" save_interval=1,\n",
" save_dir=recorder.save_dir,\n",
" batch_size=32,\n",
" num_epochs=1,\n",
" recorder=recorder,\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use WABL model to join two parts"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize WABL model\n",
"# The main function of the WABL model is to serialize data and \n",
"# provide a unified interface for different machine learning models\n",
"model = WABLBasicModel(base_model, kb.pseudo_label_list)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get training and testing data\n",
"train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n",
"test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train and save"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train model\n",
"framework_hed.train(\n",
" model,\n",
" abducer,\n",
" (train_X, train_Z, train_Y),\n",
" (test_X, test_Z, test_Y),\n",
" loop_num=15,\n",
" sample_num=5000,\n",
" verbose=1,\n",
")\n",
"\n",
"# Save results\n",
"recorder.dump()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ABL",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 0
- 0
examples/mnist_add/weights/all_weights_here.txt View File


Loading…
Cancel
Save