You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_optimizer.py 1.3 kB

123456789101112131415161718192021222324252627282930313233343536
  1. import unittest
  2. import torch
  3. from fastNLP.core.optimizer import SGD, Adam
  4. class TestOptim(unittest.TestCase):
  5. def test_SGD(self):
  6. optim = SGD(torch.nn.Linear(10, 3).parameters())
  7. self.assertTrue("lr" in optim.__dict__["settings"])
  8. self.assertTrue("momentum" in optim.__dict__["settings"])
  9. optim = SGD(lr=0.001)
  10. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  11. optim = SGD(lr=0.002, momentum=0.989)
  12. self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
  13. self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)
  14. with self.assertRaises(RuntimeError):
  15. _ = SGD("???")
  16. with self.assertRaises(RuntimeError):
  17. _ = SGD(0.001, lr=0.002)
  18. def test_Adam(self):
  19. optim = Adam(torch.nn.Linear(10, 3).parameters())
  20. self.assertTrue("lr" in optim.__dict__["settings"])
  21. self.assertTrue("weight_decay" in optim.__dict__["settings"])
  22. optim = Adam(lr=0.001)
  23. self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
  24. optim = Adam(lr=0.002, weight_decay=0.989)
  25. self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
  26. self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)