You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo.py 1.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) 2019 Western Digital Corporation or its affiliates.
  3. import torch
  4. from ..builder import DETECTORS
  5. from .single_stage import SingleStageDetector
  6. @DETECTORS.register_module()
  7. class YOLOV3(SingleStageDetector):
  8. def __init__(self,
  9. backbone,
  10. neck,
  11. bbox_head,
  12. train_cfg=None,
  13. test_cfg=None,
  14. pretrained=None,
  15. init_cfg=None):
  16. super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
  17. test_cfg, pretrained, init_cfg)
  18. def onnx_export(self, img, img_metas):
  19. """Test function for exporting to ONNX, without test time augmentation.
  20. Args:
  21. img (torch.Tensor): input images.
  22. img_metas (list[dict]): List of image information.
  23. Returns:
  24. tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
  25. and class labels of shape [N, num_det].
  26. """
  27. x = self.extract_feat(img)
  28. outs = self.bbox_head.forward(x)
  29. # get shape as tensor
  30. img_shape = torch._shape_as_tensor(img)[2:]
  31. img_metas[0]['img_shape_for_onnx'] = img_shape
  32. det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
  33. return det_bboxes, det_labels

No Description

Contributors (2)