You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

misc.py 1.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from torch.nn import functional as F
  3. def interpolate_as(source, target, mode='bilinear', align_corners=False):
  4. """Interpolate the `source` to the shape of the `target`.
  5. The `source` must be a Tensor, but the `target` can be a Tensor or a
  6. np.ndarray with the shape (..., target_h, target_w).
  7. Args:
  8. source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or
  9. (N, C, H, W).
  10. target (Tensor | np.ndarray): The interpolation target with the shape
  11. (..., target_h, target_w).
  12. mode (str): Algorithm used for interpolation. The options are the
  13. same as those in F.interpolate(). Default: ``'bilinear'``.
  14. align_corners (bool): The same as the argument in F.interpolate().
  15. Returns:
  16. Tensor: The interpolated source Tensor.
  17. """
  18. assert len(target.shape) >= 2
  19. def _interpolate_as(source, target, mode='bilinear', align_corners=False):
  20. """Interpolate the `source` (4D) to the shape of the `target`."""
  21. target_h, target_w = target.shape[-2:]
  22. source_h, source_w = source.shape[-2:]
  23. if target_h != source_h or target_w != source_w:
  24. source = F.interpolate(
  25. source,
  26. size=(target_h, target_w),
  27. mode=mode,
  28. align_corners=align_corners)
  29. return source
  30. if len(source.shape) == 3:
  31. source = source[:, None, :, :]
  32. source = _interpolate_as(source, target, mode, align_corners)
  33. return source[:, 0, :, :]
  34. else:
  35. return _interpolate_as(source, target, mode, align_corners)

No Description

Contributors (1)