You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_semantic_head.py 2.8 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. import torch.nn.functional as F
  4. from mmcv.runner import BaseModule, force_fp32
  5. from ..builder import build_loss
  6. from ..utils import interpolate_as
  7. class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
  8. """Base module of Semantic Head.
  9. Args:
  10. num_classes (int): the number of classes.
  11. init_cfg (dict): the initialization config.
  12. loss_seg (dict): the loss of the semantic head.
  13. """
  14. def __init__(self,
  15. num_classes,
  16. init_cfg=None,
  17. loss_seg=dict(
  18. type='CrossEntropyLoss',
  19. ignore_index=255,
  20. loss_weight=1.0)):
  21. super(BaseSemanticHead, self).__init__(init_cfg)
  22. self.loss_seg = build_loss(loss_seg)
  23. self.num_classes = num_classes
  24. @force_fp32(apply_to=('seg_preds', ))
  25. def loss(self, seg_preds, gt_semantic_seg):
  26. """Get the loss of semantic head.
  27. Args:
  28. seg_preds (Tensor): The input logits with the shape (N, C, H, W).
  29. gt_semantic_seg: The ground truth of semantic segmentation with
  30. the shape (N, H, W).
  31. label_bias: The starting number of the semantic label.
  32. Default: 1.
  33. Returns:
  34. dict: the loss of semantic head.
  35. """
  36. if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]:
  37. seg_preds = interpolate_as(seg_preds, gt_semantic_seg)
  38. seg_preds = seg_preds.permute((0, 2, 3, 1))
  39. loss_seg = self.loss_seg(
  40. seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
  41. gt_semantic_seg.reshape(-1).long())
  42. return dict(loss_seg=loss_seg)
  43. @abstractmethod
  44. def forward(self, x):
  45. """Placeholder of forward function.
  46. Returns:
  47. dict[str, Tensor]: A dictionary, including features
  48. and predicted scores. Required keys: 'seg_preds'
  49. and 'feats'.
  50. """
  51. pass
  52. def forward_train(self, x, gt_semantic_seg):
  53. output = self.forward(x)
  54. seg_preds = output['seg_preds']
  55. return self.loss(seg_preds, gt_semantic_seg)
  56. def simple_test(self, x, img_metas, rescale=False):
  57. output = self.forward(x)
  58. seg_preds = output['seg_preds']
  59. seg_preds = F.interpolate(
  60. seg_preds,
  61. size=img_metas[0]['pad_shape'][:2],
  62. mode='bilinear',
  63. align_corners=False)
  64. if rescale:
  65. h, w, _ = img_metas[0]['img_shape']
  66. seg_preds = seg_preds[:, :, :h, :w]
  67. h, w, _ = img_metas[0]['ori_shape']
  68. seg_preds = F.interpolate(
  69. seg_preds, size=(h, w), mode='bilinear', align_corners=False)
  70. return seg_preds

No Description

Contributors (3)