From 112c5a7e24c72aafe25607d2e7bc75306176ec4b Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 30 Mar 2023 20:21:55 +0800 Subject: [PATCH] [MNT] change the parameter of BasicModel, device, to optional --- abl/models/basic_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/abl/models/basic_model.py b/abl/models/basic_model.py index f26a11a..4e9f870 100644 --- a/abl/models/basic_model.py +++ b/abl/models/basic_model.py @@ -146,8 +146,8 @@ class BasicModel: The loss function used for training. optimizer : torch.nn.Module The optimizer used for training. - device : torch.device - The device on which the model will be trained or used for prediction. + device : torch.device, optional + The device on which the model will be trained or used for prediction, by default torch.decive("cpu"). batch_size : int, optional The batch size used for training, by default 1. num_epochs : int, optional @@ -223,7 +223,7 @@ class BasicModel: model: torch.nn.Module, criterion: torch.nn.Module, optimizer: torch.nn.Module, - device: torch.device, + device: torch.device = torch.device("cpu"), batch_size: int = 1, num_epochs: int = 1, stop_loss: Optional[float] = 0.01,