`Learn the Basics `_ ||
`Quick Start `_ ||
`Dataset & Data Structure `_ ||
**Machine Learning Part** ||
`Reasoning Part `_ ||
`Evaluation Metrics `_ ||
`Bridge `_
Machine Learning Part
=====================
``ABLModel`` class serves as a unified interface to all machine learning models. Its constructor, the ``__init__`` method, takes a singular argument, ``base_model``. This argument denotes the fundamental machine learning model, which must implement the ``fit`` and ``predict`` methods.
.. code:: python
class ABLModel:
def __init__(self, base_model: Any) -> None:
if not (hasattr(base_model, "fit") and hasattr(base_model, "predict")):
raise NotImplementedError("The base_model should implement fit and predict methods.")
self.base_model = base_model
All scikit-learn models satisify this requiremnts, so we can directly use the model to create an instance of ``ABLModel``. For example, we can customize our machine learning model by
.. code:: python
import sklearn
from abl.learning import ABLModel
base_model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3)
model = ABLModel(base_model)
For a PyTorch-based neural network, we first need to encapsulate it within a ``BasicNN`` object and then use this object to instantiate an instance of ``ABLModel``. For example, we can customize our machine learning model by
.. code:: python
# Load a PyTorch-based neural network
cls = torchvision.models.resnet18(pretrained=True)
# criterion and optimizer are used for training
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters())
base_model = BasicNN(cls, criterion, optimizer)
model = ABLModel(base_model)
Besides ``fit`` and ``predict``, ``BasicNN`` also implements the following methods:
+---------------------------+----------------------------------------+
| Method | Function |
+===========================+========================================+
| train_epoch(data_loader) | Train the neural network for one epoch.|
+---------------------------+----------------------------------------+
| predict_proba(X) | Predict the class probabilities of X. |
+---------------------------+----------------------------------------+
| score(X, y) | Calculate the accuracy of the model on |
| | test data. |
+---------------------------+----------------------------------------+
| save(epoch_id, save_path) | Save the model. |
+---------------------------+----------------------------------------+
| load(load_path) | Load the model. |
+---------------------------+----------------------------------------+