更新Optimizer: 多种初始化方法 1. SGD() 2. SGD(0.01) 3. SGD(lr=0.01) 4. SGD(lr=0.01, momentum=0.9) 5. SGD(model.parameters(), lr=0.1, momentum=0.9)tags/v0.2.0^2
| @@ -3,14 +3,41 @@ import torch | |||
| class Optimizer(object): | |||
| def __init__(self, model_params, **kwargs): | |||
| if model_params is not None and not isinstance(model_params, torch.Tensor): | |||
| raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) | |||
| if model_params is not None and not hasattr(model_params, "__next__"): | |||
| raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params))) | |||
| self.model_params = model_params | |||
| self.settings = kwargs | |||
| class SGD(Optimizer): | |||
| def __init__(self, model_params=None, lr=0.001, momentum=0.9): | |||
| def __init__(self, *args, **kwargs): | |||
| model_params, lr, momentum = None, 0.01, 0.9 | |||
| if len(args) == 0 and len(kwargs) == 0: | |||
| # SGD() | |||
| pass | |||
| elif len(args) == 1 and len(kwargs) == 0: | |||
| if isinstance(args[0], float) or isinstance(args[0], int): | |||
| # SGD(0.001) | |||
| lr = args[0] | |||
| elif hasattr(args[0], "__next__"): | |||
| # SGD(model.parameters()) args[0] is a generator | |||
| model_params = args[0] | |||
| else: | |||
| raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
| elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
| # SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) | |||
| if len(args) == 1: | |||
| if hasattr(args[0], "__next__"): | |||
| model_params = args[0] | |||
| else: | |||
| raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
| if not all(key in ("lr", "momentum") for key in kwargs): | |||
| raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) | |||
| lr = kwargs.get("lr", 0.01) | |||
| momentum = kwargs.get("momentum", 0.9) | |||
| else: | |||
| raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
| super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||
| def construct_from_pytorch(self, model_params): | |||
| @@ -20,7 +47,30 @@ class SGD(Optimizer): | |||
| class Adam(Optimizer): | |||
| def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): | |||
| def __init__(self, *args, **kwargs): | |||
| model_params, lr, weight_decay = None, 0.01, 0.9 | |||
| if len(args) == 0 and len(kwargs) == 0: | |||
| pass | |||
| elif len(args) == 1 and len(kwargs) == 0: | |||
| if isinstance(args[0], float) or isinstance(args[0], int): | |||
| lr = args[0] | |||
| elif hasattr(args[0], "__next__"): | |||
| model_params = args[0] | |||
| else: | |||
| raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
| elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
| if len(args) == 1: | |||
| if hasattr(args[0], "__next__"): | |||
| model_params = args[0] | |||
| else: | |||
| raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
| if not all(key in ("lr", "weight_decay") for key in kwargs): | |||
| raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) | |||
| lr = kwargs.get("lr", 0.01) | |||
| weight_decay = kwargs.get("weight_decay", 0.9) | |||
| else: | |||
| raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
| super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||
| def construct_from_pytorch(self, model_params): | |||
| @@ -56,7 +56,10 @@ class Trainer(object): | |||
| # increase_better is True. It means the exp result gets better if the indicator increases. | |||
| # It is true by default. | |||
| self.increase_better = False if metric_key[0] == "-" else True | |||
| self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
| if metric_key is not None: | |||
| self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
| else: | |||
| self.metric_key = None | |||
| # prepare loss | |||
| losser = _prepare_losser(losser) | |||
| @@ -144,12 +147,13 @@ class Trainer(object): | |||
| del self._summary_writer | |||
| def _train_epoch(self, data_iterator, model, epoch, start): | |||
| """Training process in one epoch. | |||
| """ | |||
| kwargs should contain: | |||
| - n_print: int, print training information every n steps. | |||
| - start: time.time(), the starting time of this step. | |||
| - epoch: int, | |||
| :param data_iterator: | |||
| :param model: | |||
| :param epoch: | |||
| :param start: | |||
| :return: | |||
| """ | |||
| for batch_x, batch_y in data_iterator: | |||
| # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
| @@ -188,7 +192,7 @@ class Trainer(object): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| :param model: a PyTorch model | |||
| :param is_test: bool, whether in test mode or not. | |||
| :param bool is_test: whether in test mode or not. | |||
| """ | |||
| if is_test: | |||
| @@ -263,7 +267,7 @@ class Trainer(object): | |||
| else: | |||
| # metric_key is set | |||
| if self.metric_key not in metric_dict: | |||
| raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") | |||
| raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") | |||
| indicator_val = metric_dict[self.metric_key] | |||
| is_better = True | |||
| @@ -2,20 +2,43 @@ import unittest | |||
| import torch | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.optimizer import SGD, Adam | |||
| class TestOptim(unittest.TestCase): | |||
| def test_case(self): | |||
| optim = SGD(torch.LongTensor(10)) | |||
| print(optim.__dict__) | |||
| def test_SGD(self): | |||
| optim = SGD(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue("lr" in optim.__dict__["settings"]) | |||
| self.assertTrue("momentum" in optim.__dict__["settings"]) | |||
| optim_2 = SGD(lr=0.001) | |||
| print(optim_2.__dict__) | |||
| optim = SGD(0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| optim_2 = SGD(lr=0.002, momentum=0.989) | |||
| print(optim_2.__dict__) | |||
| optim = SGD(lr=0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| def test_case_2(self): | |||
| optim = SGD(lr=0.002, momentum=0.989) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
| self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||
| with self.assertRaises(RuntimeError): | |||
| _ = SGD("???") | |||
| with self.assertRaises(RuntimeError): | |||
| _ = SGD(0.001) | |||
| _ = SGD(0.001, lr=0.002) | |||
| with self.assertRaises(RuntimeError): | |||
| _ = SGD(lr=0.009, shit=9000) | |||
| def test_Adam(self): | |||
| optim = Adam(torch.nn.Linear(10, 3).parameters()) | |||
| self.assertTrue("lr" in optim.__dict__["settings"]) | |||
| self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||
| optim = Adam(0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| optim = Adam(lr=0.001) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
| optim = Adam(lr=0.002, weight_decay=0.989) | |||
| self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
| self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | |||