# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch from mmdet.models.utils import interpolate_as def test_interpolate_as(): source = torch.rand((1, 5, 4, 4)) target = torch.rand((1, 1, 16, 16)) # Test 4D source and target result = interpolate_as(source, target) assert result.shape == torch.Size((1, 5, 16, 16)) # Test 3D target result = interpolate_as(source, target.squeeze(0)) assert result.shape == torch.Size((1, 5, 16, 16)) # Test 3D source result = interpolate_as(source.squeeze(0), target) assert result.shape == torch.Size((5, 16, 16)) # Test type(target) == np.ndarray target = np.random.rand(16, 16) result = interpolate_as(source.squeeze(0), target) assert result.shape == torch.Size((5, 16, 16))