|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import numpy
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader, TensorDataset
-
-
- class TestBasicNN(object):
- # Test initialization
- def test_initialization(self, basic_nn_instance):
- assert basic_nn_instance.model is not None
- assert isinstance(basic_nn_instance.criterion, nn.Module)
- assert isinstance(basic_nn_instance.optimizer, optim.Optimizer)
-
- # Test training epoch
- def test_train_epoch(self, basic_nn_instance):
- X = torch.randn(32, 1, 28, 28)
- y = torch.randint(0, 10, (32,))
- data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
- loss = basic_nn_instance.train_epoch(data_loader)
- assert isinstance(loss, float)
-
- # Test fit method
- def test_fit(self, basic_nn_instance):
- X = torch.randn(32, 1, 28, 28)
- y = torch.randint(0, 10, (32,))
- data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
- loss = basic_nn_instance.fit(data_loader)
- assert isinstance(loss, float)
-
- # Test predict method
- def test_predict(self, basic_nn_instance):
- X = list(torch.randn(32, 1, 28, 28))
- predictions = basic_nn_instance.predict(X=X)
- assert len(predictions) == len(X)
- assert numpy.isin(predictions, list(range(10))).all()
-
- # Test predict_proba method
- def test_predict_proba(self, basic_nn_instance):
- X = list(torch.randn(32, 1, 28, 28))
- predict_proba = basic_nn_instance.predict_proba(X=X)
- assert len(predict_proba) == len(X)
- assert ((0 <= predict_proba) & (predict_proba <= 1)).all()
-
- # Test score method
- def test_score(self, basic_nn_instance):
- X = torch.randn(32, 1, 28, 28)
- y = torch.randint(0, 10, (32,))
- data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
- accuracy = basic_nn_instance.score(data_loader)
- assert 0 <= accuracy <= 1
-
- # Test save and load methods
- def test_save_load(self, basic_nn_instance, tmp_path):
- model_path = tmp_path / "model.pth"
- basic_nn_instance.save(epoch_id=1, save_path=str(model_path))
- assert model_path.exists()
- basic_nn_instance.load(load_path=str(model_path))
|