You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_optimizer.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import unittest
  2. import torch
  3. from fastNLP.core.optimizer import SGD, Adam
  4. class TestOptim(unittest.TestCase):
  5. def test_SGD(self):
  6. optim = SGD(model_params=torch.nn.Linear(10, 3).parameters())
  7. self.assertTrue("lr" in optim.__dict__["settings"])
  8. self.assertTrue("momentum" in optim.__dict__["settings"])
  9. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  10. self.assertTrue(isinstance(res, torch.optim.SGD))
  11. optim = SGD(lr=0.001)
  12. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  13. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  14. self.assertTrue(isinstance(res, torch.optim.SGD))
  15. optim = SGD(lr=0.002, momentum=0.989)
  16. self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
  17. self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)
  18. optim = SGD(0.001)
  19. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  20. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  21. self.assertTrue(isinstance(res, torch.optim.SGD))
  22. with self.assertRaises(TypeError):
  23. _ = SGD("???")
  24. with self.assertRaises(TypeError):
  25. _ = SGD(0.001, lr=0.002)
  26. def test_Adam(self):
  27. optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
  28. self.assertTrue("lr" in optim.__dict__["settings"])
  29. self.assertTrue("weight_decay" in optim.__dict__["settings"])
  30. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  31. self.assertTrue(isinstance(res, torch.optim.Adam))
  32. optim = Adam(lr=0.001)
  33. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  34. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  35. self.assertTrue(isinstance(res, torch.optim.Adam))
  36. optim = Adam(lr=0.002, weight_decay=0.989)
  37. self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
  38. self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)
  39. optim = Adam(0.001)
  40. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  41. res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
  42. self.assertTrue(isinstance(res, torch.optim.Adam))