You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

postprocessing.py 3.1 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. from torch.nn import functional as F
  3. from detectron2.layers import paste_masks_in_image
  4. from detectron2.structures import Instances
  5. def detector_postprocess(results, output_height, output_width, mask_threshold=0.5):
  6. """
  7. Resize the output instances.
  8. The input images are often resized when entering an object detector.
  9. As a result, we often need the outputs of the detector in a different
  10. resolution from its inputs.
  11. This function will resize the raw outputs of an R-CNN detector
  12. to produce outputs according to the desired output resolution.
  13. Args:
  14. results (Instances): the raw outputs from the detector.
  15. `results.image_size` contains the input image resolution the detector sees.
  16. This object might be modified in-place.
  17. output_height, output_width: the desired output resolution.
  18. Returns:
  19. Instances: the resized output from the model, based on the output resolution
  20. """
  21. scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
  22. results = Instances((output_height, output_width), **results.get_fields())
  23. if results.has("pred_boxes"):
  24. output_boxes = results.pred_boxes
  25. elif results.has("proposal_boxes"):
  26. output_boxes = results.proposal_boxes
  27. output_boxes.scale(scale_x, scale_y)
  28. output_boxes.clip(results.image_size)
  29. results = results[output_boxes.nonempty()]
  30. if results.has("pred_masks"):
  31. results.pred_masks = paste_masks_in_image(
  32. results.pred_masks[:, 0, :, :], # N, 1, M, M
  33. results.pred_boxes,
  34. results.image_size,
  35. threshold=mask_threshold,
  36. )
  37. if results.has("pred_keypoints"):
  38. results.pred_keypoints[:, :, 0] *= scale_x
  39. results.pred_keypoints[:, :, 1] *= scale_y
  40. return results
  41. def sem_seg_postprocess(result, img_size, output_height, output_width):
  42. """
  43. Return semantic segmentation predictions in the original resolution.
  44. The input images are often resized when entering semantic segmentor. Moreover, in same
  45. cases, they also padded inside segmentor to be divisible by maximum network stride.
  46. As a result, we often need the predictions of the segmentor in a different
  47. resolution from its inputs.
  48. Args:
  49. result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W),
  50. where C is the number of classes, and H, W are the height and width of the prediction.
  51. img_size (tuple): image size that segmentor is taking as input.
  52. output_height, output_width: the desired output resolution.
  53. Returns:
  54. semantic segmentation prediction (Tensor): A tensor of the shape
  55. (C, output_height, output_width) that contains per-pixel soft predictions.
  56. """
  57. result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1)
  58. result = F.interpolate(
  59. result, size=(output_height, output_width), mode="bilinear", align_corners=False
  60. )[0]
  61. return result

No Description