From cfc00d10b3f12e4dc7006aa859086767ee10c40b Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Wed, 6 Dec 2023 23:07:50 +0800 Subject: [PATCH] [MNT] partial hed --- abl/bridge/simple_bridge.py | 6 +- abl/learning/abl_model.py | 19 ++ abl/reasoning/reasoner.py | 130 ++++++------ abl/structures/list_data.py | 16 +- abl/utils/utils.py | 6 +- docs/Examples/MNISTAdd.rst | 6 +- examples/hed/datasets/get_hed.py | 55 ++--- examples/hed/hed_bridge.py | 224 ++++++++++++--------- examples/hed/hed_example.ipynb | 2 +- examples/hed/hed_tmp.py | 159 +++++++++++++++ examples/hed/utils.py | 3 +- examples/mnist_add/mnist_add_example.ipynb | 4 +- tests/test_reasoning.py | 5 +- 13 files changed, 406 insertions(+), 229 deletions(-) create mode 100644 examples/hed/hed_tmp.py diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index 49daeab..2706394 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -57,6 +57,7 @@ class SimpleBridge(BaseBridge): def train( self, train_data: Union[ListData, DataSet], + val_data: Optional[Union[ListData, DataSet]] = None, loops: int = 50, segment_size: Union[int, float] = -1, eval_interval: int = 1, @@ -86,7 +87,10 @@ class SimpleBridge(BaseBridge): if (loop + 1) % eval_interval == 0 or loop == loops - 1: print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") - self.valid(train_data) + if val_data is not None: + self.valid(val_data) + else: + self.valid(train_data) if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index b433c64..d517c93 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -79,6 +79,25 @@ class ABLModel: data_X = data_samples.flatten("X") data_y = data_samples.flatten("abduced_idx") return self.base_model.fit(X=data_X, y=data_y) + + def valid(self, data_samples: ListData) -> float: + """ + Validate the model on the given data. + + Parameters + ---------- + data_samples : ListData + A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. + + Returns + ------- + float + The accuracy the trained model. + """ + data_X = data_samples.flatten("X") + data_y = data_samples.flatten("abduced_idx") + score = self.base_model.score(X=data_X, y=data_y) + return score def _model_operation(self, operation: str, *args, **kwargs): model = self.base_model diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index fd54286..ff3ecc3 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -17,19 +17,19 @@ class ReasonerBase: kb : class KBBase The knowledge base to be used for reasoning. dist_func : str, optional - The distance function to be used when determining the cost list between each - candidate and the given prediction. Valid options include: "confidence" (default) | - "hamming". For "confidence", it calculates the distance between the prediction - and candidate based on confidence derived from the predicted probability in the - data sample.For "hamming", it directly calculates the Hamming distance between + The distance function to be used when determining the cost list between each + candidate and the given prediction. Valid options include: "confidence" (default) | + "hamming". For "confidence", it calculates the distance between the prediction + and candidate based on confidence derived from the predicted probability in the + data sample.For "hamming", it directly calculates the Hamming distance between the predicted pseudo label in the data sample and candidate. mapping : dict, optional - A mapping from index in the base model to label. If not provided, a default + A mapping from index in the base model to label. If not provided, a default order-based mapping is created. max_revision : int or float, optional - The upper limit on the number of revisions for each data sample when - performing abductive reasoning. If float, denotes the fraction of the total - length that can be revised. A value of -1 implies no restriction on the + The upper limit on the number of revisions for each data sample when + performing abductive reasoning. If float, denotes the fraction of the total + length that can be revised. A value of -1 implies no restriction on the number of revisions. Defaults to -1. require_more_revision : int, optional Specifies additional number of revisions permitted beyond the minimum required @@ -37,52 +37,53 @@ class ReasonerBase: use_zoopt : bool, optional Whether to use the Zoopt library during abductive reasoning. Defaults to False. """ - - def __init__(self, - kb, - dist_func="confidence", - mapping=None, - max_revision=-1, - require_more_revision=0, - use_zoopt=False, - ): + + def __init__( + self, + kb, + dist_func="confidence", + mapping=None, + max_revision=-1, + require_more_revision=0, + use_zoopt=False, + ): if dist_func not in ["hamming", "confidence"]: - raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"") + raise NotImplementedError( + 'Valid options for dist_func include "hamming" and "confidence"' + ) self.kb = kb self.dist_func = dist_func self.use_zoopt = use_zoopt self.max_revision = max_revision self.require_more_revision = require_more_revision - + if mapping is None: - self.mapping = { - index: label for index, label in enumerate(self.kb.pseudo_label_list) - } + self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} else: if not isinstance(mapping, dict): raise TypeError("mapping should be dict") - for key, value in mapping.items(): - if not isinstance(key, int): - raise ValueError("All keys in the mapping must be integers") - if value not in self.kb.pseudo_label_list: - raise ValueError("All values in the mapping must be in the pseudo_label_list") + for key, value in mapping.items(): + if not isinstance(key, int): + raise ValueError("All keys in the mapping must be integers") + if value not in self.kb.pseudo_label_list: + raise ValueError("All values in the mapping must be in the pseudo_label_list") self.mapping = mapping self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) def _get_one_candidate(self, data_sample, candidates): """ - Due to the nondeterminism of abductive reasoning, there could be multiple candidates - satisfying the knowledge base. When this happens, return one candidate that has the + Due to the nondeterminism of abductive reasoning, there could be multiple candidates + satisfying the knowledge base. When this happens, return one candidate that has the minimum cost. If no candidates are provided, an empty list is returned. - + Parameters ---------- data_sample : ListData Data sample. candidates : List[List[Any]] - Multiple compatible candidates. - + Multiple compatible candidates. + Returns ------- List[Any] @@ -96,17 +97,17 @@ class ReasonerBase: cost_array = self._get_cost_list(data_sample, candidates) candidate = candidates[np.argmin(cost_array)] return candidate - + def _get_cost_list(self, data_sample, candidates): """ - Get the list of costs between each candidate and the given data sample. The list is + Get the list of costs between each candidate and the given data sample. The list is calculated based on one of the following distance functions: - - "hamming": Directly calculates the Hamming distance between the predicted pseudo + - "hamming": Directly calculates the Hamming distance between the predicted pseudo label in the data sample and candidate. - - "confidence": Calculates the distance between the prediction and candidate based - on confidence derived from the predicted probability in the data + - "confidence": Calculates the distance between the prediction and candidate based + on confidence derived from the predicted probability in the data sample. - + Parameters ---------- data_sample : ListData @@ -121,12 +122,9 @@ class ReasonerBase: candidates = [[self.remapping[x] for x in c] for c in candidates] return confidence_dist(data_sample.pred_prob, candidates) - - def zoopt_get_solution( - self, symbol_num, data_sample, max_revision_num - ): + def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): """ - Get the optimal solution using the Zoopt library. The solution is a list of + Get the optimal solution using the Zoopt library. The solution is a list of boolean values, where '1' (True) indicates the indices chosen to be revised. Parameters @@ -138,9 +136,7 @@ class ReasonerBase: max_revision_num : int Specifies the maximum number of revisions allowed. """ - dimension = Dimension( - size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num - ) + dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) objective = Objective( lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol), dim=dimension, @@ -149,16 +145,16 @@ class ReasonerBase: parameter = Parameter(budget=100, intermediate_result=False, autoset=True) solution = Opt.min(objective, parameter).get_x() return solution - + def zoopt_revision_score(self, symbol_num, data_sample, sol): """ - Get the revision score for a solution. A lower score suggests that the Zoopt library + Get the revision score for a solution. A lower score suggests that the Zoopt library has a higher preference for this solution. """ revision_idx = np.where(sol.get_x() != 0)[0] - candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, - data_sample.Y, - revision_idx) + candidates = self.kb.revise_at_idx( + data_sample.pred_pseudo_label, data_sample.Y, revision_idx + ) if len(candidates) > 0: return np.min(self._get_cost_list(data_sample, candidates)) else: @@ -166,7 +162,7 @@ class ReasonerBase: def _constrain_revision_num(self, solution, max_revision_num): """ - Constrain that the total number of revisions chosen by the solution does not exceed + Constrain that the total number of revisions chosen by the solution does not exceed maximum number of revisions allowed. """ x = solution.get_x() @@ -189,7 +185,7 @@ class ReasonerBase: if max_revision < 0: raise ValueError("If max_revision is an int, it must be non-negative.") return max_revision - + def abduce(self, data_sample): """ Perform abductive reasoning on the given data sample. @@ -198,7 +194,7 @@ class ReasonerBase: ---------- data_sample : ListData Data sample. - + Returns ------- List[Any] @@ -207,19 +203,21 @@ class ReasonerBase: """ symbol_num = data_sample.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) - + if self.use_zoopt: solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) revision_idx = np.where(solution != 0)[0] - candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, - data_sample.Y, - revision_idx) + candidates = self.kb.revise_at_idx( + data_sample.pred_pseudo_label, data_sample.Y, revision_idx + ) else: - candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label, - data_sample.Y, - max_revision_num, - self.require_more_revision) - + candidates = self.kb.abduce_candidates( + data_sample.pred_pseudo_label, + data_sample.Y, + max_revision_num, + self.require_more_revision, + ) + candidate = self._get_one_candidate(data_sample, candidates) return candidate @@ -228,9 +226,7 @@ class ReasonerBase: Perform abductive reasoning on the given prediction data samples. For detailed information, refer to `abduce`. """ - abduced_pseudo_label = [ - self.abduce(data_sample) for data_sample in data_samples - ] + abduced_pseudo_label = [self.abduce(data_sample) for data_sample in data_samples] data_samples.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py index a53ffc5..d6dad04 100644 --- a/abl/structures/list_data.py +++ b/abl/structures/list_data.py @@ -297,9 +297,15 @@ class ListData(BaseDataElement): def __len__(self) -> int: """int: The length of ListData.""" - if len(self._data_fields) > 0: - one_element = next(iter(self._data_fields)) - return len(getattr(self, one_element)) - # return len(self.values()[0]) + iterator = iter(self._data_fields) + data = next(iterator) + + while getattr(self, data) is None: + try: + data = next(iterator) + except StopIteration: + break + if getattr(self, data) is None: + raise ValueError("All data fields are None.") else: - return 0 + return len(getattr(self, data)) diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 6e3bb4f..3966506 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -22,10 +22,10 @@ def flatten(nested_list): TypeError If the input object is not a list. """ - if not isinstance(nested_list, list): - raise TypeError("Input must be of type list.") + # if not isinstance(nested_list, list): + # raise TypeError("Input must be of type list.") - if not nested_list or not isinstance(nested_list[0], (list, tuple)): + if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)): return nested_list return list(chain.from_iterable(nested_list)) diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index fc52de5..afdff2d 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -3,7 +3,7 @@ MNIST Add MNIST Add was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: -.. image:: ../img/image_1.jpg +.. image:: ../img/Datasets_1.png :width: 350px :align: center @@ -11,9 +11,5 @@ MNIST Add was first introduced in [1] and the inputs of this task are pairs of M The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. -In the Abductive Learning framework, the inference process is as follows: - -.. image:: ../img/image_2.jpg - :width: 700px [1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. \ No newline at end of file diff --git a/examples/hed/datasets/get_hed.py b/examples/hed/datasets/get_hed.py index 5bac060..3a0b34b 100644 --- a/examples/hed/datasets/get_hed.py +++ b/examples/hed/datasets/get_hed.py @@ -1,19 +1,18 @@ import os +import os.path as osp import cv2 -import torch -import torchvision import pickle import numpy as np import random + from collections import defaultdict -from torch.utils.data import Dataset from torchvision.transforms import transforms +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + def get_data(img_dataset, train): - transform = transforms.Compose([transforms.ToTensor()]) - X = [] - Y = [] + X, Y = [], [] if train: positive = img_dataset["train:positive"] negative = img_dataset["train:negative"] @@ -39,15 +38,12 @@ def get_data(img_dataset, train): def get_pretrain_data(labels, image_size=(28, 28, 1)): transform = transforms.Compose([transforms.ToTensor()]) X = [] + img_dir = osp.join(CURRENT_DIR, "mnist_images") for label in labels: - label_path = os.path.join( - "./datasets/mnist_images", label - ) + label_path = osp.join(img_dir, label) img_path_list = os.listdir(label_path) for img_path in img_path_list: - img = cv2.imread( - os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE - ) + img = cv2.imread(osp.join(label_path, img_path), cv2.IMREAD_GRAYSCALE) img = cv2.resize(img, (image_size[1], image_size[0])) X.append(np.array(img, dtype=np.float32)) @@ -58,24 +54,6 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): return X, Y -# def get_pretrain_data(train_data, image_size=(28, 28, 1)): -# X = [] -# for label in [0, 1]: -# for _, equation_list in train_data[label].items(): -# for equation in equation_list: -# X = X + equation - -# X = np.array(X) -# index = np.array(list(range(len(X)))) -# np.random.shuffle(index) -# X = X[index] - -# X = [img for img in X] -# Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] - -# return X, Y - - def divide_equations_by_len(equations, labels): equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} for i, equation in enumerate(equations): @@ -104,22 +82,15 @@ def split_equation(equations_by_len, prop_train, prop_val): def get_hed(dataset="mnist", train=True): - if dataset == "mnist": - with open( - "./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", - "rb", - ) as f: - img_dataset = pickle.load(f) + file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") elif dataset == "random": - with open( - "./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk", - "rb", - ) as f: - img_dataset = pickle.load(f) + file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk") else: - raise Exception("Undefined dataset") + raise ValueError("Undefined dataset") + with open(file, "rb") as f: + img_dataset = pickle.load(f) X, _, Y = get_data(img_dataset, train) equations_by_len = divide_equations_by_len(X, Y) diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index b0f401f..8ec57d3 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -1,5 +1,6 @@ import os from collections import defaultdict +from typing import Any, List import torch from torch.utils.data import DataLoader @@ -8,6 +9,7 @@ from abl.learning import ABLModel, BasicNN from abl.bridge import SimpleBridge from abl.evaluation import BaseMetric from abl.dataset import BridgeDataset, RegressionDataset +from abl.structures import ListData from abl.utils import print_log from examples.hed.utils import gen_mappings, InfiniteSampler @@ -59,54 +61,35 @@ class HEDBridge(SimpleBridge): "model": cls_autoencoder.base_model.state_dict(), } - torch.save( - save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth") - ) + torch.save(save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth")) self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth")) - def abduce_pseudo_label( - self, - pred_label, - pred_prob, - pseudo_label, - Y, - max_revision=-1, - require_more_revision=0, - ): - return self.reasoner.abduce( - (pred_label, pred_prob, pseudo_label, Y), - max_revision, - require_more_revision, - ) - - def select_mapping_and_abduce(self, pred_label, pred_prob, Y): + def abduce_pseudo_label(self, data_samples: ListData): candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) mapping_score = [] - pred_pseudo_label_list = [] abduced_pseudo_label_list = [] for _mapping in candidate_mappings: self.reasoner.mapping = _mapping - self.reasoner.set_remapping() - pred_pseudo_label = self.label_to_pseudo_label(pred_label) - abduced_pseudo_label = self.abduce_pseudo_label( - pred_label, pred_prob, pred_pseudo_label, Y, 20 - ) - mapping_score.append( - len(abduced_pseudo_label) - abduced_pseudo_label.count([]) - ) - pred_pseudo_label_list.append(pred_pseudo_label) + self.reasoner.remapping = dict(zip(_mapping.values(), _mapping.keys())) + self.idx_to_pseudo_label(data_samples) + abduced_pseudo_label = self.reasoner.abduce(data_samples) + mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([])) abduced_pseudo_label_list.append(abduced_pseudo_label) max_revisible_instances = max(mapping_score) return_idx = mapping_score.index(max_revisible_instances) self.reasoner.mapping = candidate_mappings[return_idx] - self.reasoner.set_remapping() - return abduced_pseudo_label_list[return_idx] + self.reasoner.mapping = dict( + zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys()) + ) + data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] + + return data_samples.abduced_pseudo_label - def check_training_impact(self, filtered_X, filtered_abduced_label, X): - character_accuracy = self.model.valid(filtered_X, filtered_abduced_label) - revisible_ratio = len(filtered_X) / len(X) + def check_training_impact(self, filtered_data_samples, data_samples): + character_accuracy = self.model.valid(filtered_data_samples) + revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) print_log( f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}", logger="current", @@ -136,10 +119,7 @@ class HEDBridge(SimpleBridge): pred_label, _ = self.predict(X) pred_pseudo_label = self.label_to_pseudo_label(pred_label) consistent_num = sum( - [ - self.reasoner.kb.consist_rule(instance, rule) - for instance in pred_pseudo_label - ] + [self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label] ) return consistent_num / len(X) @@ -178,12 +158,13 @@ class HEDBridge(SimpleBridge): return rules @staticmethod - def filter_empty(X, Z): - filtered_X, filtered_Z = [], [] - for x, z in zip(X, Z): - if len(z) > 0: - filtered_X.append(x), filtered_Z.append(z) - return (filtered_X, filtered_Z) + def filter_empty(data_samples: ListData): + consistent_dix = [ + i + for i in range(len(data_samples.abduced_pseudo_label)) + if len(data_samples.abduced_pseudo_label[i]) > 0 + ] + return data_samples[consistent_dix] @staticmethod def select_rules(rule_dict): @@ -203,79 +184,49 @@ class HEDBridge(SimpleBridge): add_nums_dict[add_nums] = r return list(rule_dict) - def train( - self, - train_data, - val_data, - select_num=10, - min_len=5, - max_len=8, - ): + def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: + return super().idx_to_pseudo_label(data_samples) + + def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8): for equation_len in range(min_len, max_len): print_log( f"============== equation_len: {equation_len}-{equation_len + 1} ================", logger="current", ) - train_X = train_data[1][equation_len] + train_data[1][equation_len + 1] - train_Y = [None] * len(train_X) - dataset = BridgeDataset(train_X, None, train_Y) - sampler = InfiniteSampler(len(dataset)) - data_loader = DataLoader( - dataset, - sampler=sampler, - batch_size=select_num, - collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], - ) - - condition_num = 0 - for seg_idx, (X, Z, Y) in enumerate(data_loader): - pred_label, pred_prob = self.predict(X) - if equation_len == min_len: - abduced_pseudo_label = self.select_mapping_and_abduce( - pred_label, pred_prob, Y - ) - else: - pred_pseudo_label = self.label_to_pseudo_label(pred_label) - abduced_pseudo_label = self.abduce_pseudo_label( - pred_label, pred_prob, pred_pseudo_label, Y, 20 - ) - filtered_X, filtered_abduced_pseudo_label = self.filter_empty( - X, abduced_pseudo_label - ) - if len(filtered_X) == 0: - continue - filtered_abduced_label = self.pseudo_label_to_label( - filtered_abduced_pseudo_label - ) - min_loss = self.model.train(filtered_X, filtered_abduced_label) + X = train_data[1][equation_len] + train_data[1][equation_len + 1] + Y = [None] * len(X) + data_samples = self.data_preprocess(X, None, Y) + sampler = InfiniteSampler(len(data_samples)) + for seg_idx, select_idx in enumerate(sampler): + sub_data_samples = data_samples[select_idx] + self.predict(sub_data_samples) + # self.idx_to_pseudo_label(sub_data_samples) + self.abduce_pseudo_label(sub_data_samples) + filtered_sub_data_samples = self.filter_empty(sub_data_samples) + self.pseudo_label_to_idx(filtered_sub_data_samples) + loss = self.model.train(filtered_sub_data_samples) print_log( - f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] minimal_loss is {min_loss:.5f}", + f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}", logger="current", ) - if self.check_training_impact(filtered_X, filtered_abduced_label, X): + if self.check_training_impact(filtered_sub_data_samples, sub_data_samples): condition_num += 1 else: condition_num = 0 if condition_num >= 5: - print_log( - f"Now checking if we can go to next course", logger="current" - ) + print_log(f"Now checking if we can go to next course", logger="current") rules = self.get_rules_from_data( - dataset, samples_per_rule=3, samples_num=50 - ) - print_log( - f"Learned rules from data: " + str(rules), logger="current" + data_samples, samples_per_rule=3, samples_num=50 ) + print_log(f"Learned rules from data: " + str(rules), logger="current") seems_good = self.check_rule_quality(rules, val_data, equation_len) if seems_good: - self.model.save( - save_path=f"./weights/eq_len_{equation_len}.pth" - ) + self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") break else: if equation_len == min_len: @@ -285,8 +236,83 @@ class HEDBridge(SimpleBridge): ) self.model.load(load_path="./weights/pretrain_weights.pth") else: - self.model.load( - load_path=f"./weights/eq_len_{equation_len - 1}.pth" - ) + self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") condition_num = 0 print_log("Reload Model and retrain", logger="current") + + # def train( + # self, + # train_data, + # val_data, + # segment_size=10, + # min_len=5, + # max_len=8, + # ): + # for equation_len in range(min_len, max_len): + # print_log( + # f"============== equation_len: {equation_len}-{equation_len + 1} ================", + # logger="current", + # ) + + # train_X = train_data[1][equation_len] + train_data[1][equation_len + 1] + # train_Y = [None] * len(train_X) + # # data_samples = self.data_preprocess(train_X, None, train_Y) + + # dataset = BridgeDataset(train_X, None, train_Y) + # sampler = InfiniteSampler(len(dataset)) + # data_loader = DataLoader( + # dataset, + # sampler=sampler, + # batch_size=segment_size, + # collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], + # ) + + # condition_num = 0 + + # for seg_idx, (X, Z, Y) in enumerate(data_loader): + # pred_label, pred_prob = self.predict(ListData(X=X)) + # if equation_len == min_len: + # abduced_pseudo_label = self.select_mapping_and_abduce(pred_label, pred_prob, Y) + # else: + # pred_pseudo_label = self.label_to_pseudo_label(pred_label) + # abduced_pseudo_label = self.abduce_pseudo_label( + # pred_label, pred_prob, pred_pseudo_label, Y, 20 + # ) + # filtered_X, filtered_abduced_pseudo_label = self.filter_empty( + # X, abduced_pseudo_label + # ) + # if len(filtered_X) == 0: + # continue + # filtered_abduced_label = self.pseudo_label_to_label(filtered_abduced_pseudo_label) + # min_loss = self.model.train(filtered_X, filtered_abduced_label) + + # print_log( + # f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] minimal_loss is {min_loss:.5f}", + # logger="current", + # ) + + # if self.check_training_impact(filtered_X, filtered_abduced_label, X): + # condition_num += 1 + # else: + # condition_num = 0 + + # if condition_num >= 5: + # print_log(f"Now checking if we can go to next course", logger="current") + # rules = self.get_rules_from_data(dataset, samples_per_rule=3, samples_num=50) + # print_log(f"Learned rules from data: " + str(rules), logger="current") + + # seems_good = self.check_rule_quality(rules, val_data, equation_len) + # if seems_good: + # self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") + # break + # else: + # if equation_len == min_len: + # print_log( + # "Learned mapping is: " + str(self.reasoner.mapping), + # logger="current", + # ) + # self.model.load(load_path="./weights/pretrain_weights.pth") + # else: + # self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") + # condition_num = 0 + # print_log("Reload Model and retrain", logger="current") diff --git a/examples/hed/hed_example.ipynb b/examples/hed/hed_example.ipynb index 52d9b91..14b2f93 100644 --- a/examples/hed/hed_example.ipynb +++ b/examples/hed/hed_example.ipynb @@ -280,7 +280,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/hed/hed_tmp.py b/examples/hed/hed_tmp.py new file mode 100644 index 0000000..6b315c9 --- /dev/null +++ b/examples/hed/hed_tmp.py @@ -0,0 +1,159 @@ +import os.path as osp + +import numpy as np +import torch +import torch.nn as nn + +from abl.evaluation import SemanticsMetric, SymbolMetric +from abl.learning import ABLModel, BasicNN +from abl.reasoning import PrologKB, ReasonerBase +from abl.utils import ABLLogger, print_log, reform_list +from examples.hed.datasets.get_hed import get_hed, split_equation +from examples.hed.hed_bridge import HEDBridge +from examples.models.nn import SymbolNet + +# Build logger +print_log("Abductive Learning on the HED example.", logger="current") + +# Retrieve the directory of the Log file and define the directory for saving the model weights. +log_dir = ABLLogger.get_current_instance().log_dir +weights_dir = osp.join(log_dir, "weights") + + +### Logic Part +# Initialize knowledge base and abducer +class HedKB(PrologKB): + def __init__(self, pseudo_label_list, pl_file): + super().__init__(pseudo_label_list, pl_file) + + def consist_rule(self, exs, rules): + rules = str(rules).replace("'", "") + return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 + + def abduce_rules(self, pred_res): + prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) + if len(prolog_result) == 0: + return None + prolog_rules = prolog_result[0]["X"] + rules = [rule.value for rule in prolog_rules] + return rules + + +class HedReasoner(ReasonerBase): + def revise_at_idx(self, data_sample): + revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0] + candidate = self.kb.revise_at_idx( + data_sample.pred_pseudo_label, data_sample.Y, revision_idx + ) + return candidate + + def zoopt_revision_score(self, symbol_num, data_sample, sol): + revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label) + data_sample.revision_flag = revision_flag + + lefted_idxs = [i for i in range(len(data_sample.pred_idx))] + candidate_size = [] + while lefted_idxs: + idxs = [] + idxs.append(lefted_idxs.pop(0)) + max_candidate_idxs = [] + found = False + for idx in range(-1, len(data_sample.pred_idx)): + if (not idx in idxs) and (idx >= 0): + idxs.append(idx) + candidate = self.revise_at_idx(data_sample[idxs]) + if len(candidate) == 0: + if len(idxs) > 1: + idxs.pop() + else: + if len(idxs) > len(max_candidate_idxs): + found = True + max_candidate_idxs = idxs.copy() + removed = [i for i in lefted_idxs if i in max_candidate_idxs] + if found: + candidate_size.append(len(removed) + 1) + lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] + candidate_size.sort() + score = 0 + import math + + for i in range(0, len(candidate_size)): + score -= math.exp(-i) * candidate_size[i] + return score + + def abduce(self, data_sample): + symbol_num = data_sample.elements_num("pred_pseudo_label") + max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) + + solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) + + data_sample.revision_flag = reform_list( + solution.astype(np.int32), data_sample.pred_pseudo_label + ) + + abduced_pseudo_label = [] + + for single_instance in data_sample: + single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label] + candidates = self.revise_at_idx(single_instance) + if len(candidates) == 0: + abduced_pseudo_label.append([]) + else: + abduced_pseudo_label.append(candidates[0][0]) + data_sample.abduced_pseudo_label = abduced_pseudo_label + return abduced_pseudo_label + + def abduce_rules(self, pred_res): + return self.kb.abduce_rules(pred_res) + + +import os + +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + +kb = HedKB( + pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl") +) +reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20) + +### Machine Learning Part +# Build necessary components for BasicNN +cls = SymbolNet(num_classes=4) +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# Build BasicNN +# The function of BasicNN is to wrap NN models into the form of an sklearn estimator +base_model = BasicNN( + cls, + criterion, + optimizer, + device, + batch_size=32, + num_epochs=1, + save_interval=1, + save_dir=weights_dir, +) + +# Build ABLModel +# The main function of the ABL model is to serialize data and +# provide a unified interface for different machine learning models +model = ABLModel(base_model) + +### Metric +# Set up metrics +metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")] + +### Bridge Machine Learning and Logic Reasoning +bridge = HEDBridge(model, reasoner, metric_list) + +### Dataset +total_train_data = get_hed(train=True) +train_data, val_data = split_equation(total_train_data, 3, 1) +test_data = get_hed(train=False) + +### Train and Test +bridge.pretrain("examples/hed/weights") +bridge.train(train_data, val_data) diff --git a/examples/hed/utils.py b/examples/hed/utils.py index 42b7316..cc76e93 100644 --- a/examples/hed/utils.py +++ b/examples/hed/utils.py @@ -12,7 +12,8 @@ class InfiniteSampler(sampler.Sampler): while True: order = np.random.permutation(self.num_samples) for i in range(self.num_samples): - yield order[i] + yield order[i: i + 10] + i += 10 def __len__(self): return None diff --git a/examples/mnist_add/mnist_add_example.ipynb b/examples/mnist_add/mnist_add_example.ipynb index 409eb97..d689954 100644 --- a/examples/mnist_add/mnist_add_example.ipynb +++ b/examples/mnist_add/mnist_add_example.ipynb @@ -79,7 +79,7 @@ "metadata": {}, "outputs": [], "source": [ - "# # Build ABLModel\n", + "# Build ABLModel\n", "# The main function of the ABL model is to serialize data and \n", "# provide a unified interface for different machine learning models\n", "model = ABLModel(base_model)" @@ -193,7 +193,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.18" }, "orig_nbformat": 4, "vscode": { diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index d6b0a78..0eed8cb 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -58,9 +58,8 @@ class TestPrologKB(object): def test_logic_forward_pl2(self, kb_hed): consist_exs = [ - [1, 1, "+", 0, "=", 1, 1], - [1, "+", 1, "=", 1, 0], - [0, "+", 0, "=", 0], + [1, "+", 1, "=", 0], + [1, "+", 1, "=", 1], ] inconsist_exs = [ [1, 1, "+", 0, "=", 1, 1],