Browse Source

更新Optimizer:

optimizer.SGD(lr=xxx);如果没有传入parameters,则在trainer中帮他加入parameter
tags/v0.2.0^2
FengZiYjun 7 years ago
parent
commit
8a7077fed2
4 changed files with 44 additions and 55 deletions
  1. +18
    -51
      fastNLP/core/optimizer.py
  2. +4
    -4
      fastNLP/core/trainer.py
  3. +21
    -0
      test/core/test_optimizer.py
  4. +1
    -0
      test/core/test_trainer.py

+ 18
- 51
fastNLP/core/optimizer.py View File

@@ -2,61 +2,28 @@ import torch


class Optimizer(object):
"""Wrapper of optimizer from framework
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs

1. Adam: lr (float), weight_decay (float)
2. AdaGrad
3. RMSProp
4. SGD: lr (float), momentum (float)

"""
class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)

def __init__(self, optimizer_name, **kwargs):
"""
:param optimizer_name: str, the name of the optimizer
:param kwargs: the arguments

"""
self.optim_name = optimizer_name
self.kwargs = kwargs

@property
def name(self):
"""The name of the optimizer.

:return: str
"""
return self.optim_name
def construct_from_pytorch(self, model_params):
if self.model_params is None:
self.model_params = model_params
return torch.optim.SGD(self.model_params, **self.settings)

@property
def params(self):
"""The arguments used to create the optimizer.

:return: dict of (str, *)
"""
return self.kwargs
class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)

def construct_from_pytorch(self, model_params):
"""Construct a optimizer from framework over given model parameters."""

if self.optim_name in ["SGD", "sgd"]:
if "lr" in self.kwargs:
if "momentum" not in self.kwargs:
self.kwargs["momentum"] = 0
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
else:
raise ValueError("requires learning rate for SGD optimizer")

elif self.optim_name in ["adam", "Adam"]:
if "lr" in self.kwargs:
if "weight_decay" not in self.kwargs:
self.kwargs["weight_decay"] = 0
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
weight_decay=self.kwargs["weight_decay"])
else:
raise ValueError("requires learning rate for Adam optimizer")

else:
raise NotImplementedError

return optimizer
if self.model_params is None:
self.model_params = model_params
return torch.optim.Adam(self.model_params, **self.settings)

+ 4
- 4
fastNLP/core/trainer.py View File

@@ -12,7 +12,7 @@ from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
@@ -31,7 +31,7 @@ class Trainer(object):
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True,
metric_key=None,
**kwargs):
super(Trainer, self).__init__()
@@ -178,7 +178,7 @@ class Trainer(object):
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self._better_eval_result(res):
self.save_model(self.model,
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))

def _mode(self, model, is_test=False):
@@ -225,7 +225,7 @@ class Trainer(object):
"""
return self.losser(predict, truth)

def save_model(self, model, model_name, only_param=False):
def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)


+ 21
- 0
test/core/test_optimizer.py View File

@@ -0,0 +1,21 @@
import unittest

import torch

from fastNLP.core.optimizer import SGD


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)

def test_case_2(self):
with self.assertRaises(RuntimeError):
_ = SGD(0.001)

+ 1
- 0
test/core/test_trainer.py View File

@@ -4,3 +4,4 @@ import unittest
class TestTrainer(unittest.TestCase):
def test_case_1(self):
pass


Loading…
Cancel
Save