You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_mapper.py 6.1 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import copy
  3. import logging
  4. import numpy as np
  5. import torch
  6. from fvcore.common.file_io import PathManager
  7. from PIL import Image
  8. from . import detection_utils as utils
  9. from . import transforms as T
  10. """
  11. This file contains the default mapping that's applied to "dataset dicts".
  12. """
  13. __all__ = ["DatasetMapper"]
  14. class DatasetMapper:
  15. """
  16. A callable which takes a dataset dict in Detectron2 Dataset format,
  17. and map it into a format used by the model.
  18. This is the default callable to be used to map your dataset dict into training data.
  19. You may need to follow it to implement your own one for customized logic.
  20. The callable currently does the following:
  21. 1. Read the image from "file_name"
  22. 2. Applies cropping/geometric transforms to the image and annotations
  23. 3. Prepare data and annotations to Tensor and :class:`Instances`
  24. """
  25. def __init__(self, cfg, is_train=True):
  26. if cfg.INPUT.CROP.ENABLED and is_train:
  27. self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
  28. logging.getLogger(__name__).info("CropGen used in training: " + str(self.crop_gen))
  29. else:
  30. self.crop_gen = None
  31. self.tfm_gens = utils.build_transform_gen(cfg, is_train)
  32. # fmt: off
  33. self.img_format = cfg.INPUT.FORMAT
  34. self.mask_on = cfg.MODEL.MASK_ON
  35. self.mask_format = cfg.INPUT.MASK_FORMAT
  36. self.keypoint_on = cfg.MODEL.KEYPOINT_ON
  37. self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
  38. # fmt: on
  39. if self.keypoint_on and is_train:
  40. # Flip only makes sense in training
  41. self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
  42. else:
  43. self.keypoint_hflip_indices = None
  44. if self.load_proposals:
  45. self.min_box_side_len = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
  46. self.proposal_topk = (
  47. cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
  48. if is_train
  49. else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
  50. )
  51. self.is_train = is_train
  52. def __call__(self, dataset_dict):
  53. """
  54. Args:
  55. dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
  56. Returns:
  57. dict: a format that builtin models in detectron2 accept
  58. """
  59. dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
  60. # USER: Write your own image loading if it's not from a file
  61. image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
  62. utils.check_image_size(dataset_dict, image)
  63. if "annotations" not in dataset_dict:
  64. image, transforms = T.apply_transform_gens(
  65. ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
  66. )
  67. else:
  68. # Crop around an instance if there are instances in the image.
  69. # USER: Remove if you don't use cropping
  70. if self.crop_gen:
  71. crop_tfm = utils.gen_crop_transform_with_instance(
  72. self.crop_gen.get_crop_size(image.shape[:2]),
  73. image.shape[:2],
  74. np.random.choice(dataset_dict["annotations"]),
  75. )
  76. image = crop_tfm.apply_image(image)
  77. image, transforms = T.apply_transform_gens(self.tfm_gens, image)
  78. if self.crop_gen:
  79. transforms = crop_tfm + transforms
  80. image_shape = image.shape[:2] # h, w
  81. # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
  82. # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
  83. # Therefore it's important to use torch.Tensor.
  84. dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
  85. # Can use uint8 if it turns out to be slow some day
  86. # USER: Remove if you don't use pre-computed proposals.
  87. if self.load_proposals:
  88. utils.transform_proposals(
  89. dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
  90. )
  91. if not self.is_train:
  92. dataset_dict.pop("annotations", None)
  93. dataset_dict.pop("sem_seg_file_name", None)
  94. return dataset_dict
  95. if "annotations" in dataset_dict:
  96. # USER: Modify this if you want to keep them for some reason.
  97. for anno in dataset_dict["annotations"]:
  98. if not self.mask_on:
  99. anno.pop("segmentation", None)
  100. if not self.keypoint_on:
  101. anno.pop("keypoints", None)
  102. # USER: Implement additional transformations if you have other types of data
  103. annos = [
  104. utils.transform_instance_annotations(
  105. obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
  106. )
  107. for obj in dataset_dict.pop("annotations")
  108. if obj.get("iscrowd", 0) == 0
  109. ]
  110. instances = utils.annotations_to_instances(
  111. annos, image_shape, mask_format=self.mask_format
  112. )
  113. # Create a tight bounding box from masks, useful when image is cropped
  114. if self.crop_gen and instances.has("gt_masks"):
  115. instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
  116. dataset_dict["instances"] = utils.filter_empty_instances(instances)
  117. # USER: Remove if you don't do semantic/panoptic segmentation.
  118. if "sem_seg_file_name" in dataset_dict:
  119. with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
  120. sem_seg_gt = Image.open(f)
  121. sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
  122. sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
  123. sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
  124. dataset_dict["sem_seg"] = sem_seg_gt
  125. return dataset_dict

No Description