|
|
|
@@ -664,6 +664,12 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
dataset = self.to_task_dataset(torch_dataset, mode) |
|
|
|
return dataset |
|
|
|
|
|
|
|
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): |
|
|
|
return build_optimizer(self.model, cfg=cfg, default_args=default_args) |
|
|
|
|
|
|
|
def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): |
|
|
|
return build_lr_scheduler(cfg=cfg, default_args=default_args) |
|
|
|
|
|
|
|
def create_optimizer_and_scheduler(self): |
|
|
|
""" Create optimizer and lr scheduler |
|
|
|
|
|
|
|
@@ -680,7 +686,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
optim_options = {} |
|
|
|
if optimizer_cfg is not None: |
|
|
|
optim_options = optimizer_cfg.pop('options', {}) |
|
|
|
optimizer = build_optimizer(self.model, cfg=optimizer_cfg) |
|
|
|
optimizer = self.build_optimizer(cfg=optimizer_cfg) |
|
|
|
|
|
|
|
if lr_scheduler is None: |
|
|
|
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) |
|
|
|
@@ -691,7 +697,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
if lr_scheduler_cfg is not None: |
|
|
|
assert optimizer is not None |
|
|
|
lr_options = lr_scheduler_cfg.pop('options', {}) |
|
|
|
lr_scheduler = build_lr_scheduler( |
|
|
|
lr_scheduler = self.build_lr_scheduler( |
|
|
|
cfg=lr_scheduler_cfg, default_args={'optimizer': optimizer}) |
|
|
|
|
|
|
|
self.optimizer = optimizer |
|
|
|
|