diff --git a/README.md b/README.md index c83ebb9..0749b0b 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ We provide several examples in `examples/`. Each example is stored in a separate + [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) + [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) ++ [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia) ## References diff --git a/ablkit/reasoning/reasoner.py b/ablkit/reasoning/reasoner.py index 2905007..5626468 100644 --- a/ablkit/reasoning/reasoner.py +++ b/ablkit/reasoning/reasoner.py @@ -180,8 +180,8 @@ class Reasoner: candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] return avg_confidence_dist(data_example.pred_prob, candidates_idxs) else: - candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] - cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) + candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] + cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results) if len(cost_list) != len(candidates): raise ValueError( "The length of the array returned by dist_func must be equal to the number " diff --git a/examples/bdd_oia/README.md b/examples/bdd_oia/README.md new file mode 100644 index 0000000..0900a4f --- /dev/null +++ b/examples/bdd_oia/README.md @@ -0,0 +1,50 @@ +# BDD-OIA + +This example shows an implementation of [BDD-OIA](https://twizwei.github.io/bddoia_project/) task. The BDD-OIA dataset comprises frames extracted from driving scene videos, which are utilized for autonomous driving predictions. Each frame is annotated with 4 binary labels, indicating the possible actions, namely $\textsf{move forward}$, $\textsf{stop}$, $\textsf{turn left}$, $\textsf{turn right}$. Each frame is also annotated with 21 intermediate binary concepts such as $\textsf{red light}$, $\textsf{road clear}$, etc., underlying the reasons for the possible actions. + +The objective is to predict possible actions for each frame. During training, we only make use of the label supervision, along with a knowledge base, which comprises information about the relations between concepts and actions, e.g., $\textsf{red light} \lor \textsf{traffic sign} \lor \textsf{obstacle} \implies \textsf{stop}$. The training set consists of 16,000 frames, while the test set contains 4,500 annotated data points. + +Before usage, the dataset was pre-processed by [Marconato et al. (2023)](https://proceedings.neurips.cc/paper_files/paper/2023/file/e560202b6e779a82478edb46c6f8f4dd-Paper-Conference.pdf) using a pretrained Faster-RCNN model on BDD-100k, in conjunction with the first module in CBM-AUC [(Sawada & Nakamura, 2022)](https://arxiv.org/abs/2202.01459), resulting in embeddings of dimension 2048. + +## Run + +```bash +pip install -r requirements.txt +cd dataset +unzip dataset.zip +cd .. +python main.py +``` + +## Usage + +```bash +usage: main.py [-h] [--no-cuda] [--epochs EPOCHS] [--lr LR] + [--batch-size BATCH_SIZE] [--loops LOOPS] + [--segment_size SEGMENT_SIZE] + [--save_interval SAVE_INTERVAL] + [--max-revision MAX_REVISION] + [--require-more-revision REQUIRE_MORE_REVISION] + +BDD_OIA example + +optional arguments: + -h, --help show this help message and exit + --no-cuda disables CUDA training + --epochs EPOCHS number of epochs in each learning loop iteration + (default : 1) + --lr LR base model learning rate (default : 0.002) + --batch-size BATCH_SIZE + base model batch size (default : 32) + --loops LOOPS + number of loop iterations (default : 2) + --segment_size SEGMENT_SIZE + segment size (default : 0.01) + --save_interval SAVE_INTERVAL + save interval (default : 1) + --max-revision MAX_REVISION + maximum revision in reasoner (default : 3) + -require-more-revision REQUIRE_MORE_REVISION + require more revision in reasoner (default : 3) + +``` diff --git a/examples/bdd_oia/bridge.py b/examples/bdd_oia/bridge.py new file mode 100644 index 0000000..0139277 --- /dev/null +++ b/examples/bdd_oia/bridge.py @@ -0,0 +1,23 @@ +import numpy as np +from typing import List, Any +from ablkit.data import ListData +from ablkit.bridge import SimpleBridge + +class BDDBridge(SimpleBridge): + def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: + pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] + pred_pseudo_label = [] + for sub_list in pred_idx: + sub_list = sub_list.squeeze() # 1 x nc -> nc + pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list]) + data_examples.pred_pseudo_label = pred_pseudo_label + return data_examples.pred_pseudo_label + + def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: + abduced_pseudo_label = data_examples.abduced_pseudo_label + abduced_idx = [] + for sub_list in abduced_pseudo_label: + sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) + abduced_idx.append(sub_list) + data_examples.abduced_idx = abduced_idx + return data_examples.abduced_idx \ No newline at end of file diff --git a/examples/bdd_oia/dataset/data_util.py b/examples/bdd_oia/dataset/data_util.py new file mode 100644 index 0000000..23bb9a8 --- /dev/null +++ b/examples/bdd_oia/dataset/data_util.py @@ -0,0 +1,18 @@ +import os +import numpy as np + +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + + +def get_dataset(fname, get_pseudo_label=True): + fname = os.path.join(CURRENT_DIR, fname) + data = np.load(fname) + X = data["X"] + X = [[emb.astype(np.float32)] for emb in X] + pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None + Y = data["Y"][:, :4].astype(int).tolist() + Y = [tuple(y) for y in Y] + return X, pseudo_label, Y + +if __name__ == '__main__': + dataset = get_dataset("val.npz") \ No newline at end of file diff --git a/examples/bdd_oia/dataset/dataset.zip b/examples/bdd_oia/dataset/dataset.zip new file mode 100644 index 0000000..82f6914 Binary files /dev/null and b/examples/bdd_oia/dataset/dataset.zip differ diff --git a/examples/bdd_oia/main.py b/examples/bdd_oia/main.py new file mode 100644 index 0000000..219f76e --- /dev/null +++ b/examples/bdd_oia/main.py @@ -0,0 +1,147 @@ +import argparse +import os.path as osp +import numpy as np +import torch +from torch import optim + +from ablkit.data.evaluation import SymbolAccuracy +from ablkit.reasoning import Reasoner +from ablkit.utils import ABLLogger, print_log + +from models.nn import * +from models.bdd_nn import BDDNN +from models.bdd_model import BDDABLModel +from reasoning.bddkb import BDDKB +from dataset.data_util import get_dataset +from bridge import BDDBridge +from metric import BDDReasoningMetric + + +def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): + pred_prob = data_example.pred_prob.T # nc x 1 + pred_prob = np.concatenate([1-pred_prob, pred_prob], axis=1) # nc x 2 + cols = np.arange(len(candidates_idxs[0]))[None, :] + corr_prob = pred_prob[cols, candidates_idxs] + costs = - np.sum(np.log(corr_prob + 1e-6), axis=1) + return costs + +def get_args(): + parser = argparse.ArgumentParser(description="BDD-OIA example") + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="number of epochs in each learning loop iteration (default : 1)", + ) + parser.add_argument( + "--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)" + ) + parser.add_argument( + "--batch-size", type=int, default=32, help="base model batch size (default : 32)" + ) + parser.add_argument( + "--loops", type=int, default=2, help="number of loop iterations (default : 2)" + ) + parser.add_argument( + "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)" + ) + parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") + parser.add_argument( + "--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)" + ) + parser.add_argument( + "--require-more-revision", + type=int, + default=3, + help="require more revision in reasoner (default : 3)", + ) + + args = parser.parse_args() + return args + +def main(): + args = get_args() + + # Build logger + print_log("Abductive Learning on the BDD-OIA example.", logger="current") + + # -- Working with Data ------------------------------ + print_log("Working with Data.", logger="current") + train_data = get_dataset(fname="train.npz", get_pseudo_label=True) + val_data = get_dataset(fname="val.npz", get_pseudo_label=True) + test_data = get_dataset(fname="test.npz", get_pseudo_label=True) + + # -- Building the Learning Part --------------------- + print_log("Building the Learning Part.", logger="current") + + # Build necessary components for BDDNN + cls = ConceptNet() + loss_fn = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(cls.parameters(), lr=args.lr) + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=args.lr, + pct_start=0.15, + epochs=args.loops, + steps_per_epoch=int(1 / args.segment_size) + 1, + ) + + # Build BDDNN + base_model = BDDNN( + cls, + loss_fn, + optimizer, + scheduler=scheduler, + device=device, + batch_size=args.batch_size, + num_epochs=args.epochs, + ) + + # Build ABLModel + model = BDDABLModel(base_model) + + # -- Building the Reasoning Part -------------------- + print_log("Building the Reasoning Part.", logger="current") + + # Build knowledge base + kb = BDDKB() + + # Create reasoner + reasoner = Reasoner( + kb, + dist_func=multi_label_confidence_dist, + max_revision=args.max_revision, + require_more_revision=args.require_more_revision + ) + + # -- Building Evaluation Metrics -------------------- + print_log("Building Evaluation Metrics.", logger="current") + metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")] + + # -- Bridging Learning and Reasoning ---------------- + print_log("Bridge Learning and Reasoning.", logger="current") + bridge = BDDBridge(model, reasoner, metric_list) + + # Retrieve the directory of the Log file and define the directory for saving the model weights. + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + + # Train and Test + bridge.train( + train_data=train_data, + val_data=val_data, + loops=args.loops, + segment_size=args.segment_size, + save_interval=args.save_interval, + save_dir=weights_dir, + ) + bridge.test(test_data) + + +if __name__ == "__main__": + main() diff --git a/examples/bdd_oia/metric.py b/examples/bdd_oia/metric.py new file mode 100644 index 0000000..c73c2f9 --- /dev/null +++ b/examples/bdd_oia/metric.py @@ -0,0 +1,24 @@ +from typing import Optional + +from ablkit.reasoning import KBBase +from ablkit.data import BaseMetric, ListData + +class BDDReasoningMetric(BaseMetric): + def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: + super().__init__(prefix) + self.kb = kb + + def process(self, data_examples: ListData) -> None: + pred_pseudo_label_list = data_examples.pred_pseudo_label + y_list = data_examples.Y + x_list = data_examples.X + for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): + pred_y = self.kb.logic_forward(pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()) + for py, yy in zip(pred_y, y): + self.results.append(int(py == yy)) + + def compute_metrics(self) -> dict: + results = self.results + metrics = dict() + metrics["reasoning_accuracy"] = sum(results) / len(results) + return metrics \ No newline at end of file diff --git a/examples/bdd_oia/models/bdd_model.py b/examples/bdd_oia/models/bdd_model.py new file mode 100644 index 0000000..bdddb7e --- /dev/null +++ b/examples/bdd_oia/models/bdd_model.py @@ -0,0 +1,24 @@ +from typing import Dict + +import numpy as np +from ablkit.data import ListData +from ablkit.learning import ABLModel +from ablkit.utils import reform_list + +class BDDABLModel(ABLModel): + def predict(self, data_examples: ListData) -> Dict: + model = self.base_model + data_X = data_examples.flatten("X") + if hasattr(model, "predict_proba"): + prob = model.predict_proba(X=data_X) + label = np.where(prob > 0.5, 1, 0).astype(int) + prob = reform_list(prob, data_examples.X) + else: + prob = None + label = model.predict(X=data_X) + label = reform_list(label, data_examples.X) + + data_examples.pred_idx = label + data_examples.pred_prob = prob + + return {"label": label, "prob": prob} \ No newline at end of file diff --git a/examples/bdd_oia/models/bdd_nn.py b/examples/bdd_oia/models/bdd_nn.py new file mode 100644 index 0000000..5fdd378 --- /dev/null +++ b/examples/bdd_oia/models/bdd_nn.py @@ -0,0 +1,94 @@ +import logging +import os +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy +import torch +from torch.utils.data import DataLoader, Dataset + +from ablkit.learning import BasicNN, PredictionDataset, ClassificationDataset +from ablkit.utils.logger import print_log + + +class MultiLabelClassificationDataset(ClassificationDataset): + def __init__(self, X: List[Any], Y: List[int], transform: Optional[Callable[..., Any]] = None): + if (not isinstance(X, list)) or (not isinstance(Y, list)): + raise ValueError("X and Y should be of type list.") + self.X = X + self.Y = torch.FloatTensor(numpy.stack(Y, axis=0)) # float32 for BCELoss + self.transform = transform + +class BDDNN(BasicNN): + + def predict( + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, + ) -> numpy.ndarray: + if data_loader is not None and X is not None: + print_log( + "Predict the class of input data in data_loader instead of X.", + logger="current", + level=logging.WARNING, + ) + + if data_loader is None: + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + pred_probs = self._predict(data_loader).sigmoid() + pred = torch.where(pred_probs > 0.5, 1, 0).int() + return pred.cpu().numpy() + + def predict_proba( + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, + ) -> numpy.ndarray: + if data_loader is not None and X is not None: + print_log( + "Predict the class probability of input data in data_loader instead of X.", + logger="current", + level=logging.WARNING, + ) + + if data_loader is None: + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + pred_probs = self._predict(data_loader).sigmoid() # B x NC + return pred_probs.cpu().numpy() + + def _data_loader( + self, + X: Optional[List[Any]], + y: Optional[List[int]] = None, + shuffle: Optional[bool] = True, + ) -> DataLoader: + if X is None: + raise ValueError("X should not be None.") + if y is None: + y = [0] * len(X) + if not len(y) == len(X): + raise ValueError("X and y should have equal length.") + + dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + return data_loader diff --git a/examples/bdd_oia/models/nn.py b/examples/bdd_oia/models/nn.py new file mode 100644 index 0000000..216be0e --- /dev/null +++ b/examples/bdd_oia/models/nn.py @@ -0,0 +1,22 @@ +from torch import nn + +class SimpleNet(nn.Module): + def __init__(self, num_features=2048, num_concepts=21): + super(SimpleNet, self).__init__() + self.fc = nn.Linear(num_features, num_concepts) + + def forward(self, x): + return self.fc(x) + +class ConceptNet(nn.Module): + def __init__(self, num_features=2048, num_concepts=21): + super(ConceptNet, self).__init__() + intermidate_dim = 256 + self.fc = nn.Sequential( + nn.Linear(num_features, intermidate_dim), + nn.SiLU(), + nn.Linear(intermidate_dim, num_concepts) + ) + + def forward(self, x): + return self.fc(x) diff --git a/examples/bdd_oia/reasoning/bddkb.py b/examples/bdd_oia/reasoning/bddkb.py new file mode 100644 index 0000000..3548c69 --- /dev/null +++ b/examples/bdd_oia/reasoning/bddkb.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +from ablkit.reasoning import KBBase + +class BDDKB(KBBase): + def __init__(self, pseudo_label_list=None): + if pseudo_label_list is None: + pseudo_label_list = [0, 1] + super().__init__(pseudo_label_list) + + def logic_forward(self, attrs): + """ + Abduction space + (0, 1, 0, 0) 610812 + (0, 1, 0, 1) 75012 + (0, 1, 1, 0) 75012 + (0, 1, 1, 1) 9212 + (1, 0, 0, 0) 12996 + (1, 0, 0, 1) 1596 + (1, 0, 1, 0) 1596 + (1, 0, 1, 1) 196 + """ + assert len(attrs) == 21 + green_light, follow, road_clear, red_light, traffic_sign, car, person, rider, other_obstacle, \ + left_lane, left_green_light, left_follow, no_left_lane, left_obstacle, left_solid_line, \ + right_lane, right_green_light, right_follow, no_right_lane, right_obstacle, right_solid_line = attrs + + illegal_return = (0, 0, 0, 0) + if red_light == green_light == 1: + return illegal_return + obstacle = car or person or rider or other_obstacle + if road_clear == obstacle: + return illegal_return + move_forward = (green_light or follow or road_clear) + stop = (red_light or traffic_sign or obstacle) + if stop: + move_forward = 0 + + can_turn_left = left_lane or left_green_light or left_follow + cannot_turn_left = no_left_lane or left_obstacle or left_solid_line + turn_left = can_turn_left and int(not cannot_turn_left) + + can_turn_right = right_lane or right_green_light or right_follow + cannot_turn_right = no_right_lane or right_obstacle or right_solid_line + turn_right = can_turn_right and int(not cannot_turn_right) + + return move_forward, stop, turn_left, turn_right \ No newline at end of file diff --git a/examples/bdd_oia/requirements.txt b/examples/bdd_oia/requirements.txt new file mode 100644 index 0000000..c238889 --- /dev/null +++ b/examples/bdd_oia/requirements.txt @@ -0,0 +1,2 @@ +torch +ablkit