| @@ -57,6 +57,7 @@ class SimpleBridge(BaseBridge): | |||
| def train( | |||
| self, | |||
| train_data: Union[ListData, DataSet], | |||
| val_data: Optional[Union[ListData, DataSet]] = None, | |||
| loops: int = 50, | |||
| segment_size: Union[int, float] = -1, | |||
| eval_interval: int = 1, | |||
| @@ -86,7 +87,10 @@ class SimpleBridge(BaseBridge): | |||
| if (loop + 1) % eval_interval == 0 or loop == loops - 1: | |||
| print_log(f"Evaluation start: loop(val) [{loop + 1}]", logger="current") | |||
| self.valid(train_data) | |||
| if val_data is not None: | |||
| self.valid(val_data) | |||
| else: | |||
| self.valid(train_data) | |||
| if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): | |||
| print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") | |||
| @@ -79,6 +79,25 @@ class ABLModel: | |||
| data_X = data_samples.flatten("X") | |||
| data_y = data_samples.flatten("abduced_idx") | |||
| return self.base_model.fit(X=data_X, y=data_y) | |||
| def valid(self, data_samples: ListData) -> float: | |||
| """ | |||
| Validate the model on the given data. | |||
| Parameters | |||
| ---------- | |||
| data_samples : ListData | |||
| A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. | |||
| Returns | |||
| ------- | |||
| float | |||
| The accuracy the trained model. | |||
| """ | |||
| data_X = data_samples.flatten("X") | |||
| data_y = data_samples.flatten("abduced_idx") | |||
| score = self.base_model.score(X=data_X, y=data_y) | |||
| return score | |||
| def _model_operation(self, operation: str, *args, **kwargs): | |||
| model = self.base_model | |||
| @@ -17,19 +17,19 @@ class ReasonerBase: | |||
| kb : class KBBase | |||
| The knowledge base to be used for reasoning. | |||
| dist_func : str, optional | |||
| The distance function to be used when determining the cost list between each | |||
| candidate and the given prediction. Valid options include: "confidence" (default) | | |||
| "hamming". For "confidence", it calculates the distance between the prediction | |||
| and candidate based on confidence derived from the predicted probability in the | |||
| data sample.For "hamming", it directly calculates the Hamming distance between | |||
| The distance function to be used when determining the cost list between each | |||
| candidate and the given prediction. Valid options include: "confidence" (default) | | |||
| "hamming". For "confidence", it calculates the distance between the prediction | |||
| and candidate based on confidence derived from the predicted probability in the | |||
| data sample.For "hamming", it directly calculates the Hamming distance between | |||
| the predicted pseudo label in the data sample and candidate. | |||
| mapping : dict, optional | |||
| A mapping from index in the base model to label. If not provided, a default | |||
| A mapping from index in the base model to label. If not provided, a default | |||
| order-based mapping is created. | |||
| max_revision : int or float, optional | |||
| The upper limit on the number of revisions for each data sample when | |||
| performing abductive reasoning. If float, denotes the fraction of the total | |||
| length that can be revised. A value of -1 implies no restriction on the | |||
| The upper limit on the number of revisions for each data sample when | |||
| performing abductive reasoning. If float, denotes the fraction of the total | |||
| length that can be revised. A value of -1 implies no restriction on the | |||
| number of revisions. Defaults to -1. | |||
| require_more_revision : int, optional | |||
| Specifies additional number of revisions permitted beyond the minimum required | |||
| @@ -37,52 +37,53 @@ class ReasonerBase: | |||
| use_zoopt : bool, optional | |||
| Whether to use the Zoopt library during abductive reasoning. Defaults to False. | |||
| """ | |||
| def __init__(self, | |||
| kb, | |||
| dist_func="confidence", | |||
| mapping=None, | |||
| max_revision=-1, | |||
| require_more_revision=0, | |||
| use_zoopt=False, | |||
| ): | |||
| def __init__( | |||
| self, | |||
| kb, | |||
| dist_func="confidence", | |||
| mapping=None, | |||
| max_revision=-1, | |||
| require_more_revision=0, | |||
| use_zoopt=False, | |||
| ): | |||
| if dist_func not in ["hamming", "confidence"]: | |||
| raise NotImplementedError("Valid options for dist_func include \"hamming\" and \"confidence\"") | |||
| raise NotImplementedError( | |||
| 'Valid options for dist_func include "hamming" and "confidence"' | |||
| ) | |||
| self.kb = kb | |||
| self.dist_func = dist_func | |||
| self.use_zoopt = use_zoopt | |||
| self.max_revision = max_revision | |||
| self.require_more_revision = require_more_revision | |||
| if mapping is None: | |||
| self.mapping = { | |||
| index: label for index, label in enumerate(self.kb.pseudo_label_list) | |||
| } | |||
| self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||
| else: | |||
| if not isinstance(mapping, dict): | |||
| raise TypeError("mapping should be dict") | |||
| for key, value in mapping.items(): | |||
| if not isinstance(key, int): | |||
| raise ValueError("All keys in the mapping must be integers") | |||
| if value not in self.kb.pseudo_label_list: | |||
| raise ValueError("All values in the mapping must be in the pseudo_label_list") | |||
| for key, value in mapping.items(): | |||
| if not isinstance(key, int): | |||
| raise ValueError("All keys in the mapping must be integers") | |||
| if value not in self.kb.pseudo_label_list: | |||
| raise ValueError("All values in the mapping must be in the pseudo_label_list") | |||
| self.mapping = mapping | |||
| self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
| def _get_one_candidate(self, data_sample, candidates): | |||
| """ | |||
| Due to the nondeterminism of abductive reasoning, there could be multiple candidates | |||
| satisfying the knowledge base. When this happens, return one candidate that has the | |||
| Due to the nondeterminism of abductive reasoning, there could be multiple candidates | |||
| satisfying the knowledge base. When this happens, return one candidate that has the | |||
| minimum cost. If no candidates are provided, an empty list is returned. | |||
| Parameters | |||
| ---------- | |||
| data_sample : ListData | |||
| Data sample. | |||
| candidates : List[List[Any]] | |||
| Multiple compatible candidates. | |||
| Multiple compatible candidates. | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| @@ -96,17 +97,17 @@ class ReasonerBase: | |||
| cost_array = self._get_cost_list(data_sample, candidates) | |||
| candidate = candidates[np.argmin(cost_array)] | |||
| return candidate | |||
| def _get_cost_list(self, data_sample, candidates): | |||
| """ | |||
| Get the list of costs between each candidate and the given data sample. The list is | |||
| Get the list of costs between each candidate and the given data sample. The list is | |||
| calculated based on one of the following distance functions: | |||
| - "hamming": Directly calculates the Hamming distance between the predicted pseudo | |||
| - "hamming": Directly calculates the Hamming distance between the predicted pseudo | |||
| label in the data sample and candidate. | |||
| - "confidence": Calculates the distance between the prediction and candidate based | |||
| on confidence derived from the predicted probability in the data | |||
| - "confidence": Calculates the distance between the prediction and candidate based | |||
| on confidence derived from the predicted probability in the data | |||
| sample. | |||
| Parameters | |||
| ---------- | |||
| data_sample : ListData | |||
| @@ -121,12 +122,9 @@ class ReasonerBase: | |||
| candidates = [[self.remapping[x] for x in c] for c in candidates] | |||
| return confidence_dist(data_sample.pred_prob, candidates) | |||
| def zoopt_get_solution( | |||
| self, symbol_num, data_sample, max_revision_num | |||
| ): | |||
| def zoopt_get_solution(self, symbol_num, data_sample, max_revision_num): | |||
| """ | |||
| Get the optimal solution using the Zoopt library. The solution is a list of | |||
| Get the optimal solution using the Zoopt library. The solution is a list of | |||
| boolean values, where '1' (True) indicates the indices chosen to be revised. | |||
| Parameters | |||
| @@ -138,9 +136,7 @@ class ReasonerBase: | |||
| max_revision_num : int | |||
| Specifies the maximum number of revisions allowed. | |||
| """ | |||
| dimension = Dimension( | |||
| size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num | |||
| ) | |||
| dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) | |||
| objective = Objective( | |||
| lambda sol: self.zoopt_revision_score(symbol_num, data_sample, sol), | |||
| dim=dimension, | |||
| @@ -149,16 +145,16 @@ class ReasonerBase: | |||
| parameter = Parameter(budget=100, intermediate_result=False, autoset=True) | |||
| solution = Opt.min(objective, parameter).get_x() | |||
| return solution | |||
| def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
| """ | |||
| Get the revision score for a solution. A lower score suggests that the Zoopt library | |||
| Get the revision score for a solution. A lower score suggests that the Zoopt library | |||
| has a higher preference for this solution. | |||
| """ | |||
| revision_idx = np.where(sol.get_x() != 0)[0] | |||
| candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| revision_idx) | |||
| candidates = self.kb.revise_at_idx( | |||
| data_sample.pred_pseudo_label, data_sample.Y, revision_idx | |||
| ) | |||
| if len(candidates) > 0: | |||
| return np.min(self._get_cost_list(data_sample, candidates)) | |||
| else: | |||
| @@ -166,7 +162,7 @@ class ReasonerBase: | |||
| def _constrain_revision_num(self, solution, max_revision_num): | |||
| """ | |||
| Constrain that the total number of revisions chosen by the solution does not exceed | |||
| Constrain that the total number of revisions chosen by the solution does not exceed | |||
| maximum number of revisions allowed. | |||
| """ | |||
| x = solution.get_x() | |||
| @@ -189,7 +185,7 @@ class ReasonerBase: | |||
| if max_revision < 0: | |||
| raise ValueError("If max_revision is an int, it must be non-negative.") | |||
| return max_revision | |||
| def abduce(self, data_sample): | |||
| """ | |||
| Perform abductive reasoning on the given data sample. | |||
| @@ -198,7 +194,7 @@ class ReasonerBase: | |||
| ---------- | |||
| data_sample : ListData | |||
| Data sample. | |||
| Returns | |||
| ------- | |||
| List[Any] | |||
| @@ -207,19 +203,21 @@ class ReasonerBase: | |||
| """ | |||
| symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
| if self.use_zoopt: | |||
| solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) | |||
| revision_idx = np.where(solution != 0)[0] | |||
| candidates = self.kb.revise_at_idx(data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| revision_idx) | |||
| candidates = self.kb.revise_at_idx( | |||
| data_sample.pred_pseudo_label, data_sample.Y, revision_idx | |||
| ) | |||
| else: | |||
| candidates = self.kb.abduce_candidates(data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| max_revision_num, | |||
| self.require_more_revision) | |||
| candidates = self.kb.abduce_candidates( | |||
| data_sample.pred_pseudo_label, | |||
| data_sample.Y, | |||
| max_revision_num, | |||
| self.require_more_revision, | |||
| ) | |||
| candidate = self._get_one_candidate(data_sample, candidates) | |||
| return candidate | |||
| @@ -228,9 +226,7 @@ class ReasonerBase: | |||
| Perform abductive reasoning on the given prediction data samples. | |||
| For detailed information, refer to `abduce`. | |||
| """ | |||
| abduced_pseudo_label = [ | |||
| self.abduce(data_sample) for data_sample in data_samples | |||
| ] | |||
| abduced_pseudo_label = [self.abduce(data_sample) for data_sample in data_samples] | |||
| data_samples.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| @@ -297,9 +297,15 @@ class ListData(BaseDataElement): | |||
| def __len__(self) -> int: | |||
| """int: The length of ListData.""" | |||
| if len(self._data_fields) > 0: | |||
| one_element = next(iter(self._data_fields)) | |||
| return len(getattr(self, one_element)) | |||
| # return len(self.values()[0]) | |||
| iterator = iter(self._data_fields) | |||
| data = next(iterator) | |||
| while getattr(self, data) is None: | |||
| try: | |||
| data = next(iterator) | |||
| except StopIteration: | |||
| break | |||
| if getattr(self, data) is None: | |||
| raise ValueError("All data fields are None.") | |||
| else: | |||
| return 0 | |||
| return len(getattr(self, data)) | |||
| @@ -22,10 +22,10 @@ def flatten(nested_list): | |||
| TypeError | |||
| If the input object is not a list. | |||
| """ | |||
| if not isinstance(nested_list, list): | |||
| raise TypeError("Input must be of type list.") | |||
| # if not isinstance(nested_list, list): | |||
| # raise TypeError("Input must be of type list.") | |||
| if not nested_list or not isinstance(nested_list[0], (list, tuple)): | |||
| if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)): | |||
| return nested_list | |||
| return list(chain.from_iterable(nested_list)) | |||
| @@ -3,7 +3,7 @@ MNIST Add | |||
| MNIST Add was first introduced in [1] and the inputs of this task are pairs of MNIST images and the outputs are their sums. The dataset looks like this: | |||
| .. image:: ../img/image_1.jpg | |||
| .. image:: ../img/Datasets_1.png | |||
| :width: 350px | |||
| :align: center | |||
| @@ -11,9 +11,5 @@ MNIST Add was first introduced in [1] and the inputs of this task are pairs of M | |||
| The ``gt_pseudo_label`` is only used to test the performance of the machine learning model and is not used in the training phase. | |||
| In the Abductive Learning framework, the inference process is as follows: | |||
| .. image:: ../img/image_2.jpg | |||
| :width: 700px | |||
| [1] Robin Manhaeve, Sebastijan Dumancic, Angelika Kimmig, Thomas Demeester, and Luc De Raedt. Deepproblog: Neural probabilistic logic programming. In Advances in Neural Information Processing Systems 31 (NeurIPS), pages 3749-3759.2018. | |||
| @@ -1,19 +1,18 @@ | |||
| import os | |||
| import os.path as osp | |||
| import cv2 | |||
| import torch | |||
| import torchvision | |||
| import pickle | |||
| import numpy as np | |||
| import random | |||
| from collections import defaultdict | |||
| from torch.utils.data import Dataset | |||
| from torchvision.transforms import transforms | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| def get_data(img_dataset, train): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| Y = [] | |||
| X, Y = [], [] | |||
| if train: | |||
| positive = img_dataset["train:positive"] | |||
| negative = img_dataset["train:negative"] | |||
| @@ -39,15 +38,12 @@ def get_data(img_dataset, train): | |||
| def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| transform = transforms.Compose([transforms.ToTensor()]) | |||
| X = [] | |||
| img_dir = osp.join(CURRENT_DIR, "mnist_images") | |||
| for label in labels: | |||
| label_path = os.path.join( | |||
| "./datasets/mnist_images", label | |||
| ) | |||
| label_path = osp.join(img_dir, label) | |||
| img_path_list = os.listdir(label_path) | |||
| for img_path in img_path_list: | |||
| img = cv2.imread( | |||
| os.path.join(label_path, img_path), cv2.IMREAD_GRAYSCALE | |||
| ) | |||
| img = cv2.imread(osp.join(label_path, img_path), cv2.IMREAD_GRAYSCALE) | |||
| img = cv2.resize(img, (image_size[1], image_size[0])) | |||
| X.append(np.array(img, dtype=np.float32)) | |||
| @@ -58,24 +54,6 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)): | |||
| return X, Y | |||
| # def get_pretrain_data(train_data, image_size=(28, 28, 1)): | |||
| # X = [] | |||
| # for label in [0, 1]: | |||
| # for _, equation_list in train_data[label].items(): | |||
| # for equation in equation_list: | |||
| # X = X + equation | |||
| # X = np.array(X) | |||
| # index = np.array(list(range(len(X)))) | |||
| # np.random.shuffle(index) | |||
| # X = X[index] | |||
| # X = [img for img in X] | |||
| # Y = [img.copy().reshape(image_size[0] * image_size[1] * image_size[2]) for img in X] | |||
| # return X, Y | |||
| def divide_equations_by_len(equations, labels): | |||
| equations_by_len = {1: defaultdict(list), 0: defaultdict(list)} | |||
| for i, equation in enumerate(equations): | |||
| @@ -104,22 +82,15 @@ def split_equation(equations_by_len, prop_train, prop_val): | |||
| def get_hed(dataset="mnist", train=True): | |||
| if dataset == "mnist": | |||
| with open( | |||
| "./datasets/mnist_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
| elif dataset == "random": | |||
| with open( | |||
| "./datasets/random_equation_data_train_len_26_test_len_26_sys_2_.pk", | |||
| "rb", | |||
| ) as f: | |||
| img_dataset = pickle.load(f) | |||
| file = osp.join(CURRENT_DIR, "random_equation_data_train_len_26_test_len_26_sys_2_.pk") | |||
| else: | |||
| raise Exception("Undefined dataset") | |||
| raise ValueError("Undefined dataset") | |||
| with open(file, "rb") as f: | |||
| img_dataset = pickle.load(f) | |||
| X, _, Y = get_data(img_dataset, train) | |||
| equations_by_len = divide_equations_by_len(X, Y) | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| from collections import defaultdict | |||
| from typing import Any, List | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| @@ -8,6 +9,7 @@ from abl.learning import ABLModel, BasicNN | |||
| from abl.bridge import SimpleBridge | |||
| from abl.evaluation import BaseMetric | |||
| from abl.dataset import BridgeDataset, RegressionDataset | |||
| from abl.structures import ListData | |||
| from abl.utils import print_log | |||
| from examples.hed.utils import gen_mappings, InfiniteSampler | |||
| @@ -59,54 +61,35 @@ class HEDBridge(SimpleBridge): | |||
| "model": cls_autoencoder.base_model.state_dict(), | |||
| } | |||
| torch.save( | |||
| save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth") | |||
| ) | |||
| torch.save(save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth")) | |||
| self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth")) | |||
| def abduce_pseudo_label( | |||
| self, | |||
| pred_label, | |||
| pred_prob, | |||
| pseudo_label, | |||
| Y, | |||
| max_revision=-1, | |||
| require_more_revision=0, | |||
| ): | |||
| return self.reasoner.abduce( | |||
| (pred_label, pred_prob, pseudo_label, Y), | |||
| max_revision, | |||
| require_more_revision, | |||
| ) | |||
| def select_mapping_and_abduce(self, pred_label, pred_prob, Y): | |||
| def abduce_pseudo_label(self, data_samples: ListData): | |||
| candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) | |||
| mapping_score = [] | |||
| pred_pseudo_label_list = [] | |||
| abduced_pseudo_label_list = [] | |||
| for _mapping in candidate_mappings: | |||
| self.reasoner.mapping = _mapping | |||
| self.reasoner.set_remapping() | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||
| pred_label, pred_prob, pred_pseudo_label, Y, 20 | |||
| ) | |||
| mapping_score.append( | |||
| len(abduced_pseudo_label) - abduced_pseudo_label.count([]) | |||
| ) | |||
| pred_pseudo_label_list.append(pred_pseudo_label) | |||
| self.reasoner.remapping = dict(zip(_mapping.values(), _mapping.keys())) | |||
| self.idx_to_pseudo_label(data_samples) | |||
| abduced_pseudo_label = self.reasoner.abduce(data_samples) | |||
| mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([])) | |||
| abduced_pseudo_label_list.append(abduced_pseudo_label) | |||
| max_revisible_instances = max(mapping_score) | |||
| return_idx = mapping_score.index(max_revisible_instances) | |||
| self.reasoner.mapping = candidate_mappings[return_idx] | |||
| self.reasoner.set_remapping() | |||
| return abduced_pseudo_label_list[return_idx] | |||
| self.reasoner.mapping = dict( | |||
| zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys()) | |||
| ) | |||
| data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx] | |||
| return data_samples.abduced_pseudo_label | |||
| def check_training_impact(self, filtered_X, filtered_abduced_label, X): | |||
| character_accuracy = self.model.valid(filtered_X, filtered_abduced_label) | |||
| revisible_ratio = len(filtered_X) / len(X) | |||
| def check_training_impact(self, filtered_data_samples, data_samples): | |||
| character_accuracy = self.model.valid(filtered_data_samples) | |||
| revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X) | |||
| print_log( | |||
| f"Revisible ratio is {revisible_ratio:.3f}, Character accuracy is {character_accuracy:.3f}", | |||
| logger="current", | |||
| @@ -136,10 +119,7 @@ class HEDBridge(SimpleBridge): | |||
| pred_label, _ = self.predict(X) | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| consistent_num = sum( | |||
| [ | |||
| self.reasoner.kb.consist_rule(instance, rule) | |||
| for instance in pred_pseudo_label | |||
| ] | |||
| [self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label] | |||
| ) | |||
| return consistent_num / len(X) | |||
| @@ -178,12 +158,13 @@ class HEDBridge(SimpleBridge): | |||
| return rules | |||
| @staticmethod | |||
| def filter_empty(X, Z): | |||
| filtered_X, filtered_Z = [], [] | |||
| for x, z in zip(X, Z): | |||
| if len(z) > 0: | |||
| filtered_X.append(x), filtered_Z.append(z) | |||
| return (filtered_X, filtered_Z) | |||
| def filter_empty(data_samples: ListData): | |||
| consistent_dix = [ | |||
| i | |||
| for i in range(len(data_samples.abduced_pseudo_label)) | |||
| if len(data_samples.abduced_pseudo_label[i]) > 0 | |||
| ] | |||
| return data_samples[consistent_dix] | |||
| @staticmethod | |||
| def select_rules(rule_dict): | |||
| @@ -203,79 +184,49 @@ class HEDBridge(SimpleBridge): | |||
| add_nums_dict[add_nums] = r | |||
| return list(rule_dict) | |||
| def train( | |||
| self, | |||
| train_data, | |||
| val_data, | |||
| select_num=10, | |||
| min_len=5, | |||
| max_len=8, | |||
| ): | |||
| def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | |||
| return super().idx_to_pseudo_label(data_samples) | |||
| def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8): | |||
| for equation_len in range(min_len, max_len): | |||
| print_log( | |||
| f"============== equation_len: {equation_len}-{equation_len + 1} ================", | |||
| logger="current", | |||
| ) | |||
| train_X = train_data[1][equation_len] + train_data[1][equation_len + 1] | |||
| train_Y = [None] * len(train_X) | |||
| dataset = BridgeDataset(train_X, None, train_Y) | |||
| sampler = InfiniteSampler(len(dataset)) | |||
| data_loader = DataLoader( | |||
| dataset, | |||
| sampler=sampler, | |||
| batch_size=select_num, | |||
| collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| ) | |||
| condition_num = 0 | |||
| for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||
| pred_label, pred_prob = self.predict(X) | |||
| if equation_len == min_len: | |||
| abduced_pseudo_label = self.select_mapping_and_abduce( | |||
| pred_label, pred_prob, Y | |||
| ) | |||
| else: | |||
| pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| abduced_pseudo_label = self.abduce_pseudo_label( | |||
| pred_label, pred_prob, pred_pseudo_label, Y, 20 | |||
| ) | |||
| filtered_X, filtered_abduced_pseudo_label = self.filter_empty( | |||
| X, abduced_pseudo_label | |||
| ) | |||
| if len(filtered_X) == 0: | |||
| continue | |||
| filtered_abduced_label = self.pseudo_label_to_label( | |||
| filtered_abduced_pseudo_label | |||
| ) | |||
| min_loss = self.model.train(filtered_X, filtered_abduced_label) | |||
| X = train_data[1][equation_len] + train_data[1][equation_len + 1] | |||
| Y = [None] * len(X) | |||
| data_samples = self.data_preprocess(X, None, Y) | |||
| sampler = InfiniteSampler(len(data_samples)) | |||
| for seg_idx, select_idx in enumerate(sampler): | |||
| sub_data_samples = data_samples[select_idx] | |||
| self.predict(sub_data_samples) | |||
| # self.idx_to_pseudo_label(sub_data_samples) | |||
| self.abduce_pseudo_label(sub_data_samples) | |||
| filtered_sub_data_samples = self.filter_empty(sub_data_samples) | |||
| self.pseudo_label_to_idx(filtered_sub_data_samples) | |||
| loss = self.model.train(filtered_sub_data_samples) | |||
| print_log( | |||
| f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] minimal_loss is {min_loss:.5f}", | |||
| f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] model loss is {loss:.5f}", | |||
| logger="current", | |||
| ) | |||
| if self.check_training_impact(filtered_X, filtered_abduced_label, X): | |||
| if self.check_training_impact(filtered_sub_data_samples, sub_data_samples): | |||
| condition_num += 1 | |||
| else: | |||
| condition_num = 0 | |||
| if condition_num >= 5: | |||
| print_log( | |||
| f"Now checking if we can go to next course", logger="current" | |||
| ) | |||
| print_log(f"Now checking if we can go to next course", logger="current") | |||
| rules = self.get_rules_from_data( | |||
| dataset, samples_per_rule=3, samples_num=50 | |||
| ) | |||
| print_log( | |||
| f"Learned rules from data: " + str(rules), logger="current" | |||
| data_samples, samples_per_rule=3, samples_num=50 | |||
| ) | |||
| print_log(f"Learned rules from data: " + str(rules), logger="current") | |||
| seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
| if seems_good: | |||
| self.model.save( | |||
| save_path=f"./weights/eq_len_{equation_len}.pth" | |||
| ) | |||
| self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") | |||
| break | |||
| else: | |||
| if equation_len == min_len: | |||
| @@ -285,8 +236,83 @@ class HEDBridge(SimpleBridge): | |||
| ) | |||
| self.model.load(load_path="./weights/pretrain_weights.pth") | |||
| else: | |||
| self.model.load( | |||
| load_path=f"./weights/eq_len_{equation_len - 1}.pth" | |||
| ) | |||
| self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") | |||
| condition_num = 0 | |||
| print_log("Reload Model and retrain", logger="current") | |||
| # def train( | |||
| # self, | |||
| # train_data, | |||
| # val_data, | |||
| # segment_size=10, | |||
| # min_len=5, | |||
| # max_len=8, | |||
| # ): | |||
| # for equation_len in range(min_len, max_len): | |||
| # print_log( | |||
| # f"============== equation_len: {equation_len}-{equation_len + 1} ================", | |||
| # logger="current", | |||
| # ) | |||
| # train_X = train_data[1][equation_len] + train_data[1][equation_len + 1] | |||
| # train_Y = [None] * len(train_X) | |||
| # # data_samples = self.data_preprocess(train_X, None, train_Y) | |||
| # dataset = BridgeDataset(train_X, None, train_Y) | |||
| # sampler = InfiniteSampler(len(dataset)) | |||
| # data_loader = DataLoader( | |||
| # dataset, | |||
| # sampler=sampler, | |||
| # batch_size=segment_size, | |||
| # collate_fn=lambda data_list: [list(data) for data in zip(*data_list)], | |||
| # ) | |||
| # condition_num = 0 | |||
| # for seg_idx, (X, Z, Y) in enumerate(data_loader): | |||
| # pred_label, pred_prob = self.predict(ListData(X=X)) | |||
| # if equation_len == min_len: | |||
| # abduced_pseudo_label = self.select_mapping_and_abduce(pred_label, pred_prob, Y) | |||
| # else: | |||
| # pred_pseudo_label = self.label_to_pseudo_label(pred_label) | |||
| # abduced_pseudo_label = self.abduce_pseudo_label( | |||
| # pred_label, pred_prob, pred_pseudo_label, Y, 20 | |||
| # ) | |||
| # filtered_X, filtered_abduced_pseudo_label = self.filter_empty( | |||
| # X, abduced_pseudo_label | |||
| # ) | |||
| # if len(filtered_X) == 0: | |||
| # continue | |||
| # filtered_abduced_label = self.pseudo_label_to_label(filtered_abduced_pseudo_label) | |||
| # min_loss = self.model.train(filtered_X, filtered_abduced_label) | |||
| # print_log( | |||
| # f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}] minimal_loss is {min_loss:.5f}", | |||
| # logger="current", | |||
| # ) | |||
| # if self.check_training_impact(filtered_X, filtered_abduced_label, X): | |||
| # condition_num += 1 | |||
| # else: | |||
| # condition_num = 0 | |||
| # if condition_num >= 5: | |||
| # print_log(f"Now checking if we can go to next course", logger="current") | |||
| # rules = self.get_rules_from_data(dataset, samples_per_rule=3, samples_num=50) | |||
| # print_log(f"Learned rules from data: " + str(rules), logger="current") | |||
| # seems_good = self.check_rule_quality(rules, val_data, equation_len) | |||
| # if seems_good: | |||
| # self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth") | |||
| # break | |||
| # else: | |||
| # if equation_len == min_len: | |||
| # print_log( | |||
| # "Learned mapping is: " + str(self.reasoner.mapping), | |||
| # logger="current", | |||
| # ) | |||
| # self.model.load(load_path="./weights/pretrain_weights.pth") | |||
| # else: | |||
| # self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth") | |||
| # condition_num = 0 | |||
| # print_log("Reload Model and retrain", logger="current") | |||
| @@ -280,7 +280,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.18" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| @@ -0,0 +1,159 @@ | |||
| import os.path as osp | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| from abl.evaluation import SemanticsMetric, SymbolMetric | |||
| from abl.learning import ABLModel, BasicNN | |||
| from abl.reasoning import PrologKB, ReasonerBase | |||
| from abl.utils import ABLLogger, print_log, reform_list | |||
| from examples.hed.datasets.get_hed import get_hed, split_equation | |||
| from examples.hed.hed_bridge import HEDBridge | |||
| from examples.models.nn import SymbolNet | |||
| # Build logger | |||
| print_log("Abductive Learning on the HED example.", logger="current") | |||
| # Retrieve the directory of the Log file and define the directory for saving the model weights. | |||
| log_dir = ABLLogger.get_current_instance().log_dir | |||
| weights_dir = osp.join(log_dir, "weights") | |||
| ### Logic Part | |||
| # Initialize knowledge base and abducer | |||
| class HedKB(PrologKB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| def consist_rule(self, exs, rules): | |||
| rules = str(rules).replace("'", "") | |||
| return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 | |||
| def abduce_rules(self, pred_res): | |||
| prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) | |||
| if len(prolog_result) == 0: | |||
| return None | |||
| prolog_rules = prolog_result[0]["X"] | |||
| rules = [rule.value for rule in prolog_rules] | |||
| return rules | |||
| class HedReasoner(ReasonerBase): | |||
| def revise_at_idx(self, data_sample): | |||
| revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0] | |||
| candidate = self.kb.revise_at_idx( | |||
| data_sample.pred_pseudo_label, data_sample.Y, revision_idx | |||
| ) | |||
| return candidate | |||
| def zoopt_revision_score(self, symbol_num, data_sample, sol): | |||
| revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label) | |||
| data_sample.revision_flag = revision_flag | |||
| lefted_idxs = [i for i in range(len(data_sample.pred_idx))] | |||
| candidate_size = [] | |||
| while lefted_idxs: | |||
| idxs = [] | |||
| idxs.append(lefted_idxs.pop(0)) | |||
| max_candidate_idxs = [] | |||
| found = False | |||
| for idx in range(-1, len(data_sample.pred_idx)): | |||
| if (not idx in idxs) and (idx >= 0): | |||
| idxs.append(idx) | |||
| candidate = self.revise_at_idx(data_sample[idxs]) | |||
| if len(candidate) == 0: | |||
| if len(idxs) > 1: | |||
| idxs.pop() | |||
| else: | |||
| if len(idxs) > len(max_candidate_idxs): | |||
| found = True | |||
| max_candidate_idxs = idxs.copy() | |||
| removed = [i for i in lefted_idxs if i in max_candidate_idxs] | |||
| if found: | |||
| candidate_size.append(len(removed) + 1) | |||
| lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] | |||
| candidate_size.sort() | |||
| score = 0 | |||
| import math | |||
| for i in range(0, len(candidate_size)): | |||
| score -= math.exp(-i) * candidate_size[i] | |||
| return score | |||
| def abduce(self, data_sample): | |||
| symbol_num = data_sample.elements_num("pred_pseudo_label") | |||
| max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) | |||
| solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num) | |||
| data_sample.revision_flag = reform_list( | |||
| solution.astype(np.int32), data_sample.pred_pseudo_label | |||
| ) | |||
| abduced_pseudo_label = [] | |||
| for single_instance in data_sample: | |||
| single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label] | |||
| candidates = self.revise_at_idx(single_instance) | |||
| if len(candidates) == 0: | |||
| abduced_pseudo_label.append([]) | |||
| else: | |||
| abduced_pseudo_label.append(candidates[0][0]) | |||
| data_sample.abduced_pseudo_label = abduced_pseudo_label | |||
| return abduced_pseudo_label | |||
| def abduce_rules(self, pred_res): | |||
| return self.kb.abduce_rules(pred_res) | |||
| import os | |||
| CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) | |||
| kb = HedKB( | |||
| pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl") | |||
| ) | |||
| reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20) | |||
| ### Machine Learning Part | |||
| # Build necessary components for BasicNN | |||
| cls = SymbolNet(num_classes=4) | |||
| criterion = nn.CrossEntropyLoss() | |||
| optimizer = torch.optim.Adam(cls.parameters(), lr=0.001) | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| # Build BasicNN | |||
| # The function of BasicNN is to wrap NN models into the form of an sklearn estimator | |||
| base_model = BasicNN( | |||
| cls, | |||
| criterion, | |||
| optimizer, | |||
| device, | |||
| batch_size=32, | |||
| num_epochs=1, | |||
| save_interval=1, | |||
| save_dir=weights_dir, | |||
| ) | |||
| # Build ABLModel | |||
| # The main function of the ABL model is to serialize data and | |||
| # provide a unified interface for different machine learning models | |||
| model = ABLModel(base_model) | |||
| ### Metric | |||
| # Set up metrics | |||
| metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")] | |||
| ### Bridge Machine Learning and Logic Reasoning | |||
| bridge = HEDBridge(model, reasoner, metric_list) | |||
| ### Dataset | |||
| total_train_data = get_hed(train=True) | |||
| train_data, val_data = split_equation(total_train_data, 3, 1) | |||
| test_data = get_hed(train=False) | |||
| ### Train and Test | |||
| bridge.pretrain("examples/hed/weights") | |||
| bridge.train(train_data, val_data) | |||
| @@ -12,7 +12,8 @@ class InfiniteSampler(sampler.Sampler): | |||
| while True: | |||
| order = np.random.permutation(self.num_samples) | |||
| for i in range(self.num_samples): | |||
| yield order[i] | |||
| yield order[i: i + 10] | |||
| i += 10 | |||
| def __len__(self): | |||
| return None | |||
| @@ -79,7 +79,7 @@ | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "# # Build ABLModel\n", | |||
| "# Build ABLModel\n", | |||
| "# The main function of the ABL model is to serialize data and \n", | |||
| "# provide a unified interface for different machine learning models\n", | |||
| "model = ABLModel(base_model)" | |||
| @@ -193,7 +193,7 @@ | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.8.16" | |||
| "version": "3.8.18" | |||
| }, | |||
| "orig_nbformat": 4, | |||
| "vscode": { | |||
| @@ -58,9 +58,8 @@ class TestPrologKB(object): | |||
| def test_logic_forward_pl2(self, kb_hed): | |||
| consist_exs = [ | |||
| [1, 1, "+", 0, "=", 1, 1], | |||
| [1, "+", 1, "=", 1, 0], | |||
| [0, "+", 0, "=", 0], | |||
| [1, "+", 1, "=", 0], | |||
| [1, "+", 1, "=", 1], | |||
| ] | |||
| inconsist_exs = [ | |||
| [1, 1, "+", 0, "=", 1, 1], | |||