|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import copy
- import numpy as np
- from contextlib import contextmanager
- from itertools import count
- import torch
- from torch import nn
- from torch.nn.parallel import DistributedDataParallel
-
- from detectron2.data.detection_utils import read_image
- from detectron2.data.transforms import ResizeShortestEdge
- from detectron2.structures import Instances
-
- from .meta_arch import GeneralizedRCNN
- from .postprocessing import detector_postprocess
- from .roi_heads.fast_rcnn import fast_rcnn_inference_single_image
-
- __all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"]
-
-
- class DatasetMapperTTA:
- """
- Implement test-time augmentation for detection data.
- It is a callable which takes a dataset dict from a detection dataset,
- and returns a list of dataset dicts where the images
- are augmented from the input image by the transformations defined in the config.
- This is used for test-time augmentation.
- """
-
- def __init__(self, cfg):
- self.min_sizes = cfg.TEST.AUG.MIN_SIZES
- self.max_size = cfg.TEST.AUG.MAX_SIZE
- self.flip = cfg.TEST.AUG.FLIP
- self.image_format = cfg.INPUT.FORMAT
-
- def __call__(self, dataset_dict):
- """
- Args:
- dict: a detection dataset dict
-
- Returns:
- list[dict]:
- a list of dataset dicts, which contain augmented version of the input image.
- The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``.
- """
- ret = []
- if "image" not in dataset_dict:
- numpy_image = read_image(dataset_dict["file_name"], self.image_format)
- else:
- numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy().astype("uint8")
- for min_size in self.min_sizes:
- image = np.copy(numpy_image)
- tfm = ResizeShortestEdge(min_size, self.max_size).get_transform(image)
- resized = tfm.apply_image(image)
- resized = torch.as_tensor(resized.transpose(2, 0, 1).astype("float32"))
-
- dic = copy.deepcopy(dataset_dict)
- dic["horiz_flip"] = False
- dic["image"] = resized
- ret.append(dic)
-
- if self.flip:
- dic = copy.deepcopy(dataset_dict)
- dic["horiz_flip"] = True
- dic["image"] = torch.flip(resized, dims=[2])
- ret.append(dic)
- return ret
-
-
- class GeneralizedRCNNWithTTA(nn.Module):
- """
- A GeneralizedRCNN with test-time augmentation enabled.
- Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`.
- """
-
- def __init__(self, cfg, model, tta_mapper=None, batch_size=3):
- """
- Args:
- cfg (CfgNode):
- model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
- tta_mapper (callable): takes a dataset dict and returns a list of
- augmented versions of the dataset dict. Defaults to
- `DatasetMapperTTA(cfg)`.
- batch_size (int): batch the augmented images into this batch size for inference.
- """
- super().__init__()
- if isinstance(model, DistributedDataParallel):
- model = model.module
- assert isinstance(
- model, GeneralizedRCNN
- ), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model))
- self.cfg = cfg.clone()
- assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet"
- assert (
- not self.cfg.MODEL.LOAD_PROPOSALS
- ), "TTA for pre-computed proposals is not supported yet"
-
- self.model = model
-
- if tta_mapper is None:
- tta_mapper = DatasetMapperTTA(cfg)
- self.tta_mapper = tta_mapper
- self.batch_size = batch_size
-
- @contextmanager
- def _turn_off_roi_head(self, attr):
- """
- Open a context where one head in `model.roi_heads` is temporarily turned off.
- Args:
- attr (str): the attribute in `model.roi_heads` which can be used
- to turn off a specific head, e.g., "mask_on", "keypoint_on".
- """
- roi_heads = self.model.roi_heads
- try:
- old = getattr(roi_heads, attr)
- except AttributeError:
- # The head may not be implemented in certain ROIHeads
- old = None
-
- if old is None:
- yield
- else:
- setattr(roi_heads, attr, False)
- yield
- setattr(roi_heads, attr, old)
-
- def _batch_inference(self, batched_inputs, detected_instances=None, do_postprocess=True):
- """
- Execute inference on a list of inputs,
- using batch size = self.batch_size, instead of the length of the list.
-
- Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
- """
- if detected_instances is None:
- detected_instances = [None] * len(batched_inputs)
-
- outputs = []
- inputs, instances = [], []
- for idx, input, instance in zip(count(), batched_inputs, detected_instances):
- inputs.append(input)
- instances.append(instance)
- if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
- outputs.extend(
- self.model.inference(
- inputs,
- instances if instances[0] is not None else None,
- do_postprocess=do_postprocess,
- )
- )
- inputs, instances = [], []
- return outputs
-
- def __call__(self, batched_inputs):
- """
- Same input/output format as :meth:`GeneralizedRCNN.forward`
- """
- return [self._inference_one_image(x) for x in batched_inputs]
-
- def _inference_one_image(self, input):
- """
- Args:
- input (dict): one dataset dict
-
- Returns:
- dict: one output dict
- """
- augmented_inputs = self.tta_mapper(input)
-
- do_hflip = [k.pop("horiz_flip", False) for k in augmented_inputs]
- heights = [k["height"] for k in augmented_inputs]
- widths = [k["width"] for k in augmented_inputs]
- assert (
- len(set(heights)) == 1 and len(set(widths)) == 1
- ), "Augmented version of the inputs should have the same original resolution!"
- height = heights[0]
- width = widths[0]
-
- # 1. Detect boxes from all augmented versions
- # 1.1: forward with all augmented images
- with self._turn_off_roi_head("mask_on"), self._turn_off_roi_head("keypoint_on"):
- # temporarily disable mask/keypoint head
- outputs = self._batch_inference(augmented_inputs, do_postprocess=False)
- # 1.2: union the results
- all_boxes = []
- all_scores = []
- all_classes = []
- for idx, output in enumerate(outputs):
- rescaled_output = detector_postprocess(output, height, width)
- pred_boxes = rescaled_output.pred_boxes.tensor
- if do_hflip[idx]:
- pred_boxes[:, [0, 2]] = width - pred_boxes[:, [2, 0]]
- all_boxes.append(pred_boxes)
- all_scores.extend(rescaled_output.scores)
- all_classes.extend(rescaled_output.pred_classes)
- all_boxes = torch.cat(all_boxes, dim=0).cpu()
- num_boxes = len(all_boxes)
-
- # 1.3: select from the union of all results
- num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES
- # +1 because fast_rcnn_inference expects background scores as well
- all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device)
- for idx, cls, score in zip(count(), all_classes, all_scores):
- all_scores_2d[idx, cls] = score
-
- merged_instances, _ = fast_rcnn_inference_single_image(
- all_boxes,
- all_scores_2d,
- (height, width),
- 1e-8,
- self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
- self.cfg.TEST.DETECTIONS_PER_IMAGE,
- )
-
- if not self.cfg.MODEL.MASK_ON:
- return {"instances": merged_instances}
-
- # 2. Use the detected boxes to obtain masks
- # 2.1: rescale the detected boxes
- augmented_instances = []
- for idx, input in enumerate(augmented_inputs):
- actual_height, actual_width = input["image"].shape[1:3]
- scale_x = actual_width * 1.0 / width
- scale_y = actual_height * 1.0 / height
- pred_boxes = merged_instances.pred_boxes.clone()
- pred_boxes.tensor[:, 0::2] *= scale_x
- pred_boxes.tensor[:, 1::2] *= scale_y
- if do_hflip[idx]:
- pred_boxes.tensor[:, [0, 2]] = actual_width - pred_boxes.tensor[:, [2, 0]]
-
- aug_instances = Instances(
- image_size=(actual_height, actual_width),
- pred_boxes=pred_boxes,
- pred_classes=merged_instances.pred_classes,
- scores=merged_instances.scores,
- )
- augmented_instances.append(aug_instances)
- # 2.2: run forward on the detected boxes
- outputs = self._batch_inference(augmented_inputs, augmented_instances, do_postprocess=False)
- for idx, output in enumerate(outputs):
- if do_hflip[idx]:
- output.pred_masks = output.pred_masks.flip(dims=[3])
- # 2.3: average the predictions
- all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0)
- avg_pred_masks = torch.mean(all_pred_masks, dim=0)
- output = outputs[0]
- output.pred_masks = avg_pred_masks
- output = detector_postprocess(output, height, width)
- return {"instances": output}
|