diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index ff2ee40e..72737b81 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -2,61 +2,28 @@ import torch class Optimizer(object): - """Wrapper of optimizer from framework + def __init__(self, model_params, **kwargs): + if model_params is not None and not isinstance(model_params, torch.Tensor): + raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) + self.model_params = model_params + self.settings = kwargs - 1. Adam: lr (float), weight_decay (float) - 2. AdaGrad - 3. RMSProp - 4. SGD: lr (float), momentum (float) - """ +class SGD(Optimizer): + def __init__(self, model_params=None, lr=0.001, momentum=0.9): + super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) - def __init__(self, optimizer_name, **kwargs): - """ - :param optimizer_name: str, the name of the optimizer - :param kwargs: the arguments - - """ - self.optim_name = optimizer_name - self.kwargs = kwargs - - @property - def name(self): - """The name of the optimizer. - - :return: str - """ - return self.optim_name + def construct_from_pytorch(self, model_params): + if self.model_params is None: + self.model_params = model_params + return torch.optim.SGD(self.model_params, **self.settings) - @property - def params(self): - """The arguments used to create the optimizer. - :return: dict of (str, *) - """ - return self.kwargs +class Adam(Optimizer): + def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): + super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): - """Construct a optimizer from framework over given model parameters.""" - - if self.optim_name in ["SGD", "sgd"]: - if "lr" in self.kwargs: - if "momentum" not in self.kwargs: - self.kwargs["momentum"] = 0 - optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) - else: - raise ValueError("requires learning rate for SGD optimizer") - - elif self.optim_name in ["adam", "Adam"]: - if "lr" in self.kwargs: - if "weight_decay" not in self.kwargs: - self.kwargs["weight_decay"] = 0 - optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], - weight_decay=self.kwargs["weight_decay"]) - else: - raise ValueError("requires learning rate for Adam optimizer") - - else: - raise NotImplementedError - - return optimizer + if self.model_params is None: + self.model_params = model_params + return torch.optim.Adam(self.model_params, **self.settings) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d4bedb6f..fb9ba25b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -12,7 +12,7 @@ from fastNLP.core.batch import Batch from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics -from fastNLP.core.optimizer import Optimizer +from fastNLP.core.optimizer import Adam from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester @@ -31,7 +31,7 @@ class Trainer(object): def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", - optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True, metric_key=None, **kwargs): super(Trainer, self).__init__() @@ -178,7 +178,7 @@ class Trainer(object): for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) if self.save_path is not None and self._better_eval_result(res): - self.save_model(self.model, + self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) def _mode(self, model, is_test=False): @@ -225,7 +225,7 @@ class Trainer(object): """ return self.losser(predict, truth) - def save_model(self, model, model_name, only_param=False): + def _save_model(self, model, model_name, only_param=False): model_name = os.path.join(self.save_path, model_name) if only_param: torch.save(model.state_dict(), model_name) diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py new file mode 100644 index 00000000..26e47d43 --- /dev/null +++ b/test/core/test_optimizer.py @@ -0,0 +1,21 @@ +import unittest + +import torch + +from fastNLP.core.optimizer import SGD + + +class TestOptim(unittest.TestCase): + def test_case(self): + optim = SGD(torch.LongTensor(10)) + print(optim.__dict__) + + optim_2 = SGD(lr=0.001) + print(optim_2.__dict__) + + optim_2 = SGD(lr=0.002, momentum=0.989) + print(optim_2.__dict__) + + def test_case_2(self): + with self.assertRaises(RuntimeError): + _ = SGD(0.001) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 7c0a1a9d..08df6a49 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -4,3 +4,4 @@ import unittest class TestTrainer(unittest.TestCase): def test_case_1(self): pass +