You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

brick_wrappers.py 1.8 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
  5. if torch.__version__ == 'parrots':
  6. TORCH_VERSION = torch.__version__
  7. else:
  8. # torch.__version__ could be 1.3.1+cu92, we only need the first two
  9. # for comparison
  10. TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
  11. def adaptive_avg_pool2d(input, output_size):
  12. """Handle empty batch dimension to adaptive_avg_pool2d.
  13. Args:
  14. input (tensor): 4D tensor.
  15. output_size (int, tuple[int,int]): the target output size.
  16. """
  17. if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
  18. if isinstance(output_size, int):
  19. output_size = [output_size, output_size]
  20. output_size = [*input.shape[:2], *output_size]
  21. empty = NewEmptyTensorOp.apply(input, output_size)
  22. return empty
  23. else:
  24. return F.adaptive_avg_pool2d(input, output_size)
  25. class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
  26. """Handle empty batch dimension to AdaptiveAvgPool2d."""
  27. def forward(self, x):
  28. # PyTorch 1.9 does not support empty tensor inference yet
  29. if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
  30. output_size = self.output_size
  31. if isinstance(output_size, int):
  32. output_size = [output_size, output_size]
  33. else:
  34. output_size = [
  35. v if v is not None else d
  36. for v, d in zip(output_size,
  37. x.size()[-2:])
  38. ]
  39. output_size = [*x.shape[:2], *output_size]
  40. empty = NewEmptyTensorOp.apply(x, output_size)
  41. return empty
  42. return super().forward(x)

No Description

Contributors (1)