|
|
|
@@ -146,8 +146,8 @@ class BasicModel: |
|
|
|
The loss function used for training. |
|
|
|
optimizer : torch.nn.Module |
|
|
|
The optimizer used for training. |
|
|
|
device : torch.device |
|
|
|
The device on which the model will be trained or used for prediction. |
|
|
|
device : torch.device, optional |
|
|
|
The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). |
|
|
|
batch_size : int, optional |
|
|
|
The batch size used for training, by default 1. |
|
|
|
num_epochs : int, optional |
|
|
|
@@ -223,7 +223,7 @@ class BasicModel: |
|
|
|
model: torch.nn.Module, |
|
|
|
criterion: torch.nn.Module, |
|
|
|
optimizer: torch.nn.Module, |
|
|
|
device: torch.device, |
|
|
|
device: torch.device = torch.device("cpu"), |
|
|
|
batch_size: int = 1, |
|
|
|
num_epochs: int = 1, |
|
|
|
stop_loss: Optional[float] = 0.01, |
|
|
|
|