@@ -84,17 +84,17 @@ To build the machine learning part, we need to wrap our machine learning model i
# The number of pseudo labels is 10
cls = LeNet5(num_classes=10)
Aside from the network, we need to define a criterion, an optimizer, and a device so as to create a ``BasicNN`` object. This class implements ``fit``, ``predict``, ``predict_proba`` and several other methods to enable the PyTorch-based neural network to work as a scikit-learn model.
Aside from the network, we need to define a loss_fn, an optimizer, and a device so as to create a ``BasicNN`` object. This class implements ``fit``, ``predict``, ``predict_proba`` and several other methods to enable the PyTorch-based neural network to work as a scikit-learn model.