You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transformer.py 46 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. from typing import Sequence
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import (build_activation_layer, build_conv_layer,
  9. build_norm_layer, xavier_init)
  10. from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
  11. TRANSFORMER_LAYER_SEQUENCE)
  12. from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
  13. TransformerLayerSequence,
  14. build_transformer_layer_sequence)
  15. from mmcv.runner.base_module import BaseModule
  16. from mmcv.utils import to_2tuple
  17. from torch.nn.init import normal_
  18. from mmdet.models.utils.builder import TRANSFORMER
  19. try:
  20. from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
  21. except ImportError:
  22. warnings.warn(
  23. '`MultiScaleDeformableAttention` in MMCV has been moved to '
  24. '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
  25. from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
  26. def nlc_to_nchw(x, hw_shape):
  27. """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
  28. Args:
  29. x (Tensor): The input tensor of shape [N, L, C] before conversion.
  30. hw_shape (Sequence[int]): The height and width of output feature map.
  31. Returns:
  32. Tensor: The output tensor of shape [N, C, H, W] after conversion.
  33. """
  34. H, W = hw_shape
  35. assert len(x.shape) == 3
  36. B, L, C = x.shape
  37. assert L == H * W, 'The seq_len does not match H, W'
  38. return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
  39. def nchw_to_nlc(x):
  40. """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
  41. Args:
  42. x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
  43. Returns:
  44. Tensor: The output tensor of shape [N, L, C] after conversion.
  45. """
  46. assert len(x.shape) == 4
  47. return x.flatten(2).transpose(1, 2).contiguous()
  48. class AdaptivePadding(nn.Module):
  49. """Applies padding to input (if needed) so that input can get fully covered
  50. by filter you specified. It support two modes "same" and "corner". The
  51. "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
  52. input. The "corner" mode would pad zero to bottom right.
  53. Args:
  54. kernel_size (int | tuple): Size of the kernel:
  55. stride (int | tuple): Stride of the filter. Default: 1:
  56. dilation (int | tuple): Spacing between kernel elements.
  57. Default: 1
  58. padding (str): Support "same" and "corner", "corner" mode
  59. would pad zero to bottom right, and "same" mode would
  60. pad zero around input. Default: "corner".
  61. Example:
  62. >>> kernel_size = 16
  63. >>> stride = 16
  64. >>> dilation = 1
  65. >>> input = torch.rand(1, 1, 15, 17)
  66. >>> adap_pad = AdaptivePadding(
  67. >>> kernel_size=kernel_size,
  68. >>> stride=stride,
  69. >>> dilation=dilation,
  70. >>> padding="corner")
  71. >>> out = adap_pad(input)
  72. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  73. >>> input = torch.rand(1, 1, 16, 17)
  74. >>> out = adap_pad(input)
  75. >>> assert (out.shape[2], out.shape[3]) == (16, 32)
  76. """
  77. def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
  78. super(AdaptivePadding, self).__init__()
  79. assert padding in ('same', 'corner')
  80. kernel_size = to_2tuple(kernel_size)
  81. stride = to_2tuple(stride)
  82. padding = to_2tuple(padding)
  83. dilation = to_2tuple(dilation)
  84. self.padding = padding
  85. self.kernel_size = kernel_size
  86. self.stride = stride
  87. self.dilation = dilation
  88. def get_pad_shape(self, input_shape):
  89. input_h, input_w = input_shape
  90. kernel_h, kernel_w = self.kernel_size
  91. stride_h, stride_w = self.stride
  92. output_h = math.ceil(input_h / stride_h)
  93. output_w = math.ceil(input_w / stride_w)
  94. pad_h = max((output_h - 1) * stride_h +
  95. (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
  96. pad_w = max((output_w - 1) * stride_w +
  97. (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
  98. return pad_h, pad_w
  99. def forward(self, x):
  100. pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
  101. if pad_h > 0 or pad_w > 0:
  102. if self.padding == 'corner':
  103. x = F.pad(x, [0, pad_w, 0, pad_h])
  104. elif self.padding == 'same':
  105. x = F.pad(x, [
  106. pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
  107. pad_h - pad_h // 2
  108. ])
  109. return x
  110. class PatchEmbed(BaseModule):
  111. """Image to Patch Embedding.
  112. We use a conv layer to implement PatchEmbed.
  113. Args:
  114. in_channels (int): The num of input channels. Default: 3
  115. embed_dims (int): The dimensions of embedding. Default: 768
  116. conv_type (str): The config dict for embedding
  117. conv layer type selection. Default: "Conv2d.
  118. kernel_size (int): The kernel_size of embedding conv. Default: 16.
  119. stride (int): The slide stride of embedding conv.
  120. Default: None (Would be set as `kernel_size`).
  121. padding (int | tuple | string ): The padding length of
  122. embedding conv. When it is a string, it means the mode
  123. of adaptive padding, support "same" and "corner" now.
  124. Default: "corner".
  125. dilation (int): The dilation rate of embedding conv. Default: 1.
  126. bias (bool): Bias of embed conv. Default: True.
  127. norm_cfg (dict, optional): Config dict for normalization layer.
  128. Default: None.
  129. input_size (int | tuple | None): The size of input, which will be
  130. used to calculate the out size. Only work when `dynamic_size`
  131. is False. Default: None.
  132. init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
  133. Default: None.
  134. """
  135. def __init__(
  136. self,
  137. in_channels=3,
  138. embed_dims=768,
  139. conv_type='Conv2d',
  140. kernel_size=16,
  141. stride=16,
  142. padding='corner',
  143. dilation=1,
  144. bias=True,
  145. norm_cfg=None,
  146. input_size=None,
  147. init_cfg=None,
  148. ):
  149. super(PatchEmbed, self).__init__(init_cfg=init_cfg)
  150. self.embed_dims = embed_dims
  151. if stride is None:
  152. stride = kernel_size
  153. kernel_size = to_2tuple(kernel_size)
  154. stride = to_2tuple(stride)
  155. dilation = to_2tuple(dilation)
  156. if isinstance(padding, str):
  157. self.adap_padding = AdaptivePadding(
  158. kernel_size=kernel_size,
  159. stride=stride,
  160. dilation=dilation,
  161. padding=padding)
  162. # disable the padding of conv
  163. padding = 0
  164. else:
  165. self.adap_padding = None
  166. padding = to_2tuple(padding)
  167. self.projection = build_conv_layer(
  168. dict(type=conv_type),
  169. in_channels=in_channels,
  170. out_channels=embed_dims,
  171. kernel_size=kernel_size,
  172. stride=stride,
  173. padding=padding,
  174. dilation=dilation,
  175. bias=bias)
  176. if norm_cfg is not None:
  177. self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
  178. else:
  179. self.norm = None
  180. if input_size:
  181. input_size = to_2tuple(input_size)
  182. # `init_out_size` would be used outside to
  183. # calculate the num_patches
  184. # when `use_abs_pos_embed` outside
  185. self.init_input_size = input_size
  186. if self.adap_padding:
  187. pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
  188. input_h, input_w = input_size
  189. input_h = input_h + pad_h
  190. input_w = input_w + pad_w
  191. input_size = (input_h, input_w)
  192. # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  193. h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
  194. (kernel_size[0] - 1) - 1) // stride[0] + 1
  195. w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
  196. (kernel_size[1] - 1) - 1) // stride[1] + 1
  197. self.init_out_size = (h_out, w_out)
  198. else:
  199. self.init_input_size = None
  200. self.init_out_size = None
  201. def forward(self, x):
  202. """
  203. Args:
  204. x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
  205. Returns:
  206. tuple: Contains merged results and its spatial shape.
  207. - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
  208. - out_size (tuple[int]): Spatial shape of x, arrange as
  209. (out_h, out_w).
  210. """
  211. if self.adap_padding:
  212. x = self.adap_padding(x)
  213. x = self.projection(x)
  214. out_size = (x.shape[2], x.shape[3])
  215. x = x.flatten(2).transpose(1, 2)
  216. if self.norm is not None:
  217. x = self.norm(x)
  218. return x, out_size
  219. class PatchMerging(BaseModule):
  220. """Merge patch feature map.
  221. This layer groups feature map by kernel_size, and applies norm and linear
  222. layers to the grouped feature map. Our implementation uses `nn.Unfold` to
  223. merge patch, which is about 25% faster than original implementation.
  224. Instead, we need to modify pretrained models for compatibility.
  225. Args:
  226. in_channels (int): The num of input channels.
  227. to gets fully covered by filter and stride you specified..
  228. Default: True.
  229. out_channels (int): The num of output channels.
  230. kernel_size (int | tuple, optional): the kernel size in the unfold
  231. layer. Defaults to 2.
  232. stride (int | tuple, optional): the stride of the sliding blocks in the
  233. unfold layer. Default: None. (Would be set as `kernel_size`)
  234. padding (int | tuple | string ): The padding length of
  235. embedding conv. When it is a string, it means the mode
  236. of adaptive padding, support "same" and "corner" now.
  237. Default: "corner".
  238. dilation (int | tuple, optional): dilation parameter in the unfold
  239. layer. Default: 1.
  240. bias (bool, optional): Whether to add bias in linear layer or not.
  241. Defaults: False.
  242. norm_cfg (dict, optional): Config dict for normalization layer.
  243. Default: dict(type='LN').
  244. init_cfg (dict, optional): The extra config for initialization.
  245. Default: None.
  246. """
  247. def __init__(self,
  248. in_channels,
  249. out_channels,
  250. kernel_size=2,
  251. stride=None,
  252. padding='corner',
  253. dilation=1,
  254. bias=False,
  255. norm_cfg=dict(type='LN'),
  256. init_cfg=None):
  257. super().__init__(init_cfg=init_cfg)
  258. self.in_channels = in_channels
  259. self.out_channels = out_channels
  260. if stride:
  261. stride = stride
  262. else:
  263. stride = kernel_size
  264. kernel_size = to_2tuple(kernel_size)
  265. stride = to_2tuple(stride)
  266. dilation = to_2tuple(dilation)
  267. if isinstance(padding, str):
  268. self.adap_padding = AdaptivePadding(
  269. kernel_size=kernel_size,
  270. stride=stride,
  271. dilation=dilation,
  272. padding=padding)
  273. # disable the padding of unfold
  274. padding = 0
  275. else:
  276. self.adap_padding = None
  277. padding = to_2tuple(padding)
  278. self.sampler = nn.Unfold(
  279. kernel_size=kernel_size,
  280. dilation=dilation,
  281. padding=padding,
  282. stride=stride)
  283. sample_dim = kernel_size[0] * kernel_size[1] * in_channels
  284. if norm_cfg is not None:
  285. self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
  286. else:
  287. self.norm = None
  288. self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
  289. def forward(self, x, input_size):
  290. """
  291. Args:
  292. x (Tensor): Has shape (B, H*W, C_in).
  293. input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
  294. Default: None.
  295. Returns:
  296. tuple: Contains merged results and its spatial shape.
  297. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
  298. - out_size (tuple[int]): Spatial shape of x, arrange as
  299. (Merged_H, Merged_W).
  300. """
  301. B, L, C = x.shape
  302. assert isinstance(input_size, Sequence), f'Expect ' \
  303. f'input_size is ' \
  304. f'`Sequence` ' \
  305. f'but get {input_size}'
  306. H, W = input_size
  307. assert L == H * W, 'input feature has wrong size'
  308. x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
  309. # Use nn.Unfold to merge patch. About 25% faster than original method,
  310. # but need to modify pretrained model for compatibility
  311. if self.adap_padding:
  312. x = self.adap_padding(x)
  313. H, W = x.shape[-2:]
  314. x = self.sampler(x)
  315. # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
  316. out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
  317. (self.sampler.kernel_size[0] - 1) -
  318. 1) // self.sampler.stride[0] + 1
  319. out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
  320. (self.sampler.kernel_size[1] - 1) -
  321. 1) // self.sampler.stride[1] + 1
  322. output_size = (out_h, out_w)
  323. x = x.transpose(1, 2) # B, H/2*W/2, 4*C
  324. x = self.norm(x) if self.norm else x
  325. x = self.reduction(x)
  326. return x, output_size
  327. def inverse_sigmoid(x, eps=1e-5):
  328. """Inverse function of sigmoid.
  329. Args:
  330. x (Tensor): The tensor to do the
  331. inverse.
  332. eps (float): EPS avoid numerical
  333. overflow. Defaults 1e-5.
  334. Returns:
  335. Tensor: The x has passed the inverse
  336. function of sigmoid, has same
  337. shape with input.
  338. """
  339. x = x.clamp(min=0, max=1)
  340. x1 = x.clamp(min=eps)
  341. x2 = (1 - x).clamp(min=eps)
  342. return torch.log(x1 / x2)
  343. @TRANSFORMER_LAYER.register_module()
  344. class DetrTransformerDecoderLayer(BaseTransformerLayer):
  345. """Implements decoder layer in DETR transformer.
  346. Args:
  347. attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
  348. Configs for self_attention or cross_attention, the order
  349. should be consistent with it in `operation_order`. If it is
  350. a dict, it would be expand to the number of attention in
  351. `operation_order`.
  352. feedforward_channels (int): The hidden dimension for FFNs.
  353. ffn_dropout (float): Probability of an element to be zeroed
  354. in ffn. Default 0.0.
  355. operation_order (tuple[str]): The execution order of operation
  356. in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
  357. Default:None
  358. act_cfg (dict): The activation config for FFNs. Default: `LN`
  359. norm_cfg (dict): Config dict for normalization layer.
  360. Default: `LN`.
  361. ffn_num_fcs (int): The number of fully-connected layers in FFNs.
  362. Default:2.
  363. """
  364. def __init__(self,
  365. attn_cfgs,
  366. feedforward_channels,
  367. ffn_dropout=0.0,
  368. operation_order=None,
  369. act_cfg=dict(type='ReLU', inplace=True),
  370. norm_cfg=dict(type='LN'),
  371. ffn_num_fcs=2,
  372. **kwargs):
  373. super(DetrTransformerDecoderLayer, self).__init__(
  374. attn_cfgs=attn_cfgs,
  375. feedforward_channels=feedforward_channels,
  376. ffn_dropout=ffn_dropout,
  377. operation_order=operation_order,
  378. act_cfg=act_cfg,
  379. norm_cfg=norm_cfg,
  380. ffn_num_fcs=ffn_num_fcs,
  381. **kwargs)
  382. assert len(operation_order) == 6
  383. assert set(operation_order) == set(
  384. ['self_attn', 'norm', 'cross_attn', 'ffn'])
  385. @TRANSFORMER_LAYER_SEQUENCE.register_module()
  386. class DetrTransformerEncoder(TransformerLayerSequence):
  387. """TransformerEncoder of DETR.
  388. Args:
  389. post_norm_cfg (dict): Config of last normalization layer. Default:
  390. `LN`. Only used when `self.pre_norm` is `True`
  391. """
  392. def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):
  393. super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
  394. if post_norm_cfg is not None:
  395. self.post_norm = build_norm_layer(
  396. post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
  397. else:
  398. assert not self.pre_norm, f'Use prenorm in ' \
  399. f'{self.__class__.__name__},' \
  400. f'Please specify post_norm_cfg'
  401. self.post_norm = None
  402. def forward(self, *args, **kwargs):
  403. """Forward function for `TransformerCoder`.
  404. Returns:
  405. Tensor: forwarded results with shape [num_query, bs, embed_dims].
  406. """
  407. x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
  408. if self.post_norm is not None:
  409. x = self.post_norm(x)
  410. return x
  411. @TRANSFORMER_LAYER_SEQUENCE.register_module()
  412. class DetrTransformerDecoder(TransformerLayerSequence):
  413. """Implements the decoder in DETR transformer.
  414. Args:
  415. return_intermediate (bool): Whether to return intermediate outputs.
  416. post_norm_cfg (dict): Config of last normalization layer. Default:
  417. `LN`.
  418. """
  419. def __init__(self,
  420. *args,
  421. post_norm_cfg=dict(type='LN'),
  422. return_intermediate=False,
  423. **kwargs):
  424. super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
  425. self.return_intermediate = return_intermediate
  426. if post_norm_cfg is not None:
  427. self.post_norm = build_norm_layer(post_norm_cfg,
  428. self.embed_dims)[1]
  429. else:
  430. self.post_norm = None
  431. def forward(self, query, *args, **kwargs):
  432. """Forward function for `TransformerDecoder`.
  433. Args:
  434. query (Tensor): Input query with shape
  435. `(num_query, bs, embed_dims)`.
  436. Returns:
  437. Tensor: Results with shape [1, num_query, bs, embed_dims] when
  438. return_intermediate is `False`, otherwise it has shape
  439. [num_layers, num_query, bs, embed_dims].
  440. """
  441. if not self.return_intermediate:
  442. x = super().forward(query, *args, **kwargs)
  443. if self.post_norm:
  444. x = self.post_norm(x)[None]
  445. return x
  446. intermediate = []
  447. for layer in self.layers:
  448. query = layer(query, *args, **kwargs)
  449. if self.return_intermediate:
  450. if self.post_norm is not None:
  451. intermediate.append(self.post_norm(query))
  452. else:
  453. intermediate.append(query)
  454. return torch.stack(intermediate)
  455. @TRANSFORMER.register_module()
  456. class Transformer(BaseModule):
  457. """Implements the DETR transformer.
  458. Following the official DETR implementation, this module copy-paste
  459. from torch.nn.Transformer with modifications:
  460. * positional encodings are passed in MultiheadAttention
  461. * extra LN at the end of encoder is removed
  462. * decoder returns a stack of activations from all decoding layers
  463. See `paper: End-to-End Object Detection with Transformers
  464. <https://arxiv.org/pdf/2005.12872>`_ for details.
  465. Args:
  466. encoder (`mmcv.ConfigDict` | Dict): Config of
  467. TransformerEncoder. Defaults to None.
  468. decoder ((`mmcv.ConfigDict` | Dict)): Config of
  469. TransformerDecoder. Defaults to None
  470. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  471. Defaults to None.
  472. """
  473. def __init__(self, encoder=None, decoder=None, init_cfg=None):
  474. super(Transformer, self).__init__(init_cfg=init_cfg)
  475. self.encoder = build_transformer_layer_sequence(encoder)
  476. self.decoder = build_transformer_layer_sequence(decoder)
  477. self.embed_dims = self.encoder.embed_dims
  478. def init_weights(self):
  479. # follow the official DETR to init parameters
  480. for m in self.modules():
  481. if hasattr(m, 'weight') and m.weight.dim() > 1:
  482. xavier_init(m, distribution='uniform')
  483. self._is_init = True
  484. def forward(self, x, mask, query_embed, pos_embed):
  485. """Forward function for `Transformer`.
  486. Args:
  487. x (Tensor): Input query with shape [bs, c, h, w] where
  488. c = embed_dims.
  489. mask (Tensor): The key_padding_mask used for encoder and decoder,
  490. with shape [bs, h, w].
  491. query_embed (Tensor): The query embedding for decoder, with shape
  492. [num_query, c].
  493. pos_embed (Tensor): The positional encoding for encoder and
  494. decoder, with the same shape as `x`.
  495. Returns:
  496. tuple[Tensor]: results of decoder containing the following tensor.
  497. - out_dec: Output from decoder. If return_intermediate_dec \
  498. is True output has shape [num_dec_layers, bs,
  499. num_query, embed_dims], else has shape [1, bs, \
  500. num_query, embed_dims].
  501. - memory: Output results from encoder, with shape \
  502. [bs, embed_dims, h, w].
  503. """
  504. bs, c, h, w = x.shape
  505. # use `view` instead of `flatten` for dynamically exporting to ONNX
  506. x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
  507. pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
  508. query_embed = query_embed.unsqueeze(1).repeat(
  509. 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
  510. mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
  511. memory = self.encoder(
  512. query=x,
  513. key=None,
  514. value=None,
  515. query_pos=pos_embed,
  516. query_key_padding_mask=mask)
  517. target = torch.zeros_like(query_embed)
  518. # out_dec: [num_layers, num_query, bs, dim]
  519. out_dec = self.decoder(
  520. query=target,
  521. key=memory,
  522. value=memory,
  523. key_pos=pos_embed,
  524. query_pos=query_embed,
  525. key_padding_mask=mask)
  526. out_dec = out_dec.transpose(1, 2)
  527. memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
  528. return out_dec, memory
  529. @TRANSFORMER_LAYER_SEQUENCE.register_module()
  530. class DeformableDetrTransformerDecoder(TransformerLayerSequence):
  531. """Implements the decoder in DETR transformer.
  532. Args:
  533. return_intermediate (bool): Whether to return intermediate outputs.
  534. coder_norm_cfg (dict): Config of last normalization layer. Default:
  535. `LN`.
  536. """
  537. def __init__(self, *args, return_intermediate=False, **kwargs):
  538. super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
  539. self.return_intermediate = return_intermediate
  540. def forward(self,
  541. query,
  542. *args,
  543. reference_points=None,
  544. valid_ratios=None,
  545. reg_branches=None,
  546. **kwargs):
  547. """Forward function for `TransformerDecoder`.
  548. Args:
  549. query (Tensor): Input query with shape
  550. `(num_query, bs, embed_dims)`.
  551. reference_points (Tensor): The reference
  552. points of offset. has shape
  553. (bs, num_query, 4) when as_two_stage,
  554. otherwise has shape ((bs, num_query, 2).
  555. valid_ratios (Tensor): The radios of valid
  556. points on the feature map, has shape
  557. (bs, num_levels, 2)
  558. reg_branch: (obj:`nn.ModuleList`): Used for
  559. refining the regression results. Only would
  560. be passed when with_box_refine is True,
  561. otherwise would be passed a `None`.
  562. Returns:
  563. Tensor: Results with shape [1, num_query, bs, embed_dims] when
  564. return_intermediate is `False`, otherwise it has shape
  565. [num_layers, num_query, bs, embed_dims].
  566. """
  567. output = query
  568. intermediate = []
  569. intermediate_reference_points = []
  570. for lid, layer in enumerate(self.layers):
  571. if reference_points.shape[-1] == 4:
  572. reference_points_input = reference_points[:, :, None] * \
  573. torch.cat([valid_ratios, valid_ratios], -1)[:, None]
  574. else:
  575. assert reference_points.shape[-1] == 2
  576. reference_points_input = reference_points[:, :, None] * \
  577. valid_ratios[:, None]
  578. output = layer(
  579. output,
  580. *args,
  581. reference_points=reference_points_input,
  582. **kwargs)
  583. output = output.permute(1, 0, 2)
  584. if reg_branches is not None:
  585. tmp = reg_branches[lid](output)
  586. if reference_points.shape[-1] == 4:
  587. new_reference_points = tmp + inverse_sigmoid(
  588. reference_points)
  589. new_reference_points = new_reference_points.sigmoid()
  590. else:
  591. assert reference_points.shape[-1] == 2
  592. new_reference_points = tmp
  593. new_reference_points[..., :2] = tmp[
  594. ..., :2] + inverse_sigmoid(reference_points)
  595. new_reference_points = new_reference_points.sigmoid()
  596. reference_points = new_reference_points.detach()
  597. output = output.permute(1, 0, 2)
  598. if self.return_intermediate:
  599. intermediate.append(output)
  600. intermediate_reference_points.append(reference_points)
  601. if self.return_intermediate:
  602. return torch.stack(intermediate), torch.stack(
  603. intermediate_reference_points)
  604. return output, reference_points
  605. @TRANSFORMER.register_module()
  606. class DeformableDetrTransformer(Transformer):
  607. """Implements the DeformableDETR transformer.
  608. Args:
  609. as_two_stage (bool): Generate query from encoder features.
  610. Default: False.
  611. num_feature_levels (int): Number of feature maps from FPN:
  612. Default: 4.
  613. two_stage_num_proposals (int): Number of proposals when set
  614. `as_two_stage` as True. Default: 300.
  615. """
  616. def __init__(self,
  617. as_two_stage=False,
  618. num_feature_levels=4,
  619. two_stage_num_proposals=300,
  620. **kwargs):
  621. super(DeformableDetrTransformer, self).__init__(**kwargs)
  622. self.as_two_stage = as_two_stage
  623. self.num_feature_levels = num_feature_levels
  624. self.two_stage_num_proposals = two_stage_num_proposals
  625. self.embed_dims = self.encoder.embed_dims
  626. self.init_layers()
  627. def init_layers(self):
  628. """Initialize layers of the DeformableDetrTransformer."""
  629. self.level_embeds = nn.Parameter(
  630. torch.Tensor(self.num_feature_levels, self.embed_dims))
  631. if self.as_two_stage:
  632. self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
  633. self.enc_output_norm = nn.LayerNorm(self.embed_dims)
  634. self.pos_trans = nn.Linear(self.embed_dims * 2,
  635. self.embed_dims * 2)
  636. self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
  637. else:
  638. self.reference_points = nn.Linear(self.embed_dims, 2)
  639. def init_weights(self):
  640. """Initialize the transformer weights."""
  641. for p in self.parameters():
  642. if p.dim() > 1:
  643. nn.init.xavier_uniform_(p)
  644. for m in self.modules():
  645. if isinstance(m, MultiScaleDeformableAttention):
  646. m.init_weights()
  647. if not self.as_two_stage:
  648. xavier_init(self.reference_points, distribution='uniform', bias=0.)
  649. normal_(self.level_embeds)
  650. def gen_encoder_output_proposals(self, memory, memory_padding_mask,
  651. spatial_shapes):
  652. """Generate proposals from encoded memory.
  653. Args:
  654. memory (Tensor) : The output of encoder,
  655. has shape (bs, num_key, embed_dim). num_key is
  656. equal the number of points on feature map from
  657. all level.
  658. memory_padding_mask (Tensor): Padding mask for memory.
  659. has shape (bs, num_key).
  660. spatial_shapes (Tensor): The shape of all feature maps.
  661. has shape (num_level, 2).
  662. Returns:
  663. tuple: A tuple of feature map and bbox prediction.
  664. - output_memory (Tensor): The input of decoder, \
  665. has shape (bs, num_key, embed_dim). num_key is \
  666. equal the number of points on feature map from \
  667. all levels.
  668. - output_proposals (Tensor): The normalized proposal \
  669. after a inverse sigmoid, has shape \
  670. (bs, num_keys, 4).
  671. """
  672. N, S, C = memory.shape
  673. proposals = []
  674. _cur = 0
  675. for lvl, (H, W) in enumerate(spatial_shapes):
  676. mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(
  677. N, H, W, 1)
  678. valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
  679. valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
  680. grid_y, grid_x = torch.meshgrid(
  681. torch.linspace(
  682. 0, H - 1, H, dtype=torch.float32, device=memory.device),
  683. torch.linspace(
  684. 0, W - 1, W, dtype=torch.float32, device=memory.device))
  685. grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
  686. scale = torch.cat([valid_W.unsqueeze(-1),
  687. valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
  688. grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
  689. wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
  690. proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
  691. proposals.append(proposal)
  692. _cur += (H * W)
  693. output_proposals = torch.cat(proposals, 1)
  694. output_proposals_valid = ((output_proposals > 0.01) &
  695. (output_proposals < 0.99)).all(
  696. -1, keepdim=True)
  697. output_proposals = torch.log(output_proposals / (1 - output_proposals))
  698. output_proposals = output_proposals.masked_fill(
  699. memory_padding_mask.unsqueeze(-1), float('inf'))
  700. output_proposals = output_proposals.masked_fill(
  701. ~output_proposals_valid, float('inf'))
  702. output_memory = memory
  703. output_memory = output_memory.masked_fill(
  704. memory_padding_mask.unsqueeze(-1), float(0))
  705. output_memory = output_memory.masked_fill(~output_proposals_valid,
  706. float(0))
  707. output_memory = self.enc_output_norm(self.enc_output(output_memory))
  708. return output_memory, output_proposals
  709. @staticmethod
  710. def get_reference_points(spatial_shapes, valid_ratios, device):
  711. """Get the reference points used in decoder.
  712. Args:
  713. spatial_shapes (Tensor): The shape of all
  714. feature maps, has shape (num_level, 2).
  715. valid_ratios (Tensor): The radios of valid
  716. points on the feature map, has shape
  717. (bs, num_levels, 2)
  718. device (obj:`device`): The device where
  719. reference_points should be.
  720. Returns:
  721. Tensor: reference points used in decoder, has \
  722. shape (bs, num_keys, num_levels, 2).
  723. """
  724. reference_points_list = []
  725. for lvl, (H, W) in enumerate(spatial_shapes):
  726. # TODO check this 0.5
  727. ref_y, ref_x = torch.meshgrid(
  728. torch.linspace(
  729. 0.5, H - 0.5, H, dtype=torch.float32, device=device),
  730. torch.linspace(
  731. 0.5, W - 0.5, W, dtype=torch.float32, device=device))
  732. ref_y = ref_y.reshape(-1)[None] / (
  733. valid_ratios[:, None, lvl, 1] * H)
  734. ref_x = ref_x.reshape(-1)[None] / (
  735. valid_ratios[:, None, lvl, 0] * W)
  736. ref = torch.stack((ref_x, ref_y), -1)
  737. reference_points_list.append(ref)
  738. reference_points = torch.cat(reference_points_list, 1)
  739. reference_points = reference_points[:, :, None] * valid_ratios[:, None]
  740. return reference_points
  741. def get_valid_ratio(self, mask):
  742. """Get the valid radios of feature maps of all level."""
  743. _, H, W = mask.shape
  744. valid_H = torch.sum(~mask[:, :, 0], 1)
  745. valid_W = torch.sum(~mask[:, 0, :], 1)
  746. valid_ratio_h = valid_H.float() / H
  747. valid_ratio_w = valid_W.float() / W
  748. valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
  749. return valid_ratio
  750. def get_proposal_pos_embed(self,
  751. proposals,
  752. num_pos_feats=128,
  753. temperature=10000):
  754. """Get the position embedding of proposal."""
  755. scale = 2 * math.pi
  756. dim_t = torch.arange(
  757. num_pos_feats, dtype=torch.float32, device=proposals.device)
  758. dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
  759. # N, L, 4
  760. proposals = proposals.sigmoid() * scale
  761. # N, L, 4, 128
  762. pos = proposals[:, :, :, None] / dim_t
  763. # N, L, 4, 64, 2
  764. pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
  765. dim=4).flatten(2)
  766. return pos
  767. def forward(self,
  768. mlvl_feats,
  769. mlvl_masks,
  770. query_embed,
  771. mlvl_pos_embeds,
  772. reg_branches=None,
  773. cls_branches=None,
  774. **kwargs):
  775. """Forward function for `Transformer`.
  776. Args:
  777. mlvl_feats (list(Tensor)): Input queries from
  778. different level. Each element has shape
  779. [bs, embed_dims, h, w].
  780. mlvl_masks (list(Tensor)): The key_padding_mask from
  781. different level used for encoder and decoder,
  782. each element has shape [bs, h, w].
  783. query_embed (Tensor): The query embedding for decoder,
  784. with shape [num_query, c].
  785. mlvl_pos_embeds (list(Tensor)): The positional encoding
  786. of feats from different level, has the shape
  787. [bs, embed_dims, h, w].
  788. reg_branches (obj:`nn.ModuleList`): Regression heads for
  789. feature maps from each decoder layer. Only would
  790. be passed when
  791. `with_box_refine` is True. Default to None.
  792. cls_branches (obj:`nn.ModuleList`): Classification heads
  793. for feature maps from each decoder layer. Only would
  794. be passed when `as_two_stage`
  795. is True. Default to None.
  796. Returns:
  797. tuple[Tensor]: results of decoder containing the following tensor.
  798. - inter_states: Outputs from decoder. If
  799. return_intermediate_dec is True output has shape \
  800. (num_dec_layers, bs, num_query, embed_dims), else has \
  801. shape (1, bs, num_query, embed_dims).
  802. - init_reference_out: The initial value of reference \
  803. points, has shape (bs, num_queries, 4).
  804. - inter_references_out: The internal value of reference \
  805. points in decoder, has shape \
  806. (num_dec_layers, bs,num_query, embed_dims)
  807. - enc_outputs_class: The classification score of \
  808. proposals generated from \
  809. encoder's feature maps, has shape \
  810. (batch, h*w, num_classes). \
  811. Only would be returned when `as_two_stage` is True, \
  812. otherwise None.
  813. - enc_outputs_coord_unact: The regression results \
  814. generated from encoder's feature maps., has shape \
  815. (batch, h*w, 4). Only would \
  816. be returned when `as_two_stage` is True, \
  817. otherwise None.
  818. """
  819. assert self.as_two_stage or query_embed is not None
  820. feat_flatten = []
  821. mask_flatten = []
  822. lvl_pos_embed_flatten = []
  823. spatial_shapes = []
  824. for lvl, (feat, mask, pos_embed) in enumerate(
  825. zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
  826. bs, c, h, w = feat.shape
  827. spatial_shape = (h, w)
  828. spatial_shapes.append(spatial_shape)
  829. feat = feat.flatten(2).transpose(1, 2)
  830. mask = mask.flatten(1)
  831. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  832. lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
  833. lvl_pos_embed_flatten.append(lvl_pos_embed)
  834. feat_flatten.append(feat)
  835. mask_flatten.append(mask)
  836. feat_flatten = torch.cat(feat_flatten, 1)
  837. mask_flatten = torch.cat(mask_flatten, 1)
  838. lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
  839. spatial_shapes = torch.as_tensor(
  840. spatial_shapes, dtype=torch.long, device=feat_flatten.device)
  841. level_start_index = torch.cat((spatial_shapes.new_zeros(
  842. (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
  843. valid_ratios = torch.stack(
  844. [self.get_valid_ratio(m) for m in mlvl_masks], 1)
  845. reference_points = \
  846. self.get_reference_points(spatial_shapes,
  847. valid_ratios,
  848. device=feat.device)
  849. feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
  850. lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
  851. 1, 0, 2) # (H*W, bs, embed_dims)
  852. memory = self.encoder(
  853. query=feat_flatten,
  854. key=None,
  855. value=None,
  856. query_pos=lvl_pos_embed_flatten,
  857. query_key_padding_mask=mask_flatten,
  858. spatial_shapes=spatial_shapes,
  859. reference_points=reference_points,
  860. level_start_index=level_start_index,
  861. valid_ratios=valid_ratios,
  862. **kwargs)
  863. memory = memory.permute(1, 0, 2)
  864. bs, _, c = memory.shape
  865. if self.as_two_stage:
  866. output_memory, output_proposals = \
  867. self.gen_encoder_output_proposals(
  868. memory, mask_flatten, spatial_shapes)
  869. enc_outputs_class = cls_branches[self.decoder.num_layers](
  870. output_memory)
  871. enc_outputs_coord_unact = \
  872. reg_branches[
  873. self.decoder.num_layers](output_memory) + output_proposals
  874. topk = self.two_stage_num_proposals
  875. topk_proposals = torch.topk(
  876. enc_outputs_class[..., 0], topk, dim=1)[1]
  877. topk_coords_unact = torch.gather(
  878. enc_outputs_coord_unact, 1,
  879. topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
  880. topk_coords_unact = topk_coords_unact.detach()
  881. reference_points = topk_coords_unact.sigmoid()
  882. init_reference_out = reference_points
  883. pos_trans_out = self.pos_trans_norm(
  884. self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
  885. query_pos, query = torch.split(pos_trans_out, c, dim=2)
  886. else:
  887. query_pos, query = torch.split(query_embed, c, dim=1)
  888. query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
  889. query = query.unsqueeze(0).expand(bs, -1, -1)
  890. reference_points = self.reference_points(query_pos).sigmoid()
  891. init_reference_out = reference_points
  892. # decoder
  893. query = query.permute(1, 0, 2)
  894. memory = memory.permute(1, 0, 2)
  895. query_pos = query_pos.permute(1, 0, 2)
  896. inter_states, inter_references = self.decoder(
  897. query=query,
  898. key=None,
  899. value=memory,
  900. query_pos=query_pos,
  901. key_padding_mask=mask_flatten,
  902. reference_points=reference_points,
  903. spatial_shapes=spatial_shapes,
  904. level_start_index=level_start_index,
  905. valid_ratios=valid_ratios,
  906. reg_branches=reg_branches,
  907. **kwargs)
  908. inter_references_out = inter_references
  909. if self.as_two_stage:
  910. return inter_states, init_reference_out,\
  911. inter_references_out, enc_outputs_class,\
  912. enc_outputs_coord_unact
  913. return inter_states, init_reference_out, \
  914. inter_references_out, None, None
  915. @TRANSFORMER.register_module()
  916. class DynamicConv(BaseModule):
  917. """Implements Dynamic Convolution.
  918. This module generate parameters for each sample and
  919. use bmm to implement 1*1 convolution. Code is modified
  920. from the `official github repo <https://github.com/PeizeSun/
  921. SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
  922. Args:
  923. in_channels (int): The input feature channel.
  924. Defaults to 256.
  925. feat_channels (int): The inner feature channel.
  926. Defaults to 64.
  927. out_channels (int, optional): The output feature channel.
  928. When not specified, it will be set to `in_channels`
  929. by default
  930. input_feat_shape (int): The shape of input feature.
  931. Defaults to 7.
  932. with_proj (bool): Project two-dimentional feature to
  933. one-dimentional feature. Default to True.
  934. act_cfg (dict): The activation config for DynamicConv.
  935. norm_cfg (dict): Config dict for normalization layer. Default
  936. layer normalization.
  937. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  938. Default: None.
  939. """
  940. def __init__(self,
  941. in_channels=256,
  942. feat_channels=64,
  943. out_channels=None,
  944. input_feat_shape=7,
  945. with_proj=True,
  946. act_cfg=dict(type='ReLU', inplace=True),
  947. norm_cfg=dict(type='LN'),
  948. init_cfg=None):
  949. super(DynamicConv, self).__init__(init_cfg)
  950. self.in_channels = in_channels
  951. self.feat_channels = feat_channels
  952. self.out_channels_raw = out_channels
  953. self.input_feat_shape = input_feat_shape
  954. self.with_proj = with_proj
  955. self.act_cfg = act_cfg
  956. self.norm_cfg = norm_cfg
  957. self.out_channels = out_channels if out_channels else in_channels
  958. self.num_params_in = self.in_channels * self.feat_channels
  959. self.num_params_out = self.out_channels * self.feat_channels
  960. self.dynamic_layer = nn.Linear(
  961. self.in_channels, self.num_params_in + self.num_params_out)
  962. self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
  963. self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
  964. self.activation = build_activation_layer(act_cfg)
  965. num_output = self.out_channels * input_feat_shape**2
  966. if self.with_proj:
  967. self.fc_layer = nn.Linear(num_output, self.out_channels)
  968. self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
  969. def forward(self, param_feature, input_feature):
  970. """Forward function for `DynamicConv`.
  971. Args:
  972. param_feature (Tensor): The feature can be used
  973. to generate the parameter, has shape
  974. (num_all_proposals, in_channels).
  975. input_feature (Tensor): Feature that
  976. interact with parameters, has shape
  977. (num_all_proposals, in_channels, H, W).
  978. Returns:
  979. Tensor: The output feature has shape
  980. (num_all_proposals, out_channels).
  981. """
  982. input_feature = input_feature.flatten(2).permute(2, 0, 1)
  983. input_feature = input_feature.permute(1, 0, 2)
  984. parameters = self.dynamic_layer(param_feature)
  985. param_in = parameters[:, :self.num_params_in].view(
  986. -1, self.in_channels, self.feat_channels)
  987. param_out = parameters[:, -self.num_params_out:].view(
  988. -1, self.feat_channels, self.out_channels)
  989. # input_feature has shape (num_all_proposals, H*W, in_channels)
  990. # param_in has shape (num_all_proposals, in_channels, feat_channels)
  991. # feature has shape (num_all_proposals, H*W, feat_channels)
  992. features = torch.bmm(input_feature, param_in)
  993. features = self.norm_in(features)
  994. features = self.activation(features)
  995. # param_out has shape (batch_size, feat_channels, out_channels)
  996. features = torch.bmm(features, param_out)
  997. features = self.norm_out(features)
  998. features = self.activation(features)
  999. if self.with_proj:
  1000. features = features.flatten(1)
  1001. features = self.fc_layer(features)
  1002. features = self.fc_norm(features)
  1003. features = self.activation(features)
  1004. return features

No Description

Contributors (2)