You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detr.py 2.5 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. from ..builder import DETECTORS
  5. from .single_stage import SingleStageDetector
  6. @DETECTORS.register_module()
  7. class DETR(SingleStageDetector):
  8. r"""Implementation of `DETR: End-to-End Object Detection with
  9. Transformers <https://arxiv.org/pdf/2005.12872>`_"""
  10. def __init__(self,
  11. backbone,
  12. bbox_head,
  13. train_cfg=None,
  14. test_cfg=None,
  15. pretrained=None,
  16. init_cfg=None):
  17. super(DETR, self).__init__(backbone, None, bbox_head, train_cfg,
  18. test_cfg, pretrained, init_cfg)
  19. # over-write `forward_dummy` because:
  20. # the forward of bbox_head requires img_metas
  21. def forward_dummy(self, img):
  22. """Used for computing network flops.
  23. See `mmdetection/tools/analysis_tools/get_flops.py`
  24. """
  25. warnings.warn('Warning! MultiheadAttention in DETR does not '
  26. 'support flops computation! Do not use the '
  27. 'results in your papers!')
  28. batch_size, _, height, width = img.shape
  29. dummy_img_metas = [
  30. dict(
  31. batch_input_shape=(height, width),
  32. img_shape=(height, width, 3)) for _ in range(batch_size)
  33. ]
  34. x = self.extract_feat(img)
  35. outs = self.bbox_head(x, dummy_img_metas)
  36. return outs
  37. # over-write `onnx_export` because:
  38. # (1) the forward of bbox_head requires img_metas
  39. # (2) the different behavior (e.g. construction of `masks`) between
  40. # torch and ONNX model, during the forward of bbox_head
  41. def onnx_export(self, img, img_metas):
  42. """Test function for exporting to ONNX, without test time augmentation.
  43. Args:
  44. img (torch.Tensor): input images.
  45. img_metas (list[dict]): List of image information.
  46. Returns:
  47. tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
  48. and class labels of shape [N, num_det].
  49. """
  50. x = self.extract_feat(img)
  51. # forward of this head requires img_metas
  52. outs = self.bbox_head.forward_onnx(x, img_metas)
  53. # get shape as tensor
  54. img_shape = torch._shape_as_tensor(img)[2:]
  55. img_metas[0]['img_shape_for_onnx'] = img_shape
  56. det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
  57. return det_bboxes, det_labels

No Description

Contributors (2)