You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_basic_nn.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import DataLoader, TensorDataset
  6. class TestBasicNN(object):
  7. # Test initialization
  8. def test_initialization(self, basic_nn_instance):
  9. assert basic_nn_instance.model is not None
  10. assert isinstance(basic_nn_instance.criterion, nn.Module)
  11. assert isinstance(basic_nn_instance.optimizer, optim.Optimizer)
  12. # Test training epoch
  13. def test_train_epoch(self, basic_nn_instance):
  14. X = torch.randn(32, 1, 28, 28)
  15. y = torch.randint(0, 10, (32,))
  16. data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
  17. loss = basic_nn_instance.train_epoch(data_loader)
  18. assert isinstance(loss, float)
  19. # Test fit method
  20. def test_fit(self, basic_nn_instance):
  21. X = torch.randn(32, 1, 28, 28)
  22. y = torch.randint(0, 10, (32,))
  23. data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
  24. loss = basic_nn_instance.fit(data_loader)
  25. assert isinstance(loss, float)
  26. # Test predict method
  27. def test_predict(self, basic_nn_instance):
  28. X = list(torch.randn(32, 1, 28, 28))
  29. predictions = basic_nn_instance.predict(X=X)
  30. assert len(predictions) == len(X)
  31. assert numpy.isin(predictions, list(range(10))).all()
  32. # Test predict_proba method
  33. def test_predict_proba(self, basic_nn_instance):
  34. X = list(torch.randn(32, 1, 28, 28))
  35. predict_proba = basic_nn_instance.predict_proba(X=X)
  36. assert len(predict_proba) == len(X)
  37. assert ((0 <= predict_proba) & (predict_proba <= 1)).all()
  38. # Test score method
  39. def test_score(self, basic_nn_instance):
  40. X = torch.randn(32, 1, 28, 28)
  41. y = torch.randint(0, 10, (32,))
  42. data_loader = DataLoader(TensorDataset(X, y), batch_size=4)
  43. accuracy = basic_nn_instance.score(data_loader)
  44. assert 0 <= accuracy <= 1
  45. # Test save and load methods
  46. def test_save_load(self, basic_nn_instance, tmp_path):
  47. model_path = tmp_path / "model.pth"
  48. basic_nn_instance.save(epoch_id=1, save_path=str(model_path))
  49. assert model_path.exists()
  50. basic_nn_instance.load(load_path=str(model_path))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.