|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- from typing import Any, Dict, List
- import torch
-
- from detectron2.config import CfgNode
-
- from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR
-
-
- def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
- """
- Build an optimizer from config.
- """
- params: List[Dict[str, Any]] = []
- for key, value in model.named_parameters():
- if not value.requires_grad:
- continue
- lr = cfg.SOLVER.BASE_LR
- weight_decay = cfg.SOLVER.WEIGHT_DECAY
- if key.endswith("norm.weight") or key.endswith("norm.bias"):
- weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM
- elif key.endswith(".bias"):
- # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
- # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
- # hyperparameters are by default exactly the same as for regular
- # weights.
- lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
- weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
- params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
-
- optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM)
- return optimizer
-
-
- def build_lr_scheduler(
- cfg: CfgNode, optimizer: torch.optim.Optimizer
- ) -> torch.optim.lr_scheduler._LRScheduler:
- """
- Build a LR scheduler from config.
- """
- name = cfg.SOLVER.LR_SCHEDULER_NAME
- if name == "WarmupMultiStepLR":
- return WarmupMultiStepLR(
- optimizer,
- cfg.SOLVER.STEPS,
- cfg.SOLVER.GAMMA,
- warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
- warmup_iters=cfg.SOLVER.WARMUP_ITERS,
- warmup_method=cfg.SOLVER.WARMUP_METHOD,
- )
- elif name == "WarmupCosineLR":
- return WarmupCosineLR(
- optimizer,
- cfg.SOLVER.MAX_ITER,
- warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
- warmup_iters=cfg.SOLVER.WARMUP_ITERS,
- warmup_method=cfg.SOLVER.WARMUP_METHOD,
- )
- else:
- raise ValueError("Unknown LR scheduler: {}".format(name))
|