Browse Source

[MNT] change the parameter of BasicModel, device, to optional

pull/3/head
Gao Enhao 3 years ago
parent
commit
112c5a7e24
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      abl/models/basic_model.py

+ 3
- 3
abl/models/basic_model.py View File

@@ -146,8 +146,8 @@ class BasicModel:
The loss function used for training.
optimizer : torch.nn.Module
The optimizer used for training.
device : torch.device
The device on which the model will be trained or used for prediction.
device : torch.device, optional
The device on which the model will be trained or used for prediction, by default torch.decive("cpu").
batch_size : int, optional
The batch size used for training, by default 1.
num_epochs : int, optional
@@ -223,7 +223,7 @@ class BasicModel:
model: torch.nn.Module,
criterion: torch.nn.Module,
optimizer: torch.nn.Module,
device: torch.device,
device: torch.device = torch.device("cpu"),
batch_size: int = 1,
num_epochs: int = 1,
stop_loss: Optional[float] = 0.01,


Loading…
Cancel
Save