| @@ -33,8 +33,9 @@ class Optimizer(object): | |||||
| def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def _get_require_grads_param(self, params): | |||||
| @staticmethod | |||||
| def _get_require_grads_param(params): | |||||
| """ | """ | ||||
| 将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
| @@ -43,6 +44,7 @@ class Optimizer(object): | |||||
| """ | """ | ||||
| return [param for param in params if param.requires_grad] | return [param for param in params if param.requires_grad] | ||||
| class NullOptimizer(Optimizer): | class NullOptimizer(Optimizer): | ||||
| """ | """ | ||||
| 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | ||||
| @@ -113,7 +115,8 @@ class Adam(Optimizer): | |||||
| class AdamW(TorchOptimizer): | class AdamW(TorchOptimizer): | ||||
| r""" | r""" | ||||
| 对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||||
| 对AdamW的实现,该实现在pytorch 1.2.0版本中已经出现,https://github.com/pytorch/pytorch/pull/21250。 | |||||
| 这里加入以适配低版本的pytorch | |||||
| .. todo:: | .. todo:: | ||||
| 翻译成中文 | 翻译成中文 | ||||
| @@ -2,7 +2,7 @@ import unittest | |||||
| import torch | import torch | ||||
| from fastNLP import SGD, Adam | |||||
| from fastNLP import SGD, Adam, AdamW | |||||
| class TestOptim(unittest.TestCase): | class TestOptim(unittest.TestCase): | ||||
| @@ -52,3 +52,12 @@ class TestOptim(unittest.TestCase): | |||||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
| res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
| self.assertTrue(isinstance(res, torch.optim.Adam)) | self.assertTrue(isinstance(res, torch.optim.Adam)) | ||||
| def test_AdamW(self): | |||||
| optim = AdamW(params=torch.nn.Linear(10, 3).parameters()) | |||||
| self.assertTrue('lr' in optim.defaults) | |||||
| self.assertTrue('weight_decay' in optim.defaults) | |||||
| optim = AdamW(params=torch.nn.Linear(10, 3).parameters(), lr=0.002, weight_decay=0.989) | |||||
| self.assertEqual(optim.defaults['lr'], 0.002) | |||||
| self.assertTrue(optim.defaults['weight_decay'], 0.989) | |||||