You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pvt.py 23 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
  9. constant_init, normal_init, trunc_normal_init)
  10. from mmcv.cnn.bricks.drop import build_dropout
  11. from mmcv.cnn.bricks.transformer import MultiheadAttention
  12. from mmcv.cnn.utils.weight_init import trunc_normal_
  13. from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
  14. load_state_dict)
  15. from torch.nn.modules.utils import _pair as to_2tuple
  16. from ...utils import get_root_logger
  17. from ..builder import BACKBONES
  18. from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert
  19. class MixFFN(BaseModule):
  20. """An implementation of MixFFN of PVT.
  21. The differences between MixFFN & FFN:
  22. 1. Use 1X1 Conv to replace Linear layer.
  23. 2. Introduce 3X3 Depth-wise Conv to encode positional information.
  24. Args:
  25. embed_dims (int): The feature dimension. Same as
  26. `MultiheadAttention`.
  27. feedforward_channels (int): The hidden dimension of FFNs.
  28. act_cfg (dict, optional): The activation config for FFNs.
  29. Default: dict(type='GELU').
  30. ffn_drop (float, optional): Probability of an element to be
  31. zeroed in FFN. Default 0.0.
  32. dropout_layer (obj:`ConfigDict`): The dropout_layer used
  33. when adding the shortcut.
  34. Default: None.
  35. use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
  36. Defaults: False.
  37. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  38. Default: None.
  39. """
  40. def __init__(self,
  41. embed_dims,
  42. feedforward_channels,
  43. act_cfg=dict(type='GELU'),
  44. ffn_drop=0.,
  45. dropout_layer=None,
  46. use_conv=False,
  47. init_cfg=None):
  48. super(MixFFN, self).__init__(init_cfg=init_cfg)
  49. self.embed_dims = embed_dims
  50. self.feedforward_channels = feedforward_channels
  51. self.act_cfg = act_cfg
  52. activate = build_activation_layer(act_cfg)
  53. in_channels = embed_dims
  54. fc1 = Conv2d(
  55. in_channels=in_channels,
  56. out_channels=feedforward_channels,
  57. kernel_size=1,
  58. stride=1,
  59. bias=True)
  60. if use_conv:
  61. # 3x3 depth wise conv to provide positional encode information
  62. dw_conv = Conv2d(
  63. in_channels=feedforward_channels,
  64. out_channels=feedforward_channels,
  65. kernel_size=3,
  66. stride=1,
  67. padding=(3 - 1) // 2,
  68. bias=True,
  69. groups=feedforward_channels)
  70. fc2 = Conv2d(
  71. in_channels=feedforward_channels,
  72. out_channels=in_channels,
  73. kernel_size=1,
  74. stride=1,
  75. bias=True)
  76. drop = nn.Dropout(ffn_drop)
  77. layers = [fc1, activate, drop, fc2, drop]
  78. if use_conv:
  79. layers.insert(1, dw_conv)
  80. self.layers = Sequential(*layers)
  81. self.dropout_layer = build_dropout(
  82. dropout_layer) if dropout_layer else torch.nn.Identity()
  83. def forward(self, x, hw_shape, identity=None):
  84. out = nlc_to_nchw(x, hw_shape)
  85. out = self.layers(out)
  86. out = nchw_to_nlc(out)
  87. if identity is None:
  88. identity = x
  89. return identity + self.dropout_layer(out)
  90. class SpatialReductionAttention(MultiheadAttention):
  91. """An implementation of Spatial Reduction Attention of PVT.
  92. This module is modified from MultiheadAttention which is a module from
  93. mmcv.cnn.bricks.transformer.
  94. Args:
  95. embed_dims (int): The embedding dimension.
  96. num_heads (int): Parallel attention heads.
  97. attn_drop (float): A Dropout layer on attn_output_weights.
  98. Default: 0.0.
  99. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
  100. Default: 0.0.
  101. dropout_layer (obj:`ConfigDict`): The dropout_layer used
  102. when adding the shortcut. Default: None.
  103. batch_first (bool): Key, Query and Value are shape of
  104. (batch, n, embed_dim)
  105. or (n, batch, embed_dim). Default: False.
  106. qkv_bias (bool): enable bias for qkv if True. Default: True.
  107. norm_cfg (dict): Config dict for normalization layer.
  108. Default: dict(type='LN').
  109. sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
  110. Attention of PVT. Default: 1.
  111. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  112. Default: None.
  113. """
  114. def __init__(self,
  115. embed_dims,
  116. num_heads,
  117. attn_drop=0.,
  118. proj_drop=0.,
  119. dropout_layer=None,
  120. batch_first=True,
  121. qkv_bias=True,
  122. norm_cfg=dict(type='LN'),
  123. sr_ratio=1,
  124. init_cfg=None):
  125. super().__init__(
  126. embed_dims,
  127. num_heads,
  128. attn_drop,
  129. proj_drop,
  130. batch_first=batch_first,
  131. dropout_layer=dropout_layer,
  132. bias=qkv_bias,
  133. init_cfg=init_cfg)
  134. self.sr_ratio = sr_ratio
  135. if sr_ratio > 1:
  136. self.sr = Conv2d(
  137. in_channels=embed_dims,
  138. out_channels=embed_dims,
  139. kernel_size=sr_ratio,
  140. stride=sr_ratio)
  141. # The ret[0] of build_norm_layer is norm name.
  142. self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
  143. # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
  144. from mmdet import mmcv_version, digit_version
  145. if mmcv_version < digit_version('1.3.17'):
  146. warnings.warn('The legacy version of forward function in'
  147. 'SpatialReductionAttention is deprecated in'
  148. 'mmcv>=1.3.17 and will no longer support in the'
  149. 'future. Please upgrade your mmcv.')
  150. self.forward = self.legacy_forward
  151. def forward(self, x, hw_shape, identity=None):
  152. x_q = x
  153. if self.sr_ratio > 1:
  154. x_kv = nlc_to_nchw(x, hw_shape)
  155. x_kv = self.sr(x_kv)
  156. x_kv = nchw_to_nlc(x_kv)
  157. x_kv = self.norm(x_kv)
  158. else:
  159. x_kv = x
  160. if identity is None:
  161. identity = x_q
  162. # Because the dataflow('key', 'query', 'value') of
  163. # ``torch.nn.MultiheadAttention`` is (num_query, batch,
  164. # embed_dims), We should adjust the shape of dataflow from
  165. # batch_first (batch, num_query, embed_dims) to num_query_first
  166. # (num_query ,batch, embed_dims), and recover ``attn_output``
  167. # from num_query_first to batch_first.
  168. if self.batch_first:
  169. x_q = x_q.transpose(0, 1)
  170. x_kv = x_kv.transpose(0, 1)
  171. out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
  172. if self.batch_first:
  173. out = out.transpose(0, 1)
  174. return identity + self.dropout_layer(self.proj_drop(out))
  175. def legacy_forward(self, x, hw_shape, identity=None):
  176. """multi head attention forward in mmcv version < 1.3.17."""
  177. x_q = x
  178. if self.sr_ratio > 1:
  179. x_kv = nlc_to_nchw(x, hw_shape)
  180. x_kv = self.sr(x_kv)
  181. x_kv = nchw_to_nlc(x_kv)
  182. x_kv = self.norm(x_kv)
  183. else:
  184. x_kv = x
  185. if identity is None:
  186. identity = x_q
  187. out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
  188. return identity + self.dropout_layer(self.proj_drop(out))
  189. class PVTEncoderLayer(BaseModule):
  190. """Implements one encoder layer in PVT.
  191. Args:
  192. embed_dims (int): The feature dimension.
  193. num_heads (int): Parallel attention heads.
  194. feedforward_channels (int): The hidden dimension for FFNs.
  195. drop_rate (float): Probability of an element to be zeroed.
  196. after the feed forward layer. Default: 0.0.
  197. attn_drop_rate (float): The drop out rate for attention layer.
  198. Default: 0.0.
  199. drop_path_rate (float): stochastic depth rate. Default: 0.0.
  200. qkv_bias (bool): enable bias for qkv if True.
  201. Default: True.
  202. act_cfg (dict): The activation config for FFNs.
  203. Default: dict(type='GELU').
  204. norm_cfg (dict): Config dict for normalization layer.
  205. Default: dict(type='LN').
  206. sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
  207. Attention of PVT. Default: 1.
  208. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
  209. Default: False.
  210. init_cfg (dict, optional): Initialization config dict.
  211. Default: None.
  212. """
  213. def __init__(self,
  214. embed_dims,
  215. num_heads,
  216. feedforward_channels,
  217. drop_rate=0.,
  218. attn_drop_rate=0.,
  219. drop_path_rate=0.,
  220. qkv_bias=True,
  221. act_cfg=dict(type='GELU'),
  222. norm_cfg=dict(type='LN'),
  223. sr_ratio=1,
  224. use_conv_ffn=False,
  225. init_cfg=None):
  226. super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
  227. # The ret[0] of build_norm_layer is norm name.
  228. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
  229. self.attn = SpatialReductionAttention(
  230. embed_dims=embed_dims,
  231. num_heads=num_heads,
  232. attn_drop=attn_drop_rate,
  233. proj_drop=drop_rate,
  234. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  235. qkv_bias=qkv_bias,
  236. norm_cfg=norm_cfg,
  237. sr_ratio=sr_ratio)
  238. # The ret[0] of build_norm_layer is norm name.
  239. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
  240. self.ffn = MixFFN(
  241. embed_dims=embed_dims,
  242. feedforward_channels=feedforward_channels,
  243. ffn_drop=drop_rate,
  244. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  245. use_conv=use_conv_ffn,
  246. act_cfg=act_cfg)
  247. def forward(self, x, hw_shape):
  248. x = self.attn(self.norm1(x), hw_shape, identity=x)
  249. x = self.ffn(self.norm2(x), hw_shape, identity=x)
  250. return x
  251. class AbsolutePositionEmbedding(BaseModule):
  252. """An implementation of the absolute position embedding in PVT.
  253. Args:
  254. pos_shape (int): The shape of the absolute position embedding.
  255. pos_dim (int): The dimension of the absolute position embedding.
  256. drop_rate (float): Probability of an element to be zeroed.
  257. Default: 0.0.
  258. """
  259. def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
  260. super().__init__(init_cfg=init_cfg)
  261. if isinstance(pos_shape, int):
  262. pos_shape = to_2tuple(pos_shape)
  263. elif isinstance(pos_shape, tuple):
  264. if len(pos_shape) == 1:
  265. pos_shape = to_2tuple(pos_shape[0])
  266. assert len(pos_shape) == 2, \
  267. f'The size of image should have length 1 or 2, ' \
  268. f'but got {len(pos_shape)}'
  269. self.pos_shape = pos_shape
  270. self.pos_dim = pos_dim
  271. self.pos_embed = nn.Parameter(
  272. torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
  273. self.drop = nn.Dropout(p=drop_rate)
  274. def init_weights(self):
  275. trunc_normal_(self.pos_embed, std=0.02)
  276. def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
  277. """Resize pos_embed weights.
  278. Resize pos_embed using bilinear interpolate method.
  279. Args:
  280. pos_embed (torch.Tensor): Position embedding weights.
  281. input_shape (tuple): Tuple for (downsampled input image height,
  282. downsampled input image width).
  283. mode (str): Algorithm used for upsampling:
  284. ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
  285. ``'trilinear'``. Default: ``'bilinear'``.
  286. Return:
  287. torch.Tensor: The resized pos_embed of shape [B, L_new, C].
  288. """
  289. assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
  290. pos_h, pos_w = self.pos_shape
  291. pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
  292. pos_embed_weight = pos_embed_weight.reshape(
  293. 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
  294. pos_embed_weight = F.interpolate(
  295. pos_embed_weight, size=input_shape, mode=mode)
  296. pos_embed_weight = torch.flatten(pos_embed_weight,
  297. 2).transpose(1, 2).contiguous()
  298. pos_embed = pos_embed_weight
  299. return pos_embed
  300. def forward(self, x, hw_shape, mode='bilinear'):
  301. pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
  302. return self.drop(x + pos_embed)
  303. @BACKBONES.register_module()
  304. class PyramidVisionTransformer(BaseModule):
  305. """Pyramid Vision Transformer (PVT)
  306. Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
  307. Dense Prediction without Convolutions
  308. <https://arxiv.org/pdf/2102.12122.pdf>`_.
  309. Args:
  310. pretrain_img_size (int | tuple[int]): The size of input image when
  311. pretrain. Defaults: 224.
  312. in_channels (int): Number of input channels. Default: 3.
  313. embed_dims (int): Embedding dimension. Default: 64.
  314. num_stags (int): The num of stages. Default: 4.
  315. num_layers (Sequence[int]): The layer number of each transformer encode
  316. layer. Default: [3, 4, 6, 3].
  317. num_heads (Sequence[int]): The attention heads of each transformer
  318. encode layer. Default: [1, 2, 5, 8].
  319. patch_sizes (Sequence[int]): The patch_size of each patch embedding.
  320. Default: [4, 2, 2, 2].
  321. strides (Sequence[int]): The stride of each patch embedding.
  322. Default: [4, 2, 2, 2].
  323. paddings (Sequence[int]): The padding of each patch embedding.
  324. Default: [0, 0, 0, 0].
  325. sr_ratios (Sequence[int]): The spatial reduction rate of each
  326. transformer encode layer. Default: [8, 4, 2, 1].
  327. out_indices (Sequence[int] | int): Output from which stages.
  328. Default: (0, 1, 2, 3).
  329. mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
  330. embedding dim of each transformer encode layer.
  331. Default: [8, 8, 4, 4].
  332. qkv_bias (bool): Enable bias for qkv if True. Default: True.
  333. drop_rate (float): Probability of an element to be zeroed.
  334. Default 0.0.
  335. attn_drop_rate (float): The drop out rate for attention layer.
  336. Default 0.0.
  337. drop_path_rate (float): stochastic depth rate. Default 0.1.
  338. use_abs_pos_embed (bool): If True, add absolute position embedding to
  339. the patch embedding. Defaults: True.
  340. use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
  341. Default: False.
  342. act_cfg (dict): The activation config for FFNs.
  343. Default: dict(type='GELU').
  344. norm_cfg (dict): Config dict for normalization layer.
  345. Default: dict(type='LN').
  346. pretrained (str, optional): model pretrained path. Default: None.
  347. convert_weights (bool): The flag indicates whether the
  348. pre-trained model is from the original repo. We may need
  349. to convert some keys to make it compatible.
  350. Default: True.
  351. init_cfg (dict or list[dict], optional): Initialization config dict.
  352. Default: None.
  353. """
  354. def __init__(self,
  355. pretrain_img_size=224,
  356. in_channels=3,
  357. embed_dims=64,
  358. num_stages=4,
  359. num_layers=[3, 4, 6, 3],
  360. num_heads=[1, 2, 5, 8],
  361. patch_sizes=[4, 2, 2, 2],
  362. strides=[4, 2, 2, 2],
  363. paddings=[0, 0, 0, 0],
  364. sr_ratios=[8, 4, 2, 1],
  365. out_indices=(0, 1, 2, 3),
  366. mlp_ratios=[8, 8, 4, 4],
  367. qkv_bias=True,
  368. drop_rate=0.,
  369. attn_drop_rate=0.,
  370. drop_path_rate=0.1,
  371. use_abs_pos_embed=True,
  372. norm_after_stage=False,
  373. use_conv_ffn=False,
  374. act_cfg=dict(type='GELU'),
  375. norm_cfg=dict(type='LN', eps=1e-6),
  376. pretrained=None,
  377. convert_weights=True,
  378. init_cfg=None):
  379. super().__init__(init_cfg=init_cfg)
  380. self.convert_weights = convert_weights
  381. if isinstance(pretrain_img_size, int):
  382. pretrain_img_size = to_2tuple(pretrain_img_size)
  383. elif isinstance(pretrain_img_size, tuple):
  384. if len(pretrain_img_size) == 1:
  385. pretrain_img_size = to_2tuple(pretrain_img_size[0])
  386. assert len(pretrain_img_size) == 2, \
  387. f'The size of image should have length 1 or 2, ' \
  388. f'but got {len(pretrain_img_size)}'
  389. assert not (init_cfg and pretrained), \
  390. 'init_cfg and pretrained cannot be setting at the same time'
  391. if isinstance(pretrained, str):
  392. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  393. 'please use "init_cfg" instead')
  394. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  395. elif pretrained is None:
  396. self.init_cfg = init_cfg
  397. else:
  398. raise TypeError('pretrained must be a str or None')
  399. self.embed_dims = embed_dims
  400. self.num_stages = num_stages
  401. self.num_layers = num_layers
  402. self.num_heads = num_heads
  403. self.patch_sizes = patch_sizes
  404. self.strides = strides
  405. self.sr_ratios = sr_ratios
  406. assert num_stages == len(num_layers) == len(num_heads) \
  407. == len(patch_sizes) == len(strides) == len(sr_ratios)
  408. self.out_indices = out_indices
  409. assert max(out_indices) < self.num_stages
  410. self.pretrained = pretrained
  411. # transformer encoder
  412. dpr = [
  413. x.item()
  414. for x in torch.linspace(0, drop_path_rate, sum(num_layers))
  415. ] # stochastic num_layer decay rule
  416. cur = 0
  417. self.layers = ModuleList()
  418. for i, num_layer in enumerate(num_layers):
  419. embed_dims_i = embed_dims * num_heads[i]
  420. patch_embed = PatchEmbed(
  421. in_channels=in_channels,
  422. embed_dims=embed_dims_i,
  423. kernel_size=patch_sizes[i],
  424. stride=strides[i],
  425. padding=paddings[i],
  426. bias=True,
  427. norm_cfg=norm_cfg)
  428. layers = ModuleList()
  429. if use_abs_pos_embed:
  430. pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
  431. pos_embed = AbsolutePositionEmbedding(
  432. pos_shape=pos_shape,
  433. pos_dim=embed_dims_i,
  434. drop_rate=drop_rate)
  435. layers.append(pos_embed)
  436. layers.extend([
  437. PVTEncoderLayer(
  438. embed_dims=embed_dims_i,
  439. num_heads=num_heads[i],
  440. feedforward_channels=mlp_ratios[i] * embed_dims_i,
  441. drop_rate=drop_rate,
  442. attn_drop_rate=attn_drop_rate,
  443. drop_path_rate=dpr[cur + idx],
  444. qkv_bias=qkv_bias,
  445. act_cfg=act_cfg,
  446. norm_cfg=norm_cfg,
  447. sr_ratio=sr_ratios[i],
  448. use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
  449. ])
  450. in_channels = embed_dims_i
  451. # The ret[0] of build_norm_layer is norm name.
  452. if norm_after_stage:
  453. norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
  454. else:
  455. norm = nn.Identity()
  456. self.layers.append(ModuleList([patch_embed, layers, norm]))
  457. cur += num_layer
  458. def init_weights(self):
  459. logger = get_root_logger()
  460. if self.init_cfg is None:
  461. logger.warn(f'No pre-trained weights for '
  462. f'{self.__class__.__name__}, '
  463. f'training start from scratch')
  464. for m in self.modules():
  465. if isinstance(m, nn.Linear):
  466. trunc_normal_init(m, std=.02, bias=0.)
  467. elif isinstance(m, nn.LayerNorm):
  468. constant_init(m, 1.0)
  469. elif isinstance(m, nn.Conv2d):
  470. fan_out = m.kernel_size[0] * m.kernel_size[
  471. 1] * m.out_channels
  472. fan_out //= m.groups
  473. normal_init(m, 0, math.sqrt(2.0 / fan_out))
  474. elif isinstance(m, AbsolutePositionEmbedding):
  475. m.init_weights()
  476. else:
  477. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  478. f'specify `Pretrained` in ' \
  479. f'`init_cfg` in ' \
  480. f'{self.__class__.__name__} '
  481. checkpoint = _load_checkpoint(
  482. self.init_cfg.checkpoint, logger=logger, map_location='cpu')
  483. logger.warn(f'Load pre-trained model for '
  484. f'{self.__class__.__name__} from original repo')
  485. if 'state_dict' in checkpoint:
  486. state_dict = checkpoint['state_dict']
  487. elif 'model' in checkpoint:
  488. state_dict = checkpoint['model']
  489. else:
  490. state_dict = checkpoint
  491. if self.convert_weights:
  492. # Because pvt backbones are not supported by mmcls,
  493. # so we need to convert pre-trained weights to match this
  494. # implementation.
  495. state_dict = pvt_convert(state_dict)
  496. load_state_dict(self, state_dict, strict=False, logger=logger)
  497. def forward(self, x):
  498. outs = []
  499. for i, layer in enumerate(self.layers):
  500. x, hw_shape = layer[0](x)
  501. for block in layer[1]:
  502. x = block(x, hw_shape)
  503. x = layer[2](x)
  504. x = nlc_to_nchw(x, hw_shape)
  505. if i in self.out_indices:
  506. outs.append(x)
  507. return outs
  508. @BACKBONES.register_module()
  509. class PyramidVisionTransformerV2(PyramidVisionTransformer):
  510. """Implementation of `PVTv2: Improved Baselines with Pyramid Vision
  511. Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_."""
  512. def __init__(self, **kwargs):
  513. super(PyramidVisionTransformerV2, self).__init__(
  514. patch_sizes=[7, 3, 3, 3],
  515. paddings=[3, 1, 1, 1],
  516. use_abs_pos_embed=False,
  517. norm_after_stage=True,
  518. use_conv_ffn=True,
  519. **kwargs)

No Description

Contributors (3)