You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

normed_predictor.py 3.0 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import CONV_LAYERS
  6. from .builder import LINEAR_LAYERS
  7. @LINEAR_LAYERS.register_module(name='NormedLinear')
  8. class NormedLinear(nn.Linear):
  9. """Normalized Linear Layer.
  10. Args:
  11. tempeature (float, optional): Tempeature term. Default to 20.
  12. power (int, optional): Power term. Default to 1.0.
  13. eps (float, optional): The minimal value of divisor to
  14. keep numerical stability. Default to 1e-6.
  15. """
  16. def __init__(self, *args, tempearture=20, power=1.0, eps=1e-6, **kwargs):
  17. super(NormedLinear, self).__init__(*args, **kwargs)
  18. self.tempearture = tempearture
  19. self.power = power
  20. self.eps = eps
  21. self.init_weights()
  22. def init_weights(self):
  23. nn.init.normal_(self.weight, mean=0, std=0.01)
  24. if self.bias is not None:
  25. nn.init.constant_(self.bias, 0)
  26. def forward(self, x):
  27. weight_ = self.weight / (
  28. self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  29. x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  30. x_ = x_ * self.tempearture
  31. return F.linear(x_, weight_, self.bias)
  32. @CONV_LAYERS.register_module(name='NormedConv2d')
  33. class NormedConv2d(nn.Conv2d):
  34. """Normalized Conv2d Layer.
  35. Args:
  36. tempeature (float, optional): Tempeature term. Default to 20.
  37. power (int, optional): Power term. Default to 1.0.
  38. eps (float, optional): The minimal value of divisor to
  39. keep numerical stability. Default to 1e-6.
  40. norm_over_kernel (bool, optional): Normalize over kernel.
  41. Default to False.
  42. """
  43. def __init__(self,
  44. *args,
  45. tempearture=20,
  46. power=1.0,
  47. eps=1e-6,
  48. norm_over_kernel=False,
  49. **kwargs):
  50. super(NormedConv2d, self).__init__(*args, **kwargs)
  51. self.tempearture = tempearture
  52. self.power = power
  53. self.norm_over_kernel = norm_over_kernel
  54. self.eps = eps
  55. def forward(self, x):
  56. if not self.norm_over_kernel:
  57. weight_ = self.weight / (
  58. self.weight.norm(dim=1, keepdim=True).pow(self.power) +
  59. self.eps)
  60. else:
  61. weight_ = self.weight / (
  62. self.weight.view(self.weight.size(0), -1).norm(
  63. dim=1, keepdim=True).pow(self.power)[..., None, None] +
  64. self.eps)
  65. x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
  66. x_ = x_ * self.tempearture
  67. if hasattr(self, 'conv2d_forward'):
  68. x_ = self.conv2d_forward(x_, weight_)
  69. else:
  70. if torch.__version__ >= '1.8':
  71. x_ = self._conv_forward(x_, weight_, self.bias)
  72. else:
  73. x_ = self._conv_forward(x_, weight_)
  74. return x_

No Description

Contributors (1)