You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detection_utils.py 15 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. """
  4. Common data processing utilities that are used in a
  5. typical object detection data pipeline.
  6. """
  7. import logging
  8. import numpy as np
  9. import torch
  10. from fvcore.common.file_io import PathManager
  11. from PIL import Image, ImageOps
  12. from detectron2.structures import (
  13. BitMasks,
  14. Boxes,
  15. BoxMode,
  16. Instances,
  17. Keypoints,
  18. PolygonMasks,
  19. RotatedBoxes,
  20. )
  21. from . import transforms as T
  22. from .catalog import MetadataCatalog
  23. class SizeMismatchError(ValueError):
  24. """
  25. When loaded image has difference width/height compared with annotation.
  26. """
  27. def read_image(file_name, format=None):
  28. """
  29. Read an image into the given format.
  30. Will apply rotation and flipping if the image has such exif information.
  31. Args:
  32. file_name (str): image file path
  33. format (str): one of the supported image modes in PIL, or "BGR"
  34. Returns:
  35. image (np.ndarray): an HWC image
  36. """
  37. with PathManager.open(file_name, "rb") as f:
  38. image = Image.open(f)
  39. # capture and ignore this bug: https://github.com/python-pillow/Pillow/issues/3973
  40. try:
  41. image = ImageOps.exif_transpose(image)
  42. except Exception:
  43. pass
  44. if format is not None:
  45. # PIL only supports RGB, so convert to RGB and flip channels over below
  46. conversion_format = format
  47. if format == "BGR":
  48. conversion_format = "RGB"
  49. image = image.convert(conversion_format)
  50. image = np.asarray(image)
  51. if format == "BGR":
  52. # flip channels if needed
  53. image = image[:, :, ::-1]
  54. # PIL squeezes out the channel dimension for "L", so make it HWC
  55. if format == "L":
  56. image = np.expand_dims(image, -1)
  57. return image
  58. def check_image_size(dataset_dict, image):
  59. """
  60. Raise an error if the image does not match the size specified in the dict.
  61. """
  62. if "width" in dataset_dict or "height" in dataset_dict:
  63. image_wh = (image.shape[1], image.shape[0])
  64. expected_wh = (dataset_dict["width"], dataset_dict["height"])
  65. if not image_wh == expected_wh:
  66. raise SizeMismatchError(
  67. "Mismatched (W,H){}, got {}, expect {}".format(
  68. " for image " + dataset_dict["file_name"]
  69. if "file_name" in dataset_dict
  70. else "",
  71. image_wh,
  72. expected_wh,
  73. )
  74. )
  75. # To ensure bbox always remap to original image size
  76. if "width" not in dataset_dict:
  77. dataset_dict["width"] = image.shape[1]
  78. if "height" not in dataset_dict:
  79. dataset_dict["height"] = image.shape[0]
  80. def transform_proposals(dataset_dict, image_shape, transforms, min_box_side_len, proposal_topk):
  81. """
  82. Apply transformations to the proposals in dataset_dict, if any.
  83. Args:
  84. dataset_dict (dict): a dict read from the dataset, possibly
  85. contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
  86. image_shape (tuple): height, width
  87. transforms (TransformList):
  88. min_box_side_len (int): keep proposals with at least this size
  89. proposal_topk (int): only keep top-K scoring proposals
  90. The input dict is modified in-place, with abovementioned keys removed. A new
  91. key "proposals" will be added. Its value is an `Instances`
  92. object which contains the transformed proposals in its field
  93. "proposal_boxes" and "objectness_logits".
  94. """
  95. if "proposal_boxes" in dataset_dict:
  96. # Transform proposal boxes
  97. boxes = transforms.apply_box(
  98. BoxMode.convert(
  99. dataset_dict.pop("proposal_boxes"),
  100. dataset_dict.pop("proposal_bbox_mode"),
  101. BoxMode.XYXY_ABS,
  102. )
  103. )
  104. boxes = Boxes(boxes)
  105. objectness_logits = torch.as_tensor(
  106. dataset_dict.pop("proposal_objectness_logits").astype("float32")
  107. )
  108. boxes.clip(image_shape)
  109. keep = boxes.nonempty(threshold=min_box_side_len)
  110. boxes = boxes[keep]
  111. objectness_logits = objectness_logits[keep]
  112. proposals = Instances(image_shape)
  113. proposals.proposal_boxes = boxes[:proposal_topk]
  114. proposals.objectness_logits = objectness_logits[:proposal_topk]
  115. dataset_dict["proposals"] = proposals
  116. def transform_instance_annotations(
  117. annotation, transforms, image_size, *, keypoint_hflip_indices=None
  118. ):
  119. """
  120. Apply transforms to box, segmentation and keypoints of annotations of a single instance.
  121. It will use `transforms.apply_box` for the box, and
  122. `transforms.apply_coords` for segmentation polygons & keypoints.
  123. If you need anything more specially designed for each data structure,
  124. you'll need to implement your own version of this function or the transforms.
  125. Args:
  126. annotation (dict): dict of instance annotations for a single instance.
  127. transforms (TransformList):
  128. image_size (tuple): the height, width of the transformed image
  129. keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
  130. Returns:
  131. dict:
  132. the same input dict with fields "bbox", "segmentation", "keypoints"
  133. transformed according to `transforms`.
  134. The "bbox_mode" field will be set to XYXY_ABS.
  135. """
  136. bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
  137. # Note that bbox is 1d (per-instance bounding box)
  138. annotation["bbox"] = transforms.apply_box([bbox])[0]
  139. annotation["bbox_mode"] = BoxMode.XYXY_ABS
  140. if "segmentation" in annotation:
  141. # each instance contains 1 or more polygons
  142. polygons = [np.asarray(p).reshape(-1, 2) for p in annotation["segmentation"]]
  143. annotation["segmentation"] = [p.reshape(-1) for p in transforms.apply_polygons(polygons)]
  144. if "keypoints" in annotation:
  145. keypoints = transform_keypoint_annotations(
  146. annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
  147. )
  148. annotation["keypoints"] = keypoints
  149. return annotation
  150. def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
  151. """
  152. Transform keypoint annotations of an image.
  153. Args:
  154. keypoints (list[float]): Nx3 float in Detectron2 Dataset format.
  155. transforms (TransformList):
  156. image_size (tuple): the height, width of the transformed image
  157. keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
  158. """
  159. # (N*3,) -> (N, 3)
  160. keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
  161. keypoints[:, :2] = transforms.apply_coords(keypoints[:, :2])
  162. # This assumes that HorizFlipTransform is the only one that does flip
  163. do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
  164. # Alternative way: check if probe points was horizontally flipped.
  165. # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
  166. # probe_aug = transforms.apply_coords(probe.copy())
  167. # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
  168. # If flipped, swap each keypoint with its opposite-handed equivalent
  169. if do_hflip:
  170. assert keypoint_hflip_indices is not None
  171. keypoints = keypoints[keypoint_hflip_indices, :]
  172. # Maintain COCO convention that if visibility == 0, then x, y = 0
  173. # TODO may need to reset visibility for cropped keypoints,
  174. # but it does not matter for our existing algorithms
  175. keypoints[keypoints[:, 2] == 0] = 0
  176. return keypoints
  177. def annotations_to_instances(annos, image_size, mask_format="polygon"):
  178. """
  179. Create an :class:`Instances` object used by the models,
  180. from instance annotations in the dataset dict.
  181. Args:
  182. annos (list[dict]): a list of instance annotations in one image, each
  183. element for one instance.
  184. image_size (tuple): height, width
  185. Returns:
  186. Instances:
  187. It will contain fields "gt_boxes", "gt_classes",
  188. "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
  189. This is the format that builtin models expect.
  190. """
  191. boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
  192. target = Instances(image_size)
  193. boxes = target.gt_boxes = Boxes(boxes)
  194. boxes.clip(image_size)
  195. classes = [obj["category_id"] for obj in annos]
  196. classes = torch.tensor(classes, dtype=torch.int64)
  197. target.gt_classes = classes
  198. if len(annos) and "segmentation" in annos[0]:
  199. polygons = [obj["segmentation"] for obj in annos]
  200. if mask_format == "polygon":
  201. masks = PolygonMasks(polygons)
  202. else:
  203. assert mask_format == "bitmask", mask_format
  204. masks = BitMasks.from_polygon_masks(polygons, *image_size)
  205. target.gt_masks = masks
  206. if len(annos) and "keypoints" in annos[0]:
  207. kpts = [obj.get("keypoints", []) for obj in annos]
  208. target.gt_keypoints = Keypoints(kpts)
  209. return target
  210. def annotations_to_instances_rotated(annos, image_size):
  211. """
  212. Create an :class:`Instances` object used by the models,
  213. from instance annotations in the dataset dict.
  214. Compared to `annotations_to_instances`, this function is for rotated boxes only
  215. Args:
  216. annos (list[dict]): a list of instance annotations in one image, each
  217. element for one instance.
  218. image_size (tuple): height, width
  219. Returns:
  220. Instances:
  221. Containing fields "gt_boxes", "gt_classes",
  222. if they can be obtained from `annos`.
  223. This is the format that builtin models expect.
  224. """
  225. boxes = [obj["bbox"] for obj in annos]
  226. target = Instances(image_size)
  227. boxes = target.gt_boxes = RotatedBoxes(boxes)
  228. boxes.clip(image_size)
  229. classes = [obj["category_id"] for obj in annos]
  230. classes = torch.tensor(classes, dtype=torch.int64)
  231. target.gt_classes = classes
  232. return target
  233. def filter_empty_instances(instances, by_box=True, by_mask=True):
  234. """
  235. Filter out empty instances in an `Instances` object.
  236. Args:
  237. instances (Instances):
  238. by_box (bool): whether to filter out instances with empty boxes
  239. by_mask (bool): whether to filter out instances with empty masks
  240. Returns:
  241. Instances: the filtered instances.
  242. """
  243. assert by_box or by_mask
  244. r = []
  245. if by_box:
  246. r.append(instances.gt_boxes.nonempty())
  247. if instances.has("gt_masks") and by_mask:
  248. r.append(instances.gt_masks.nonempty())
  249. # TODO: can also filter visible keypoints
  250. if not r:
  251. return instances
  252. m = r[0]
  253. for x in r[1:]:
  254. m = m & x
  255. return instances[m]
  256. def create_keypoint_hflip_indices(dataset_names):
  257. """
  258. Args:
  259. dataset_names (list[str]): list of dataset names
  260. Returns:
  261. ndarray[int]: a vector of size=#keypoints, storing the
  262. horizontally-flipped keypoint indices.
  263. """
  264. check_metadata_consistency("keypoint_names", dataset_names)
  265. check_metadata_consistency("keypoint_flip_map", dataset_names)
  266. meta = MetadataCatalog.get(dataset_names[0])
  267. names = meta.keypoint_names
  268. # TODO flip -> hflip
  269. flip_map = dict(meta.keypoint_flip_map)
  270. flip_map.update({v: k for k, v in flip_map.items()})
  271. flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
  272. flip_indices = [names.index(i) for i in flipped_names]
  273. return np.asarray(flip_indices)
  274. def gen_crop_transform_with_instance(crop_size, image_size, instance):
  275. """
  276. Generate a CropTransform so that the cropping region contains
  277. the center of the given instance.
  278. Args:
  279. crop_size (tuple): h, w in pixels
  280. image_size (tuple): h, w
  281. instance (dict): an annotation dict of one instance, in Detectron2's
  282. dataset format.
  283. """
  284. crop_size = np.asarray(crop_size, dtype=np.int32)
  285. bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
  286. center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
  287. min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
  288. max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
  289. max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
  290. y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
  291. x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
  292. return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
  293. def check_metadata_consistency(key, dataset_names):
  294. """
  295. Check that the datasets have consistent metadata.
  296. Args:
  297. key (str): a metadata key
  298. dataset_names (list[str]): a list of dataset names
  299. Raises:
  300. AttributeError: if the key does not exist in the metadata
  301. ValueError: if the given datasets do not have the same metadata values defined by key
  302. """
  303. if len(dataset_names) == 0:
  304. return
  305. logger = logging.getLogger(__name__)
  306. entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
  307. for idx, entry in enumerate(entries_per_dataset):
  308. if entry != entries_per_dataset[0]:
  309. logger.error(
  310. "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
  311. )
  312. logger.error(
  313. "Metadata '{}' for dataset '{}' is '{}'".format(
  314. key, dataset_names[0], str(entries_per_dataset[0])
  315. )
  316. )
  317. raise ValueError("Datasets have different metadata '{}'!".format(key))
  318. def build_transform_gen(cfg, is_train):
  319. """
  320. Create a list of :class:`TransformGen` from config.
  321. Now it includes resizing and flipping.
  322. Returns:
  323. list[TransformGen]
  324. """
  325. if is_train:
  326. min_size = cfg.INPUT.MIN_SIZE_TRAIN
  327. max_size = cfg.INPUT.MAX_SIZE_TRAIN
  328. sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
  329. else:
  330. min_size = cfg.INPUT.MIN_SIZE_TEST
  331. max_size = cfg.INPUT.MAX_SIZE_TEST
  332. sample_style = "choice"
  333. if sample_style == "range":
  334. assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
  335. len(min_size)
  336. )
  337. logger = logging.getLogger(__name__)
  338. tfm_gens = []
  339. tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
  340. if is_train:
  341. tfm_gens.append(T.RandomFlip())
  342. logger.info("TransformGens used in training: " + str(tfm_gens))
  343. return tfm_gens

No Description