# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from __future__ import division from typing import Any, List, Sequence, Tuple, Union import torch from torch.nn import functional as F class ImageList(object): """ Structure that holds a list of images (of possibly varying sizes) as a single tensor. This works by padding the images to the same size, and storing in a field the original sizes of each image Attributes: image_sizes (list[tuple[int, int]]): each tuple is (h, w) """ def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]): """ Arguments: tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1 image_sizes (list[tuple[int, int]]): Each tuple is (h, w). """ self.tensor = tensor self.image_sizes = image_sizes def __len__(self) -> int: return len(self.image_sizes) def __getitem__(self, idx: Union[int, slice]) -> torch.Tensor: """ Access the individual image in its original size. Returns: Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1 """ size = self.image_sizes[idx] return self.tensor[idx, ..., : size[0], : size[1]] # type: ignore def to(self, *args: Any, **kwargs: Any) -> "ImageList": cast_tensor = self.tensor.to(*args, **kwargs) return ImageList(cast_tensor, self.image_sizes) @staticmethod def from_tensors( tensors: Sequence[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0 ) -> "ImageList": """ Args: tensors: a tuple or list of `torch.Tensors`, each of shape (Hi, Wi) or (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded with `pad_value` so that they will have the same shape. size_divisibility (int): If `size_divisibility > 0`, also adds padding to ensure the common height and width is divisible by `size_divisibility` pad_value (float): value to pad Returns: an `ImageList`. """ assert len(tensors) > 0 assert isinstance(tensors, (tuple, list)) for t in tensors: assert isinstance(t, torch.Tensor), type(t) assert t.shape[1:-2] == tensors[0].shape[1:-2], t.shape # per dimension maximum (H, W) or (C_1, ..., C_K, H, W) where K >= 1 among all tensors max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) if size_divisibility > 0: import math stride = size_divisibility max_size = list(max_size) # type: ignore max_size[-2] = int(math.ceil(max_size[-2] / stride) * stride) # type: ignore max_size[-1] = int(math.ceil(max_size[-1] / stride) * stride) # type: ignore max_size = tuple(max_size) image_sizes = [im.shape[-2:] for im in tensors] if len(tensors) == 1: # This seems slightly (2%) faster. # TODO: check whether it's faster for multiple images as well image_size = image_sizes[0] padded = F.pad( tensors[0], [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]], value=pad_value, ) batched_imgs = padded.unsqueeze_(0) else: batch_shape = (len(tensors),) + max_size batched_imgs = tensors[0].new_full(batch_shape, pad_value) for img, pad_img in zip(tensors, batched_imgs): pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) return ImageList(batched_imgs.contiguous(), image_sizes)