You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763
  1. import warnings
  2. from collections import OrderedDict
  3. from copy import deepcopy
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.utils.checkpoint as cp
  8. from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
  9. from mmcv.cnn.bricks.transformer import FFN, build_dropout
  10. from mmcv.cnn.utils.weight_init import trunc_normal_
  11. from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
  12. from mmcv.utils import to_2tuple
  13. from ...utils import get_root_logger
  14. from ..builder import BACKBONES
  15. from ..utils.ckpt_convert import swin_converter
  16. from ..utils.transformer import PatchEmbed, PatchMerging
  17. class WindowMSA(BaseModule):
  18. """Window based multi-head self-attention (W-MSA) module with relative
  19. position bias.
  20. Args:
  21. embed_dims (int): Number of input channels.
  22. num_heads (int): Number of attention heads.
  23. window_size (tuple[int]): The height and width of the window.
  24. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
  25. Default: True.
  26. qk_scale (float | None, optional): Override default qk scale of
  27. head_dim ** -0.5 if set. Default: None.
  28. attn_drop_rate (float, optional): Dropout ratio of attention weight.
  29. Default: 0.0
  30. proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
  31. init_cfg (dict | None, optional): The Config for initialization.
  32. Default: None.
  33. """
  34. def __init__(self,
  35. embed_dims,
  36. num_heads,
  37. window_size,
  38. qkv_bias=True,
  39. qk_scale=None,
  40. attn_drop_rate=0.,
  41. proj_drop_rate=0.,
  42. init_cfg=None):
  43. super().__init__()
  44. self.embed_dims = embed_dims
  45. self.window_size = window_size # Wh, Ww
  46. self.num_heads = num_heads
  47. head_embed_dims = embed_dims // num_heads
  48. self.scale = qk_scale or head_embed_dims**-0.5
  49. self.init_cfg = init_cfg
  50. # define a parameter table of relative position bias
  51. self.relative_position_bias_table = nn.Parameter(
  52. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  53. num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  54. # About 2x faster than original impl
  55. Wh, Ww = self.window_size
  56. rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
  57. rel_position_index = rel_index_coords + rel_index_coords.T
  58. rel_position_index = rel_position_index.flip(1).contiguous()
  59. self.register_buffer('relative_position_index', rel_position_index)
  60. self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
  61. self.attn_drop = nn.Dropout(attn_drop_rate)
  62. self.proj = nn.Linear(embed_dims, embed_dims)
  63. self.proj_drop = nn.Dropout(proj_drop_rate)
  64. self.softmax = nn.Softmax(dim=-1)
  65. def init_weights(self):
  66. trunc_normal_(self.relative_position_bias_table, std=0.02)
  67. def forward(self, x, mask=None):
  68. """
  69. Args:
  70. x (tensor): input features with shape of (num_windows*B, N, C)
  71. mask (tensor | None, Optional): mask with shape of (num_windows,
  72. Wh*Ww, Wh*Ww), value should be between (-inf, 0].
  73. """
  74. B, N, C = x.shape
  75. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  76. C // self.num_heads).permute(2, 0, 3, 1, 4)
  77. # make torchscript happy (cannot use tensor as tuple)
  78. q, k, v = qkv[0], qkv[1], qkv[2]
  79. q = q * self.scale
  80. attn = (q @ k.transpose(-2, -1))
  81. relative_position_bias = self.relative_position_bias_table[
  82. self.relative_position_index.view(-1)].view(
  83. self.window_size[0] * self.window_size[1],
  84. self.window_size[0] * self.window_size[1],
  85. -1) # Wh*Ww,Wh*Ww,nH
  86. relative_position_bias = relative_position_bias.permute(
  87. 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  88. attn = attn + relative_position_bias.unsqueeze(0)
  89. if mask is not None:
  90. nW = mask.shape[0]
  91. attn = attn.view(B // nW, nW, self.num_heads, N,
  92. N) + mask.unsqueeze(1).unsqueeze(0)
  93. attn = attn.view(-1, self.num_heads, N, N)
  94. attn = self.softmax(attn)
  95. attn = self.attn_drop(attn)
  96. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  97. x = self.proj(x)
  98. x = self.proj_drop(x)
  99. return x
  100. @staticmethod
  101. def double_step_seq(step1, len1, step2, len2):
  102. seq1 = torch.arange(0, step1 * len1, step1)
  103. seq2 = torch.arange(0, step2 * len2, step2)
  104. return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
  105. class ShiftWindowMSA(BaseModule):
  106. """Shifted Window Multihead Self-Attention Module.
  107. Args:
  108. embed_dims (int): Number of input channels.
  109. num_heads (int): Number of attention heads.
  110. window_size (int): The height and width of the window.
  111. shift_size (int, optional): The shift step of each window towards
  112. right-bottom. If zero, act as regular window-msa. Defaults to 0.
  113. qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
  114. Default: True
  115. qk_scale (float | None, optional): Override default qk scale of
  116. head_dim ** -0.5 if set. Defaults: None.
  117. attn_drop_rate (float, optional): Dropout ratio of attention weight.
  118. Defaults: 0.
  119. proj_drop_rate (float, optional): Dropout ratio of output.
  120. Defaults: 0.
  121. dropout_layer (dict, optional): The dropout_layer used before output.
  122. Defaults: dict(type='DropPath', drop_prob=0.).
  123. init_cfg (dict, optional): The extra config for initialization.
  124. Default: None.
  125. """
  126. def __init__(self,
  127. embed_dims,
  128. num_heads,
  129. window_size,
  130. shift_size=0,
  131. qkv_bias=True,
  132. qk_scale=None,
  133. attn_drop_rate=0,
  134. proj_drop_rate=0,
  135. dropout_layer=dict(type='DropPath', drop_prob=0.),
  136. init_cfg=None):
  137. super().__init__(init_cfg)
  138. self.window_size = window_size
  139. self.shift_size = shift_size
  140. assert 0 <= self.shift_size < self.window_size
  141. self.w_msa = WindowMSA(
  142. embed_dims=embed_dims,
  143. num_heads=num_heads,
  144. window_size=to_2tuple(window_size),
  145. qkv_bias=qkv_bias,
  146. qk_scale=qk_scale,
  147. attn_drop_rate=attn_drop_rate,
  148. proj_drop_rate=proj_drop_rate,
  149. init_cfg=None)
  150. self.drop = build_dropout(dropout_layer)
  151. def forward(self, query, hw_shape):
  152. B, L, C = query.shape
  153. H, W = hw_shape
  154. assert L == H * W, 'input feature has wrong size'
  155. query = query.view(B, H, W, C)
  156. # pad feature maps to multiples of window size
  157. pad_r = (self.window_size - W % self.window_size) % self.window_size
  158. pad_b = (self.window_size - H % self.window_size) % self.window_size
  159. query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
  160. H_pad, W_pad = query.shape[1], query.shape[2]
  161. # cyclic shift
  162. if self.shift_size > 0:
  163. shifted_query = torch.roll(
  164. query,
  165. shifts=(-self.shift_size, -self.shift_size),
  166. dims=(1, 2))
  167. # calculate attention mask for SW-MSA
  168. img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
  169. h_slices = (slice(0, -self.window_size),
  170. slice(-self.window_size,
  171. -self.shift_size), slice(-self.shift_size, None))
  172. w_slices = (slice(0, -self.window_size),
  173. slice(-self.window_size,
  174. -self.shift_size), slice(-self.shift_size, None))
  175. cnt = 0
  176. for h in h_slices:
  177. for w in w_slices:
  178. img_mask[:, h, w, :] = cnt
  179. cnt += 1
  180. # nW, window_size, window_size, 1
  181. mask_windows = self.window_partition(img_mask)
  182. mask_windows = mask_windows.view(
  183. -1, self.window_size * self.window_size)
  184. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  185. attn_mask = attn_mask.masked_fill(attn_mask != 0,
  186. float(-100.0)).masked_fill(
  187. attn_mask == 0, float(0.0))
  188. else:
  189. shifted_query = query
  190. attn_mask = None
  191. # nW*B, window_size, window_size, C
  192. query_windows = self.window_partition(shifted_query)
  193. # nW*B, window_size*window_size, C
  194. query_windows = query_windows.view(-1, self.window_size**2, C)
  195. # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
  196. attn_windows = self.w_msa(query_windows, mask=attn_mask)
  197. # merge windows
  198. attn_windows = attn_windows.view(-1, self.window_size,
  199. self.window_size, C)
  200. # B H' W' C
  201. shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
  202. # reverse cyclic shift
  203. if self.shift_size > 0:
  204. x = torch.roll(
  205. shifted_x,
  206. shifts=(self.shift_size, self.shift_size),
  207. dims=(1, 2))
  208. else:
  209. x = shifted_x
  210. if pad_r > 0 or pad_b:
  211. x = x[:, :H, :W, :].contiguous()
  212. x = x.view(B, H * W, C)
  213. x = self.drop(x)
  214. return x
  215. def window_reverse(self, windows, H, W):
  216. """
  217. Args:
  218. windows: (num_windows*B, window_size, window_size, C)
  219. H (int): Height of image
  220. W (int): Width of image
  221. Returns:
  222. x: (B, H, W, C)
  223. """
  224. window_size = self.window_size
  225. B = int(windows.shape[0] / (H * W / window_size / window_size))
  226. x = windows.view(B, H // window_size, W // window_size, window_size,
  227. window_size, -1)
  228. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  229. return x
  230. def window_partition(self, x):
  231. """
  232. Args:
  233. x: (B, H, W, C)
  234. Returns:
  235. windows: (num_windows*B, window_size, window_size, C)
  236. """
  237. B, H, W, C = x.shape
  238. window_size = self.window_size
  239. x = x.view(B, H // window_size, window_size, W // window_size,
  240. window_size, C)
  241. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
  242. windows = windows.view(-1, window_size, window_size, C)
  243. return windows
  244. class SwinBlock(BaseModule):
  245. """"
  246. Args:
  247. embed_dims (int): The feature dimension.
  248. num_heads (int): Parallel attention heads.
  249. feedforward_channels (int): The hidden dimension for FFNs.
  250. window_size (int, optional): The local window scale. Default: 7.
  251. shift (bool, optional): whether to shift window or not. Default False.
  252. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
  253. qk_scale (float | None, optional): Override default qk scale of
  254. head_dim ** -0.5 if set. Default: None.
  255. drop_rate (float, optional): Dropout rate. Default: 0.
  256. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
  257. drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
  258. act_cfg (dict, optional): The config dict of activation function.
  259. Default: dict(type='GELU').
  260. norm_cfg (dict, optional): The config dict of normalization.
  261. Default: dict(type='LN').
  262. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  263. will save some memory while slowing down the training speed.
  264. Default: False.
  265. init_cfg (dict | list | None, optional): The init config.
  266. Default: None.
  267. """
  268. def __init__(self,
  269. embed_dims,
  270. num_heads,
  271. feedforward_channels,
  272. window_size=7,
  273. shift=False,
  274. qkv_bias=True,
  275. qk_scale=None,
  276. drop_rate=0.,
  277. attn_drop_rate=0.,
  278. drop_path_rate=0.,
  279. act_cfg=dict(type='GELU'),
  280. norm_cfg=dict(type='LN'),
  281. with_cp=False,
  282. init_cfg=None):
  283. super(SwinBlock, self).__init__()
  284. self.init_cfg = init_cfg
  285. self.with_cp = with_cp
  286. self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
  287. self.attn = ShiftWindowMSA(
  288. embed_dims=embed_dims,
  289. num_heads=num_heads,
  290. window_size=window_size,
  291. shift_size=window_size // 2 if shift else 0,
  292. qkv_bias=qkv_bias,
  293. qk_scale=qk_scale,
  294. attn_drop_rate=attn_drop_rate,
  295. proj_drop_rate=drop_rate,
  296. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  297. init_cfg=None)
  298. self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
  299. self.ffn = FFN(
  300. embed_dims=embed_dims,
  301. feedforward_channels=feedforward_channels,
  302. num_fcs=2,
  303. ffn_drop=drop_rate,
  304. dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
  305. act_cfg=act_cfg,
  306. add_identity=True,
  307. init_cfg=None)
  308. def forward(self, x, hw_shape):
  309. def _inner_forward(x):
  310. identity = x
  311. x = self.norm1(x)
  312. x = self.attn(x, hw_shape)
  313. x = x + identity
  314. identity = x
  315. x = self.norm2(x)
  316. x = self.ffn(x, identity=identity)
  317. return x
  318. if self.with_cp and x.requires_grad:
  319. x = cp.checkpoint(_inner_forward, x)
  320. else:
  321. x = _inner_forward(x)
  322. return x
  323. class SwinBlockSequence(BaseModule):
  324. """Implements one stage in Swin Transformer.
  325. Args:
  326. embed_dims (int): The feature dimension.
  327. num_heads (int): Parallel attention heads.
  328. feedforward_channels (int): The hidden dimension for FFNs.
  329. depth (int): The number of blocks in this stage.
  330. window_size (int, optional): The local window scale. Default: 7.
  331. qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
  332. qk_scale (float | None, optional): Override default qk scale of
  333. head_dim ** -0.5 if set. Default: None.
  334. drop_rate (float, optional): Dropout rate. Default: 0.
  335. attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
  336. drop_path_rate (float | list[float], optional): Stochastic depth
  337. rate. Default: 0.
  338. downsample (BaseModule | None, optional): The downsample operation
  339. module. Default: None.
  340. act_cfg (dict, optional): The config dict of activation function.
  341. Default: dict(type='GELU').
  342. norm_cfg (dict, optional): The config dict of normalization.
  343. Default: dict(type='LN').
  344. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  345. will save some memory while slowing down the training speed.
  346. Default: False.
  347. init_cfg (dict | list | None, optional): The init config.
  348. Default: None.
  349. """
  350. def __init__(self,
  351. embed_dims,
  352. num_heads,
  353. feedforward_channels,
  354. depth,
  355. window_size=7,
  356. qkv_bias=True,
  357. qk_scale=None,
  358. drop_rate=0.,
  359. attn_drop_rate=0.,
  360. drop_path_rate=0.,
  361. downsample=None,
  362. act_cfg=dict(type='GELU'),
  363. norm_cfg=dict(type='LN'),
  364. with_cp=False,
  365. init_cfg=None):
  366. super().__init__(init_cfg=init_cfg)
  367. if isinstance(drop_path_rate, list):
  368. drop_path_rates = drop_path_rate
  369. assert len(drop_path_rates) == depth
  370. else:
  371. drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
  372. self.blocks = ModuleList()
  373. for i in range(depth):
  374. block = SwinBlock(
  375. embed_dims=embed_dims,
  376. num_heads=num_heads,
  377. feedforward_channels=feedforward_channels,
  378. window_size=window_size,
  379. shift=False if i % 2 == 0 else True,
  380. qkv_bias=qkv_bias,
  381. qk_scale=qk_scale,
  382. drop_rate=drop_rate,
  383. attn_drop_rate=attn_drop_rate,
  384. drop_path_rate=drop_path_rates[i],
  385. act_cfg=act_cfg,
  386. norm_cfg=norm_cfg,
  387. with_cp=with_cp,
  388. init_cfg=None)
  389. self.blocks.append(block)
  390. self.downsample = downsample
  391. def forward(self, x, hw_shape):
  392. for block in self.blocks:
  393. x = block(x, hw_shape)
  394. if self.downsample:
  395. x_down, down_hw_shape = self.downsample(x, hw_shape)
  396. return x_down, down_hw_shape, x, hw_shape
  397. else:
  398. return x, hw_shape, x, hw_shape
  399. @BACKBONES.register_module()
  400. class SwinTransformer(BaseModule):
  401. """ Swin Transformer
  402. A PyTorch implement of : `Swin Transformer:
  403. Hierarchical Vision Transformer using Shifted Windows` -
  404. https://arxiv.org/abs/2103.14030
  405. Inspiration from
  406. https://github.com/microsoft/Swin-Transformer
  407. Args:
  408. pretrain_img_size (int | tuple[int]): The size of input image when
  409. pretrain. Defaults: 224.
  410. in_channels (int): The num of input channels.
  411. Defaults: 3.
  412. embed_dims (int): The feature dimension. Default: 96.
  413. patch_size (int | tuple[int]): Patch size. Default: 4.
  414. window_size (int): Window size. Default: 7.
  415. mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
  416. Default: 4.
  417. depths (tuple[int]): Depths of each Swin Transformer stage.
  418. Default: (2, 2, 6, 2).
  419. num_heads (tuple[int]): Parallel attention heads of each Swin
  420. Transformer stage. Default: (3, 6, 12, 24).
  421. strides (tuple[int]): The patch merging or patch embedding stride of
  422. each Swin Transformer stage. (In swin, we set kernel size equal to
  423. stride.) Default: (4, 2, 2, 2).
  424. out_indices (tuple[int]): Output from which stages.
  425. Default: (0, 1, 2, 3).
  426. qkv_bias (bool, optional): If True, add a learnable bias to query, key,
  427. value. Default: True
  428. qk_scale (float | None, optional): Override default qk scale of
  429. head_dim ** -0.5 if set. Default: None.
  430. patch_norm (bool): If add a norm layer for patch embed and patch
  431. merging. Default: True.
  432. drop_rate (float): Dropout rate. Defaults: 0.
  433. attn_drop_rate (float): Attention dropout rate. Default: 0.
  434. drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
  435. use_abs_pos_embed (bool): If True, add absolute position embedding to
  436. the patch embedding. Defaults: False.
  437. act_cfg (dict): Config dict for activation layer.
  438. Default: dict(type='LN').
  439. norm_cfg (dict): Config dict for normalization layer at
  440. output of backone. Defaults: dict(type='LN').
  441. with_cp (bool, optional): Use checkpoint or not. Using checkpoint
  442. will save some memory while slowing down the training speed.
  443. Default: False.
  444. pretrained (str, optional): model pretrained path. Default: None.
  445. convert_weights (bool): The flag indicates whether the
  446. pre-trained model is from the original repo. We may need
  447. to convert some keys to make it compatible.
  448. Default: False.
  449. frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
  450. -1 means not freezing any parameters.
  451. init_cfg (dict, optional): The Config for initialization.
  452. Defaults to None.
  453. """
  454. def __init__(self,
  455. pretrain_img_size=224,
  456. in_channels=3,
  457. embed_dims=96,
  458. patch_size=4,
  459. window_size=7,
  460. mlp_ratio=4,
  461. depths=(2, 2, 6, 2),
  462. num_heads=(3, 6, 12, 24),
  463. strides=(4, 2, 2, 2),
  464. out_indices=(0, 1, 2, 3),
  465. qkv_bias=True,
  466. qk_scale=None,
  467. patch_norm=True,
  468. drop_rate=0.,
  469. attn_drop_rate=0.,
  470. drop_path_rate=0.1,
  471. use_abs_pos_embed=False,
  472. act_cfg=dict(type='GELU'),
  473. norm_cfg=dict(type='LN'),
  474. with_cp=False,
  475. pretrained=None,
  476. convert_weights=False,
  477. frozen_stages=-1,
  478. init_cfg=None):
  479. self.convert_weights = convert_weights
  480. self.frozen_stages = frozen_stages
  481. if isinstance(pretrain_img_size, int):
  482. pretrain_img_size = to_2tuple(pretrain_img_size)
  483. elif isinstance(pretrain_img_size, tuple):
  484. if len(pretrain_img_size) == 1:
  485. pretrain_img_size = to_2tuple(pretrain_img_size[0])
  486. assert len(pretrain_img_size) == 2, \
  487. f'The size of image should have length 1 or 2, ' \
  488. f'but got {len(pretrain_img_size)}'
  489. assert not (init_cfg and pretrained), \
  490. 'init_cfg and pretrained cannot be specified at the same time'
  491. if isinstance(pretrained, str):
  492. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  493. 'please use "init_cfg" instead')
  494. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  495. elif pretrained is None:
  496. self.init_cfg = init_cfg
  497. else:
  498. raise TypeError('pretrained must be a str or None')
  499. super(SwinTransformer, self).__init__(init_cfg=init_cfg)
  500. num_layers = len(depths)
  501. self.out_indices = out_indices
  502. self.use_abs_pos_embed = use_abs_pos_embed
  503. assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
  504. self.patch_embed = PatchEmbed(
  505. in_channels=in_channels,
  506. embed_dims=embed_dims,
  507. conv_type='Conv2d',
  508. kernel_size=patch_size,
  509. stride=strides[0],
  510. norm_cfg=norm_cfg if patch_norm else None,
  511. init_cfg=None)
  512. if self.use_abs_pos_embed:
  513. patch_row = pretrain_img_size[0] // patch_size
  514. patch_col = pretrain_img_size[1] // patch_size
  515. num_patches = patch_row * patch_col
  516. self.absolute_pos_embed = nn.Parameter(
  517. torch.zeros((1, num_patches, embed_dims)))
  518. self.drop_after_pos = nn.Dropout(p=drop_rate)
  519. # set stochastic depth decay rule
  520. total_depth = sum(depths)
  521. dpr = [
  522. x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
  523. ]
  524. self.stages = ModuleList()
  525. in_channels = embed_dims
  526. for i in range(num_layers):
  527. if i < num_layers - 1:
  528. downsample = PatchMerging(
  529. in_channels=in_channels,
  530. out_channels=2 * in_channels,
  531. stride=strides[i + 1],
  532. norm_cfg=norm_cfg if patch_norm else None,
  533. init_cfg=None)
  534. else:
  535. downsample = None
  536. stage = SwinBlockSequence(
  537. embed_dims=in_channels,
  538. num_heads=num_heads[i],
  539. feedforward_channels=mlp_ratio * in_channels,
  540. depth=depths[i],
  541. window_size=window_size,
  542. qkv_bias=qkv_bias,
  543. qk_scale=qk_scale,
  544. drop_rate=drop_rate,
  545. attn_drop_rate=attn_drop_rate,
  546. drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
  547. downsample=downsample,
  548. act_cfg=act_cfg,
  549. norm_cfg=norm_cfg,
  550. with_cp=with_cp,
  551. init_cfg=None)
  552. self.stages.append(stage)
  553. if downsample:
  554. in_channels = downsample.out_channels
  555. self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
  556. # Add a norm layer for each output
  557. for i in out_indices:
  558. layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
  559. layer_name = f'norm{i}'
  560. self.add_module(layer_name, layer)
  561. def train(self, mode=True):
  562. """Convert the model into training mode while keep layers freezed."""
  563. super(SwinTransformer, self).train(mode)
  564. self._freeze_stages()
  565. def _freeze_stages(self):
  566. if self.frozen_stages >= 0:
  567. self.patch_embed.eval()
  568. for param in self.patch_embed.parameters():
  569. param.requires_grad = False
  570. if self.use_abs_pos_embed:
  571. self.absolute_pos_embed.requires_grad = False
  572. self.drop_after_pos.eval()
  573. for i in range(1, self.frozen_stages + 1):
  574. if (i - 1) in self.out_indices:
  575. norm_layer = getattr(self, f'norm{i-1}')
  576. norm_layer.eval()
  577. for param in norm_layer.parameters():
  578. param.requires_grad = False
  579. m = self.stages[i - 1]
  580. m.eval()
  581. for param in m.parameters():
  582. param.requires_grad = False
  583. def init_weights(self):
  584. logger = get_root_logger()
  585. if self.init_cfg is None:
  586. logger.warn(f'No pre-trained weights for '
  587. f'{self.__class__.__name__}, '
  588. f'training start from scratch')
  589. if self.use_abs_pos_embed:
  590. trunc_normal_(self.absolute_pos_embed, std=0.02)
  591. for m in self.modules():
  592. if isinstance(m, nn.Linear):
  593. trunc_normal_init(m, std=.02, bias=0.)
  594. elif isinstance(m, nn.LayerNorm):
  595. constant_init(m.bias, 0)
  596. constant_init(m.weight, 1.0)
  597. else:
  598. assert 'checkpoint' in self.init_cfg, f'Only support ' \
  599. f'specify `Pretrained` in ' \
  600. f'`init_cfg` in ' \
  601. f'{self.__class__.__name__} '
  602. ckpt = _load_checkpoint(
  603. self.init_cfg.checkpoint, logger=logger, map_location='cpu')
  604. if 'state_dict' in ckpt:
  605. _state_dict = ckpt['state_dict']
  606. elif 'model' in ckpt:
  607. _state_dict = ckpt['model']
  608. else:
  609. _state_dict = ckpt
  610. if self.convert_weights:
  611. # supported loading weight from original repo,
  612. _state_dict = swin_converter(_state_dict)
  613. state_dict = OrderedDict()
  614. for k, v in _state_dict.items():
  615. if k.startswith('backbone.'):
  616. state_dict[k[9:]] = v
  617. # strip prefix of state_dict
  618. if list(state_dict.keys())[0].startswith('module.'):
  619. state_dict = {k[7:]: v for k, v in state_dict.items()}
  620. # reshape absolute position embedding
  621. if state_dict.get('absolute_pos_embed') is not None:
  622. absolute_pos_embed = state_dict['absolute_pos_embed']
  623. N1, L, C1 = absolute_pos_embed.size()
  624. N2, C2, H, W = self.absolute_pos_embed.size()
  625. if N1 != N2 or C1 != C2 or L != H * W:
  626. logger.warning('Error in loading absolute_pos_embed, pass')
  627. else:
  628. state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
  629. N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
  630. # interpolate position bias table if needed
  631. relative_position_bias_table_keys = [
  632. k for k in state_dict.keys()
  633. if 'relative_position_bias_table' in k
  634. ]
  635. for table_key in relative_position_bias_table_keys:
  636. table_pretrained = state_dict[table_key]
  637. table_current = self.state_dict()[table_key]
  638. L1, nH1 = table_pretrained.size()
  639. L2, nH2 = table_current.size()
  640. if nH1 != nH2:
  641. logger.warning(f'Error in loading {table_key}, pass')
  642. elif L1 != L2:
  643. S1 = int(L1**0.5)
  644. S2 = int(L2**0.5)
  645. table_pretrained_resized = F.interpolate(
  646. table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
  647. size=(S2, S2),
  648. mode='bicubic')
  649. state_dict[table_key] = table_pretrained_resized.view(
  650. nH2, L2).permute(1, 0).contiguous()
  651. # load state_dict
  652. self.load_state_dict(state_dict, False)
  653. def forward(self, x):
  654. x, hw_shape = self.patch_embed(x)
  655. if self.use_abs_pos_embed:
  656. x = x + self.absolute_pos_embed
  657. x = self.drop_after_pos(x)
  658. outs = []
  659. for i, stage in enumerate(self.stages):
  660. x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
  661. if i in self.out_indices:
  662. norm_layer = getattr(self, f'norm{i}')
  663. out = norm_layer(out)
  664. out = out.view(-1, *out_hw_shape,
  665. self.num_features[i]).permute(0, 3, 1,
  666. 2).contiguous()
  667. outs.append(out)
  668. return outs

No Description

Contributors (3)