You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

customize_losses.md 4.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Tutorial 6: Customize Losses
  2. MMDetection provides users with different loss functions. But the default configuration may be not applicable for different datasets or models, so users may want to modify a specific loss to adapt the new situation.
  3. This tutorial first elaborate the computation pipeline of losses, then give some instructions about how to modify each step. The modification can be categorized as tweaking and weighting.
  4. ## Computation pipeline of a loss
  5. Given the input prediction and target, as well as the weights, a loss function maps the input tensor to the final loss scalar. The mapping can be divided into four steps:
  6. 1. Set the sampling method to sample positive and negative samples.
  7. 2. Get **element-wise** or **sample-wise** loss by the loss kernel function.
  8. 3. Weighting the loss with a weight tensor **element-wisely**.
  9. 4. Reduce the loss tensor to a **scalar**.
  10. 5. Weighting the loss with a **scalar**.
  11. ## Set sampling method (step 1)
  12. For some loss functions, sampling strategies are needed to avoid imbalance between positive and negative samples.
  13. For example, when using `CrossEntropyLoss` in RPN head, we need to set `RandomSampler` in `train_cfg`
  14. ```python
  15. train_cfg=dict(
  16. rpn=dict(
  17. sampler=dict(
  18. type='RandomSampler',
  19. num=256,
  20. pos_fraction=0.5,
  21. neg_pos_ub=-1,
  22. add_gt_as_proposals=False))
  23. ```
  24. For some other losses which have positive and negative sample balance mechanism such as Focal Loss, GHMC, and QualityFocalLoss, the sampler is no more necessary.
  25. ## Tweaking loss
  26. Tweaking a loss is more related with step 2, 4, 5, and most modifications can be specified in the config.
  27. Here we take [Focal Loss (FL)](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/focal_loss.py) as an example.
  28. The following code sniper are the construction method and config of FL respectively, they are actually one to one correspondence.
  29. ```python
  30. @LOSSES.register_module()
  31. class FocalLoss(nn.Module):
  32. def __init__(self,
  33. use_sigmoid=True,
  34. gamma=2.0,
  35. alpha=0.25,
  36. reduction='mean',
  37. loss_weight=1.0):
  38. ```
  39. ```python
  40. loss_cls=dict(
  41. type='FocalLoss',
  42. use_sigmoid=True,
  43. gamma=2.0,
  44. alpha=0.25,
  45. loss_weight=1.0)
  46. ```
  47. ### Tweaking hyper-parameters (step 2)
  48. `gamma` and `beta` are two hyper-parameters in the Focal Loss. Say if we want to change the value of `gamma` to be 1.5 and `alpha` to be 0.5, then we can specify them in the config as follows:
  49. ```python
  50. loss_cls=dict(
  51. type='FocalLoss',
  52. use_sigmoid=True,
  53. gamma=1.5,
  54. alpha=0.5,
  55. loss_weight=1.0)
  56. ```
  57. ### Tweaking the way of reduction (step 3)
  58. The default way of reduction is `mean` for FL. Say if we want to change the reduction from `mean` to `sum`, we can specify it in the config as follows:
  59. ```python
  60. loss_cls=dict(
  61. type='FocalLoss',
  62. use_sigmoid=True,
  63. gamma=2.0,
  64. alpha=0.25,
  65. loss_weight=1.0,
  66. reduction='sum')
  67. ```
  68. ### Tweaking loss weight (step 5)
  69. The loss weight here is a scalar which controls the weight of different losses in multi-task learning, e.g. classification loss and regression loss. Say if we want to change to loss weight of classification loss to be 0.5, we can specify it in the config as follows:
  70. ```python
  71. loss_cls=dict(
  72. type='FocalLoss',
  73. use_sigmoid=True,
  74. gamma=2.0,
  75. alpha=0.25,
  76. loss_weight=0.5)
  77. ```
  78. ## Weighting loss (step 3)
  79. Weighting loss means we re-weight the loss element-wisely. To be more specific, we multiply the loss tensor with a weight tensor which has the same shape. As a result, different entries of the loss can be scaled differently, and so called element-wisely.
  80. The loss weight varies across different models and highly context related, but overall there are two kinds of loss weights, `label_weights` for classification loss and `bbox_weights` for bbox regression loss. You can find them in the `get_target` method of the corresponding head. Here we take [ATSSHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/atss_head.py#L530) as an example, which inherit [AnchorHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/anchor_head.py) but overwrite its `get_targets` method which yields different `label_weights` and `bbox_weights`.
  81. ```
  82. class ATSSHead(AnchorHead):
  83. ...
  84. def get_targets(self,
  85. anchor_list,
  86. valid_flag_list,
  87. gt_bboxes_list,
  88. img_metas,
  89. gt_bboxes_ignore_list=None,
  90. gt_labels_list=None,
  91. label_channels=1,
  92. unmap_outputs=True):
  93. ```

No Description

Contributors (3)