| @@ -0,0 +1,83 @@ | |||
| :- use_module(library(apply)). | |||
| :- use_module(library(lists)). | |||
| % :- use_module(library(tabling)). | |||
| % :- table valid_rules/2, op_rule/2. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% DCG parser for equations | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% symbols to be mapped | |||
| digit(1). | |||
| digit(0). | |||
| % digits | |||
| digits([D]) --> [D], { digit(D) }. % empty list [] is not a digit | |||
| digits([D | T]) --> [D], !, digits(T), { digit(D) }. | |||
| digits(X):- | |||
| phrase(digits(X), X). | |||
| % More integrity constraints 1: | |||
| % This two clauses forbid the first digit to be 0. | |||
| % You may uncomment them to prune the search space | |||
| % length(X, L), | |||
| % (L > 1 -> X \= [0 | _]; true). | |||
| % Equation definition | |||
| eq_arg([D]) --> [D], { \+ D == '+', \+ D == '=' }. | |||
| eq_arg([D | T]) --> [D], !, eq_arg(T), { \+ D == '+', \+ D == '=' }. | |||
| equation(eq(X, Y, Z)) --> | |||
| eq_arg(X), [+], eq_arg(Y), [=], eq_arg(Z). | |||
| % More integrity constraints 2: | |||
| % This clause restricts the length of arguments to be sane, | |||
| % You may uncomment them to prune the search space | |||
| % { length(X, LX), length(Y, LY), length(Z, LZ), | |||
| % LZ =< max(LX, LY) + 1, LZ >= max(LX, LY) }. | |||
| parse_eq(List_of_Terms, Eq) :- | |||
| phrase(equation(Eq), List_of_Terms). | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| %% Bit-wise operation | |||
| %%%%%%%%%%%%%%%%%%%%%% | |||
| % Abductive calculation with given pseudo-labels, abduces pseudo-labels as well as operation rules | |||
| calc(Rules, Pseudo) :- | |||
| calc([], Rules, Pseudo). | |||
| calc(Rules0, Rules1, Pseudo) :- | |||
| parse_eq(Pseudo, eq(X,Y,Z)), | |||
| bitwise_calc(Rules0, Rules1, X, Y, Z). | |||
| % Bit-wise calculation that handles carrying | |||
| bitwise_calc(Rules, Rules1, X, Y, Z) :- | |||
| reverse(X, X1), reverse(Y, Y1), reverse(Z, Z1), | |||
| bitwise_calc_r(Rules, Rules1, X1, Y1, Z1), | |||
| maplist(digits, [X,Y,Z]). | |||
| bitwise_calc_r(Rs, Rs, [], Y, Y). | |||
| bitwise_calc_r(Rs, Rs, X, [], X). | |||
| bitwise_calc_r(Rules, Rules1, [D1 | X], [D2 | Y], [D3 | Z]) :- | |||
| abduce_op_rule(my_op([D1],[D2],Sum), Rules, Rules2), | |||
| ((Sum = [D3], Carry = []); (Sum = [C, D3], Carry = [C])), | |||
| bitwise_calc_r(Rules2, Rules3, X, Carry, X_carried), | |||
| bitwise_calc_r(Rules3, Rules1, X_carried, Y, Z). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Abduce operation rules | |||
| %%%%%%%%%%%%%%%%%%%%%%%%% | |||
| % Get an existed rule | |||
| abduce_op_rule(R, Rules, Rules) :- | |||
| member(R, Rules). | |||
| % Add a new rule | |||
| abduce_op_rule(R, Rules, [R|Rules]) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R). | |||
| % Integrity Constraints | |||
| valid_rules([], _). | |||
| valid_rules([my_op([X1],[Y1],_)|Rs], my_op([X],[Y],Z)) :- | |||
| op_rule(my_op([X],[Y],Z)), | |||
| [X,Y] \= [X1,Y1], | |||
| [X,Y] \= [Y1,X1], | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| valid_rules([my_op([Y],[X],Z)|Rs], my_op([X],[Y],Z)) :- | |||
| X \= Y, | |||
| valid_rules(Rs, my_op([X],[Y],Z)). | |||
| op_rule(my_op([X],[Y],[Z])) :- digit(X), digit(Y), digit(Z). | |||
| op_rule(my_op([X],[Y],[Z1,Z2])) :- digit(X), digit(Y), digits([Z1,Z2]). | |||
| @@ -0,0 +1,4 @@ | |||
| Download the Handwritten Equation Decipherment dataset from [NJU Box](https://box.nju.edu.cn/f/391c2d48c32b436cb833/) to this folder and unzip it: | |||
| ``` | |||
| unzip HED.zip | |||
| ``` | |||
| @@ -0,0 +1,130 @@ | |||
| import os | |||
| import cv2 | |||
| import torch | |||
| import torchvision | |||
| import pickle | |||
| import numpy as np | |||
| import random | |||
| from collections import defaultdict | |||
| from torch.utils.data import Dataset | |||
| from torchvision.transforms import transforms | |||
| def get_data(img_dataset, train): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| Y = [] | |||
| if train: | |||
| positive = img_dataset["train:positive"] | |||
| negative = img_dataset["train:negative"] | |||
| else: | |||
| positive = img_dataset["test:positive"] | |||
| negative = img_dataset["test:negative"] | |||
| for equation in positive: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(1) | |||
| for equation in negative: | |||
| equation = equation.astype(np.float32) | |||
| img_list = np.vsplit(equation, equation.shape[0]) | |||
| X.append(img_list) | |||
| Y.append(0) | |||
| return X, None, Y | |||
| def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| for label in labels: | |||
| label_path = os.path.join( | |||
| "./datasets/hed/mnist_images", label | |||
| ) | |||
| img_path_list = os.listdir(label_path) | |||
| for img_path in img_path_list: | |||
| img = cv2.imread( | |||
| os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE | |||
| ) | |||
| img = cv2.resize(img, (image_size[1], image_size[0])) | |||
| X.append(np.array(img, dtype=np.float32)) | |||
| X = [((img[:, :, np.newaxis] - 127) / 128.0) for img in X] | |||
| Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| X = [transform(img) for img in X] | |||
| return X, Y | |||
| # def get_pretrain_data(train_data, image_size=(28, 28, 1)): | |||
| # X = [] | |||
| # for label in [0, 1]: | |||
| # for _, equation_list in train_data[label].items(): | |||
| # for equation in equation_list: | |||
| # X = X + equation | |||
| # X = np.array(X) | |||
| # index = np.array(list(range(len(X)))) | |||
| # np.random.shuffle(index) | |||
| # X = X[index] | |||
| # X = [img for img in X] | |||
| # Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| # return X, Y | |||
| def divide_equations_by_len(equations, labels): | |||
| equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} | |||
| for i, equation in enumerate(equations): | |||
| equations_by_len[labels[i]][len(equation)].append(equation) | |||
| return equations_by_len | |||
| def split_equation(equations_by_len, prop_train, prop_val): | |||
| """ | |||
| Split the equations in each length to training and validation data according to the proportion | |||
| """ | |||
| train_equations_by_len = {1: dict(), 0: dict()} | |||
| val_equations_by_len = {1: dict(), 0: dict()} | |||
| for label in range(2): | |||
| for equation_len, equations in equations_by_len[label].items(): | |||
| random.shuffle(equations) | |||
| train_equations_by_len[label][equation_len] = equations[ | |||
| : len(equations) // (prop_train + prop_val) * prop_train | |||
| ] | |||
| val_equations_by_len[label][equation_len] = equations[ | |||
| len(equations) // (prop_train + prop_val) * prop_train : | |||
| ] | |||
| return train_equations_by_len, val_equations_by_len | |||
| def get_hed(dataset="mnist", train=True): | |||
| if dataset == "mnist": | |||
| with open( | |||
| "./datasets/hed/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| elif dataset == "random": | |||
| with open( | |||
| "./datasets/hed/random_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| else: | |||
| raise Exception("Undefined dataset") | |||
| X, _, Y = get_data(img_dataset, train) | |||
| equations_by_len = divide_equations_by_len(X, Y) | |||
| return equations_by_len | |||
| if __name__ == "__main__": | |||
| get_hed() | |||
| @@ -0,0 +1,84 @@ | |||
| :- ensure_loaded(['BK.pl']). | |||
| :- thread_setconcurrency(_, 8). | |||
| :- use_module(library(thread)). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% For propositionalisation | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| eval_inst_feature(Ex, Feature):- | |||
| eval_eq(Ex, Feature). | |||
| %% Evaluate instance given feature | |||
| eval_eq(Ex, Feature):- | |||
| parse_eq(Ex, eq(X,Y,Z)), | |||
| bitwise_calc(Feature,_,X,Y,Z), !. | |||
| %%%%%%%%%%%%%% | |||
| %% Abduction | |||
| %%%%%%%%%%%%%% | |||
| % Make abduction when given examples that have been interpreted as pseudo-labels | |||
| abduce(Exs, Delta_C) :- | |||
| abduce(Exs, [], Delta_C). | |||
| abduce([], Delta_C, Delta_C). | |||
| abduce([E|Exs], Delta_C0, Delta_C1) :- | |||
| calc(Delta_C0, Delta_C2, E), | |||
| abduce(Exs, Delta_C2, Delta_C1). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce pseudo-labels only | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_insts(Exs):- | |||
| abduce(Exs, _), !. | |||
| % (Experimental) Uncomment to use parallel abduction | |||
| % abduce_consistent_exs_concurrent(Exs), !. | |||
| logic_forward(Exs, X) :- abduce_consistent_insts(Exs) -> X = true ; X = false. | |||
| logic_forward(Exs) :- abduce_consistent_insts(Exs). | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% Abduce Delta_C given pseudo-labels | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| consistent_inst_feature(Exs, Delta_C):- | |||
| abduce(Exs, Delta_C), !. | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| %% (Experimental) Parallel abduction | |||
| %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |||
| abduce_consistent_exs_concurrent(Exs) :- | |||
| % Split the current data batch into grounding examples and variable examples (which need to be revised) | |||
| split_exs(Exs, Ground_Exs, Var_Exs), | |||
| % Find the simplest Delta_C for grounding examples. | |||
| abduce(Ground_Exs, Ground_Delta_C), !, | |||
| % Extend Ground Delta_C into all possible variations | |||
| extend_op_rule(Ground_Delta_C, Possible_Deltas), | |||
| % Concurrently abduce the variable examples | |||
| maplist(append([abduce2, Var_Exs, Ground_Exs]), [[Possible_Deltas]], Call_List), | |||
| maplist(=.., Goals, Call_List), | |||
| % writeln(Goals), | |||
| first_solution(Var_Exs, Goals, [local(inf)]). | |||
| split_exs([],[],[]). | |||
| split_exs([E | Exs], [E | G_Exs], V_Exs):- | |||
| ground(E), !, | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| split_exs([E | Exs], G_Exs, [E | V_Exs]):- | |||
| split_exs(Exs, G_Exs, V_Exs). | |||
| :- table extend_op_rule/2. | |||
| extend_op_rule(Rules, Rules) :- | |||
| length(Rules, 4). | |||
| extend_op_rule(Rules, Ext) :- | |||
| op_rule(R), | |||
| valid_rules(Rules, R), | |||
| extend_op_rule([R|Rules], Ext). | |||
| % abduction without learning new Delta_C (Because they have been extended!) | |||
| abduce2([], _, _). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| % abduce by finding ground examples | |||
| member(E, Ground_Exs), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| abduce2([E|Exs], Ground_Exs, Delta_C) :- | |||
| eval_inst_feature(E, Delta_C), | |||
| abduce2(Exs, Ground_Exs, Delta_C). | |||
| @@ -0,0 +1,199 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import HED_Abducer\n", | |||
| "from abl.abducer.kb import HED_prolog_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from models.nn import SymbolNet\n", | |||
| "from datasets.hed.get_hed import get_hed, split_equation\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 6, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stderr", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "ERROR: /home/gaoeh/ABL-Package/examples/datasets/hed/learn_add.pl:67:9: Syntax error: Operator expected\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/hed/learn_add.pl')\n", | |||
| "abducer = HED_Abducer(kb)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = SymbolNet(\n", | |||
| " num_classes=len(kb.pseudo_label_list),\n", | |||
| " image_size=(28, 28, 1),\n", | |||
| ")\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Pretrain NN classifier\n", | |||
| "framework_hed.hed_pretrain(kb, cls, recorder)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "total_train_data = get_hed(train=True)\n", | |||
| "train_data, val_data = split_equation(total_train_data, 3, 1)\n", | |||
| "test_data = get_hed(train=False)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)\n", | |||
| "framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)\n", | |||
| "\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,69 @@ | |||
| # coding: utf-8 | |||
| # ================================================================# | |||
| # Copyright (C) 2021 Freecss All rights reserved. | |||
| # | |||
| # File Name :share_example.py | |||
| # Author :freecss | |||
| # Email :karlfreecss@gmail.com | |||
| # Created Date :2021/06/07 | |||
| # Description : | |||
| # | |||
| # ================================================================# | |||
| import sys | |||
| sys.path.append("../") | |||
| from abl.utils.plog import logger, INFO | |||
| from abl.utils.utils import reduce_dimension | |||
| import torch.nn as nn | |||
| import torch | |||
| from abl.models.nn import LeNet5, SymbolNet | |||
| from abl.models.basic_model import BasicModel, BasicDataset | |||
| from abl.models.wabl_models import DecisionTree, WABLBasicModel | |||
| from sklearn.neighbors import KNeighborsClassifier | |||
| from abl.abducer.abducer_base import AbducerBase | |||
| from abl.abducer.kb import add_KB, HWF_KB, prolog_KB | |||
| from datasets.mnist_add.get_mnist_add import get_mnist_add | |||
| from datasets.hwf.get_hwf import get_hwf | |||
| from datasets.hed.get_hed import get_hed, split_equation | |||
| from abl import framework_hed_knn | |||
| def run_test(): | |||
| # kb = add_KB(True) | |||
| # kb = HWF_KB(True) | |||
| # abducer = AbducerBase(kb) | |||
| kb = prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='../examples/datasets/hed/learn_add.pl') | |||
| abducer = AbducerBase(kb, zoopt=True, multiple_predictions=True) | |||
| recorder = logger() | |||
| total_train_data = get_hed(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_hed(train=False) | |||
| # ========================= KNN model ============================ # | |||
| reduce_dimension(train_data) | |||
| reduce_dimension(val_data) | |||
| reduce_dimension(test_data) | |||
| base_model = KNeighborsClassifier(n_neighbors=3) | |||
| pretrain_data_X, pretrain_data_Y = framework_hed_knn.hed_pretrain(base_model) | |||
| model = WABLBasicModel(base_model, kb.pseudo_label_list) | |||
| model, mapping = framework_hed_knn.train_with_rule( | |||
| model, abducer, train_data, val_data, (pretrain_data_X, pretrain_data_Y), select_num=10, min_len=5, max_len=8 | |||
| ) | |||
| framework_hed_knn.hed_test( | |||
| model, abducer, mapping, train_data, test_data, min_len=5, max_len=8 | |||
| ) | |||
| # ============================ End =============================== # | |||
| recorder.dump() | |||
| return True | |||
| if __name__ == "__main__": | |||
| run_test() | |||
| @@ -0,0 +1,4 @@ | |||
| Download the Handwritten Formula Recognition dataset from [google drive](https://drive.google.com/file/d/1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy/view?usp=sharing) to this folder and unzip it: | |||
| ``` | |||
| unzip HWF.zip | |||
| ``` | |||
| @@ -0,0 +1,50 @@ | |||
| import json | |||
| from PIL import Image | |||
| from torchvision.transforms import transforms | |||
| img_transform = transforms.Compose([ | |||
| transforms.ToTensor(), | |||
| transforms.Normalize((0.5,), (1,)) | |||
| ]) | |||
| def get_data(file, get_pseudo_label): | |||
| X = [] | |||
| if get_pseudo_label: | |||
| Z = [] | |||
| Y = [] | |||
| img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' | |||
| with open(file) as f: | |||
| data = json.load(f) | |||
| for idx in range(len(data)): | |||
| imgs = [] | |||
| imgs_pseudo_label = [] | |||
| for img_path in data[idx]['img_paths']: | |||
| img = Image.open(img_dir + img_path).convert('L') | |||
| img = img_transform(img) | |||
| imgs.append(img) | |||
| if get_pseudo_label: | |||
| imgs_pseudo_label.append(img_path.split('/')[0]) | |||
| X.append(imgs) | |||
| if get_pseudo_label: | |||
| Z.append(imgs_pseudo_label) | |||
| Y.append(data[idx]['res']) | |||
| if get_pseudo_label: | |||
| return X, Z, Y | |||
| else: | |||
| return X, None, Y | |||
| def get_hwf(train = True, get_pseudo_label = False): | |||
| if(train): | |||
| file = './datasets/hwf/data/expr_train.json' | |||
| else: | |||
| file = './datasets/hwf/data/expr_test.json' | |||
| return get_data(file, get_pseudo_label) | |||
| if __name__ == "__main__": | |||
| train_X, train_Y = get_hwf(train = True) | |||
| test_X, test_Y = get_hwf(train = False) | |||
| print(len(train_X), len(test_X)) | |||
| print(len(train_X[0]), train_X[0][0].shape, train_Y[0]) | |||
| @@ -0,0 +1,184 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import HWF_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from models.nn import SymbolNet\n", | |||
| "from datasets.hwf.get_hwf import get_hwf\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = HWF_KB(GKB_flag=True)\n", | |||
| "abducer = AbducerBase(kb)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize WABL model\n", | |||
| "# The main function of the WABL model is to serialize data and \n", | |||
| "# provide a unified interface for different machine learning models\n", | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_data = get_hwf(train=True, get_pseudo_label=True)\n", | |||
| "test_data = get_hwf(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework_hed.train(\n", | |||
| " model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.13" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
| @@ -0,0 +1,2 @@ | |||
| pseudo_label(N) :- between(0, 9, N). | |||
| logic_forward([Z1, Z2], Res) :- pseudo_label(Z1), pseudo_label(Z2), Res is Z1+Z2. | |||
| @@ -0,0 +1,41 @@ | |||
| import torch | |||
| import torchvision | |||
| from torch.utils.data import Dataset | |||
| from torchvision.transforms import transforms | |||
| def get_data(file, img_dataset, get_pseudo_label): | |||
| X = [] | |||
| if get_pseudo_label: | |||
| Z = [] | |||
| Y = [] | |||
| with open(file) as f: | |||
| for line in f: | |||
| line = line.strip().split(' ') | |||
| X.append([img_dataset[int(line[0])][0], img_dataset[int(line[1])][0]]) | |||
| if get_pseudo_label: | |||
| Z.append([img_dataset[int(line[0])][1], img_dataset[int(line[1])][1]]) | |||
| Y.append(int(line[2])) | |||
| if get_pseudo_label: | |||
| return X, Z, Y | |||
| else: | |||
| return X, None, Y | |||
| def get_mnist_add(train = True, get_pseudo_label = False): | |||
| transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, ))]) | |||
| img_dataset = torchvision.datasets.MNIST(root='./datasets/mnist_add/', train=train, download=True, transform=transform) | |||
| if train: | |||
| file = './datasets/mnist_add/train_data.txt' | |||
| else: | |||
| file = './datasets/mnist_add/test_data.txt' | |||
| return get_data(file, img_dataset, get_pseudo_label) | |||
| if __name__ == "__main__": | |||
| train_X, train_Y = get_mnist_add(train = True) | |||
| test_X, test_Y = get_mnist_add(train = False) | |||
| print(len(train_X), len(test_X)) | |||
| print(train_X[0][0].shape, train_X[0][1].shape, train_Y[0]) | |||
| @@ -0,0 +1,190 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import sys\n", | |||
| "\n", | |||
| "sys.path.append(\"../\")\n", | |||
| "\n", | |||
| "import torch.nn as nn\n", | |||
| "import torch\n", | |||
| "\n", | |||
| "from abl.abducer.abducer_base import AbducerBase\n", | |||
| "from abl.abducer.kb import add_KB\n", | |||
| "\n", | |||
| "from abl.utils.plog import logger\n", | |||
| "from abl.models.basic_model import BasicModel\n", | |||
| "from abl.models.wabl_models import WABLBasicModel\n", | |||
| "\n", | |||
| "from models.nn import LeNet5\n", | |||
| "from datasets.mnist_add.get_mnist_add import get_mnist_add\n", | |||
| "from abl import framework_hed" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize logger\n", | |||
| "recorder = logger()" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Logic Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize knowledge base and abducer\n", | |||
| "kb = add_KB(GKB_flag=True)\n", | |||
| "abducer = AbducerBase(kb, dist_func=\"confidence\")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Machine Learning Part" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize necessary component for machine learning part\n", | |||
| "cls = LeNet5(num_classes=len(kb.pseudo_label_list))\n", | |||
| "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |||
| "criterion = nn.CrossEntropyLoss()\n", | |||
| "optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize BasicModel\n", | |||
| "# The function of BasicModel is to wrap NN models into the form of an sklearn estimator\n", | |||
| "base_model = BasicModel(\n", | |||
| " cls,\n", | |||
| " criterion,\n", | |||
| " optimizer,\n", | |||
| " device,\n", | |||
| " save_interval=1,\n", | |||
| " save_dir=recorder.save_dir,\n", | |||
| " batch_size=32,\n", | |||
| " num_epochs=1,\n", | |||
| " recorder=recorder,\n", | |||
| ")" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Use WABL model to join two parts" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Initialize WABL model\n", | |||
| "# The main function of the WABL model is to serialize data and \n", | |||
| "# provide a unified interface for different machine learning models\n", | |||
| "model = WABLBasicModel(base_model, kb.pseudo_label_list)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Dataset" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Get training and testing data\n", | |||
| "train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)\n", | |||
| "test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)" | |||
| ] | |||
| }, | |||
| { | |||
| "attachments": {}, | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### Train and save" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# Train model\n", | |||
| "framework_hed.train(\n", | |||
| " model,\n", | |||
| " abducer,\n", | |||
| " (train_X, train_Z, train_Y),\n", | |||
| " (test_X, test_Z, test_Y),\n", | |||
| " loop_num=15,\n", | |||
| " sample_num=5000,\n", | |||
| " verbose=1,\n", | |||
| ")\n", | |||
| "\n", | |||
| "# Save results\n", | |||
| "recorder.dump()" | |||
| ] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "ABL", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| }, | |||
| "orig_nbformat": 4 | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||