|
|
|
@@ -2,61 +2,28 @@ import torch |
|
|
|
|
|
|
|
|
|
|
|
class Optimizer(object): |
|
|
|
"""Wrapper of optimizer from framework |
|
|
|
def __init__(self, model_params, **kwargs): |
|
|
|
if model_params is not None and not isinstance(model_params, torch.Tensor): |
|
|
|
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) |
|
|
|
self.model_params = model_params |
|
|
|
self.settings = kwargs |
|
|
|
|
|
|
|
1. Adam: lr (float), weight_decay (float) |
|
|
|
2. AdaGrad |
|
|
|
3. RMSProp |
|
|
|
4. SGD: lr (float), momentum (float) |
|
|
|
|
|
|
|
""" |
|
|
|
class SGD(Optimizer): |
|
|
|
def __init__(self, model_params=None, lr=0.001, momentum=0.9): |
|
|
|
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) |
|
|
|
|
|
|
|
def __init__(self, optimizer_name, **kwargs): |
|
|
|
""" |
|
|
|
:param optimizer_name: str, the name of the optimizer |
|
|
|
:param kwargs: the arguments |
|
|
|
|
|
|
|
""" |
|
|
|
self.optim_name = optimizer_name |
|
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
@property |
|
|
|
def name(self): |
|
|
|
"""The name of the optimizer. |
|
|
|
|
|
|
|
:return: str |
|
|
|
""" |
|
|
|
return self.optim_name |
|
|
|
def construct_from_pytorch(self, model_params): |
|
|
|
if self.model_params is None: |
|
|
|
self.model_params = model_params |
|
|
|
return torch.optim.SGD(self.model_params, **self.settings) |
|
|
|
|
|
|
|
@property |
|
|
|
def params(self): |
|
|
|
"""The arguments used to create the optimizer. |
|
|
|
|
|
|
|
:return: dict of (str, *) |
|
|
|
""" |
|
|
|
return self.kwargs |
|
|
|
class Adam(Optimizer): |
|
|
|
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): |
|
|
|
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
|
|
def construct_from_pytorch(self, model_params): |
|
|
|
"""Construct a optimizer from framework over given model parameters.""" |
|
|
|
|
|
|
|
if self.optim_name in ["SGD", "sgd"]: |
|
|
|
if "lr" in self.kwargs: |
|
|
|
if "momentum" not in self.kwargs: |
|
|
|
self.kwargs["momentum"] = 0 |
|
|
|
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) |
|
|
|
else: |
|
|
|
raise ValueError("requires learning rate for SGD optimizer") |
|
|
|
|
|
|
|
elif self.optim_name in ["adam", "Adam"]: |
|
|
|
if "lr" in self.kwargs: |
|
|
|
if "weight_decay" not in self.kwargs: |
|
|
|
self.kwargs["weight_decay"] = 0 |
|
|
|
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], |
|
|
|
weight_decay=self.kwargs["weight_decay"]) |
|
|
|
else: |
|
|
|
raise ValueError("requires learning rate for Adam optimizer") |
|
|
|
|
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
return optimizer |
|
|
|
if self.model_params is None: |
|
|
|
self.model_params = model_params |
|
|
|
return torch.optim.Adam(self.model_params, **self.settings) |