You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

customize_runtime.md 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Tutorial 5: Customize Runtime Settings
  2. ## Customize optimization settings
  3. ### Customize optimizer supported by Pytorch
  4. We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the `optimizer` field of config files.
  5. For example, if you want to use `ADAM` (note that the performance could drop a lot), the modification could be as the following.
  6. ```python
  7. optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
  8. ```
  9. To modify the learning rate of the model, the users only need to modify the `lr` in the config of optimizer. The users can directly set arguments following the [API doc](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) of PyTorch.
  10. ### Customize self-implemented optimizer
  11. #### 1. Define a new optimizer
  12. A customized optimizer could be defined as following.
  13. Assume you want to add a optimizer named `MyOptimizer`, which has arguments `a`, `b`, and `c`.
  14. You need to create a new directory named `mmdet/core/optimizer`.
  15. And then implement the new optimizer in a file, e.g., in `mmdet/core/optimizer/my_optimizer.py`:
  16. ```python
  17. from .registry import OPTIMIZERS
  18. from torch.optim import Optimizer
  19. @OPTIMIZERS.register_module()
  20. class MyOptimizer(Optimizer):
  21. def __init__(self, a, b, c)
  22. ```
  23. #### 2. Add the optimizer to registry
  24. To find the above module defined above, this module should be imported into the main namespace at first. There are two options to achieve it.
  25. - Modify `mmdet/core/optimizer/__init__.py` to import it.
  26. The newly defined module should be imported in `mmdet/core/optimizer/__init__.py` so that the registry will
  27. find the new module and add it:
  28. ```python
  29. from .my_optimizer import MyOptimizer
  30. ```
  31. - Use `custom_imports` in the config to manually import it
  32. ```python
  33. custom_imports = dict(imports=['mmdet.core.optimizer.my_optimizer'], allow_failed_imports=False)
  34. ```
  35. The module `mmdet.core.optimizer.my_optimizer` will be imported at the beginning of the program and the class `MyOptimizer` is then automatically registered.
  36. Note that only the package containing the class `MyOptimizer` should be imported.
  37. `mmdet.core.optimizer.my_optimizer.MyOptimizer` **cannot** be imported directly.
  38. Actually users can use a totally different file directory structure using this importing method, as long as the module root can be located in `PYTHONPATH`.
  39. #### 3. Specify the optimizer in the config file
  40. Then you can use `MyOptimizer` in `optimizer` field of config files.
  41. In the configs, the optimizers are defined by the field `optimizer` like the following:
  42. ```python
  43. optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
  44. ```
  45. To use your own optimizer, the field can be changed to
  46. ```python
  47. optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
  48. ```
  49. ### Customize optimizer constructor
  50. Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNorm layers.
  51. The users can do those fine-grained parameter tuning through customizing optimizer constructor.
  52. ```python
  53. from mmcv.utils import build_from_cfg
  54. from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
  55. from mmdet.utils import get_root_logger
  56. from .my_optimizer import MyOptimizer
  57. @OPTIMIZER_BUILDERS.register_module()
  58. class MyOptimizerConstructor(object):
  59. def __init__(self, optimizer_cfg, paramwise_cfg=None):
  60. def __call__(self, model):
  61. return my_optimizer
  62. ```
  63. The default optimizer constructor is implemented [here](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/optimizer/default_constructor.py#L11), which could also serve as a template for new optimizer constructor.
  64. ### Additional settings
  65. Tricks not implemented by the optimizer should be implemented through optimizer constructor (e.g., set parameter-wise learning rates) or hooks. We list some common settings that could stabilize the training or accelerate the training. Feel free to create PR, issue for more settings.
  66. - __Use gradient clip to stabilize training__:
  67. Some models need gradient clip to clip the gradients to stabilize the training process. An example is as below:
  68. ```python
  69. optimizer_config = dict(
  70. _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
  71. ```
  72. If your config inherits the base config which already sets the `optimizer_config`, you might need `_delete_=True` to override the unnecessary settings. See the [config documentation](https://mmdetection.readthedocs.io/en/latest/tutorials/config.html) for more details.
  73. - __Use momentum schedule to accelerate model convergence__:
  74. We support momentum scheduler to modify model's momentum according to learning rate, which could make the model converge in a faster way.
  75. Momentum scheduler is usually used with LR scheduler, for example, the following config is used in 3D detection to accelerate convergence.
  76. For more details, please refer to the implementation of [CyclicLrUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L327) and [CyclicMomentumUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/momentum_updater.py#L130).
  77. ```python
  78. lr_config = dict(
  79. policy='cyclic',
  80. target_ratio=(10, 1e-4),
  81. cyclic_times=1,
  82. step_ratio_up=0.4,
  83. )
  84. momentum_config = dict(
  85. policy='cyclic',
  86. target_ratio=(0.85 / 0.95, 1),
  87. cyclic_times=1,
  88. step_ratio_up=0.4,
  89. )
  90. ```
  91. ## Customize training schedules
  92. By default we use step learning rate with 1x schedule, this calls [`StepLRHook`](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L153) in MMCV.
  93. We support many other learning rate schedule [here](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py), such as `CosineAnnealing` and `Poly` schedule. Here are some examples
  94. - Poly schedule:
  95. ```python
  96. lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
  97. ```
  98. - ConsineAnnealing schedule:
  99. ```python
  100. lr_config = dict(
  101. policy='CosineAnnealing',
  102. warmup='linear',
  103. warmup_iters=1000,
  104. warmup_ratio=1.0 / 10,
  105. min_lr_ratio=1e-5)
  106. ```
  107. ## Customize workflow
  108. Workflow is a list of (phase, epochs) to specify the running order and epochs.
  109. By default it is set to be
  110. ```python
  111. workflow = [('train', 1)]
  112. ```
  113. which means running 1 epoch for training.
  114. Sometimes user may want to check some metrics (e.g. loss, accuracy) about the model on the validate set.
  115. In such case, we can set the workflow as
  116. ```python
  117. [('train', 1), ('val', 1)]
  118. ```
  119. so that 1 epoch for training and 1 epoch for validation will be run iteratively.
  120. **Note**:
  121. 1. The parameters of model will not be updated during val epoch.
  122. 2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow.
  123. 3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`. Therefore, the only difference between `[('train', 1), ('val', 1)]` and `[('train', 1)]` is that the runner will calculate losses on validation set after each training epoch.
  124. ## Customize hooks
  125. ### Customize self-implemented hooks
  126. #### 1. Implement a new hook
  127. There are some occasions when the users might need to implement a new hook. MMDetection supports customized hooks in training (#3395) since v2.3.0. Thus the users could implement a hook directly in mmdet or their mmdet-based codebases and use the hook by only modifying the config in training.
  128. Before v2.3.0, the users need to modify the code to get the hook registered before training starts.
  129. Here we give an example of creating a new hook in mmdet and using it in training.
  130. ```python
  131. from mmcv.runner import HOOKS, Hook
  132. @HOOKS.register_module()
  133. class MyHook(Hook):
  134. def __init__(self, a, b):
  135. pass
  136. def before_run(self, runner):
  137. pass
  138. def after_run(self, runner):
  139. pass
  140. def before_epoch(self, runner):
  141. pass
  142. def after_epoch(self, runner):
  143. pass
  144. def before_iter(self, runner):
  145. pass
  146. def after_iter(self, runner):
  147. pass
  148. ```
  149. Depending on the functionality of the hook, the users need to specify what the hook will do at each stage of the training in `before_run`, `after_run`, `before_epoch`, `after_epoch`, `before_iter`, and `after_iter`.
  150. #### 2. Register the new hook
  151. Then we need to make `MyHook` imported. Assuming the file is in `mmdet/core/utils/my_hook.py` there are two ways to do that:
  152. - Modify `mmdet/core/utils/__init__.py` to import it.
  153. The newly defined module should be imported in `mmdet/core/utils/__init__.py` so that the registry will
  154. find the new module and add it:
  155. ```python
  156. from .my_hook import MyHook
  157. ```
  158. - Use `custom_imports` in the config to manually import it
  159. ```python
  160. custom_imports = dict(imports=['mmdet.core.utils.my_hook'], allow_failed_imports=False)
  161. ```
  162. #### 3. Modify the config
  163. ```python
  164. custom_hooks = [
  165. dict(type='MyHook', a=a_value, b=b_value)
  166. ]
  167. ```
  168. You can also set the priority of the hook by adding key `priority` to `'NORMAL'` or `'HIGHEST'` as below
  169. ```python
  170. custom_hooks = [
  171. dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
  172. ]
  173. ```
  174. By default the hook's priority is set as `NORMAL` during registration.
  175. ### Use hooks implemented in MMCV
  176. If the hook is already implemented in MMCV, you can directly modify the config to use the hook as below
  177. #### 4. Example: `NumClassCheckHook`
  178. We implement a customized hook named [NumClassCheckHook](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/utils.py) to check whether the `num_classes` in head matches the length of `CLASSSES` in `dataset`.
  179. We set it in [default_runtime.py](https://github.com/open-mmlab/mmdetection/blob/master/configs/_base_/default_runtime.py).
  180. ```python
  181. custom_hooks = [dict(type='NumClassCheckHook')]
  182. ```
  183. ### Modify default runtime hooks
  184. There are some common hooks that are not registered through `custom_hooks`, they are
  185. - log_config
  186. - checkpoint_config
  187. - evaluation
  188. - lr_config
  189. - optimizer_config
  190. - momentum_config
  191. In those hooks, only the logger hook has the `VERY_LOW` priority, others' priority are `NORMAL`.
  192. The above-mentioned tutorials already covers how to modify `optimizer_config`, `momentum_config`, and `lr_config`.
  193. Here we reveals how what we can do with `log_config`, `checkpoint_config`, and `evaluation`.
  194. #### Checkpoint config
  195. The MMCV runner will use `checkpoint_config` to initialize [`CheckpointHook`](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/hooks/checkpoint.py#L9).
  196. ```python
  197. checkpoint_config = dict(interval=1)
  198. ```
  199. The users could set `max_keep_ckpts` to only save only small number of checkpoints or decide whether to store state dict of optimizer by `save_optimizer`. More details of the arguments are [here](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.CheckpointHook)
  200. #### Log config
  201. The `log_config` wraps multiple logger hooks and enables to set intervals. Now MMCV supports `WandbLoggerHook`, `MlflowLoggerHook`, and `TensorboardLoggerHook`.
  202. The detail usages can be found in the [doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.LoggerHook).
  203. ```python
  204. log_config = dict(
  205. interval=50,
  206. hooks=[
  207. dict(type='TextLoggerHook'),
  208. dict(type='TensorboardLoggerHook')
  209. ])
  210. ```
  211. #### Evaluation config
  212. The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmdetection/blob/7a404a2c000620d52156774a5025070d9e00d918/mmdet/core/evaluation/eval_hooks.py#L8).
  213. Except the key `interval`, other arguments such as `metric` will be passed to the `dataset.evaluate()`
  214. ```python
  215. evaluation = dict(interval=1, metric='bbox')
  216. ```

No Description

Contributors (2)