You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ckpt_convert.py 5.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # This script consists of several convert functions which
  3. # can modify the weights of model in original repo to be
  4. # pre-trained weights.
  5. from collections import OrderedDict
  6. import torch
  7. def pvt_convert(ckpt):
  8. new_ckpt = OrderedDict()
  9. # Process the concat between q linear weights and kv linear weights
  10. use_abs_pos_embed = False
  11. use_conv_ffn = False
  12. for k in ckpt.keys():
  13. if k.startswith('pos_embed'):
  14. use_abs_pos_embed = True
  15. if k.find('dwconv') >= 0:
  16. use_conv_ffn = True
  17. for k, v in ckpt.items():
  18. if k.startswith('head'):
  19. continue
  20. if k.startswith('norm.'):
  21. continue
  22. if k.startswith('cls_token'):
  23. continue
  24. if k.startswith('pos_embed'):
  25. stage_i = int(k.replace('pos_embed', ''))
  26. new_k = k.replace(f'pos_embed{stage_i}',
  27. f'layers.{stage_i - 1}.1.0.pos_embed')
  28. if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7
  29. new_v = v[:, 1:, :] # remove cls token
  30. else:
  31. new_v = v
  32. elif k.startswith('patch_embed'):
  33. stage_i = int(k.split('.')[0].replace('patch_embed', ''))
  34. new_k = k.replace(f'patch_embed{stage_i}',
  35. f'layers.{stage_i - 1}.0')
  36. new_v = v
  37. if 'proj.' in new_k:
  38. new_k = new_k.replace('proj.', 'projection.')
  39. elif k.startswith('block'):
  40. stage_i = int(k.split('.')[0].replace('block', ''))
  41. layer_i = int(k.split('.')[1])
  42. new_layer_i = layer_i + use_abs_pos_embed
  43. new_k = k.replace(f'block{stage_i}.{layer_i}',
  44. f'layers.{stage_i - 1}.1.{new_layer_i}')
  45. new_v = v
  46. if 'attn.q.' in new_k:
  47. sub_item_k = k.replace('q.', 'kv.')
  48. new_k = new_k.replace('q.', 'attn.in_proj_')
  49. new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
  50. elif 'attn.kv.' in new_k:
  51. continue
  52. elif 'attn.proj.' in new_k:
  53. new_k = new_k.replace('proj.', 'attn.out_proj.')
  54. elif 'attn.sr.' in new_k:
  55. new_k = new_k.replace('sr.', 'sr.')
  56. elif 'mlp.' in new_k:
  57. string = f'{new_k}-'
  58. new_k = new_k.replace('mlp.', 'ffn.layers.')
  59. if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
  60. new_v = v.reshape((*v.shape, 1, 1))
  61. new_k = new_k.replace('fc1.', '0.')
  62. new_k = new_k.replace('dwconv.dwconv.', '1.')
  63. if use_conv_ffn:
  64. new_k = new_k.replace('fc2.', '4.')
  65. else:
  66. new_k = new_k.replace('fc2.', '3.')
  67. string += f'{new_k} {v.shape}-{new_v.shape}'
  68. elif k.startswith('norm'):
  69. stage_i = int(k[4])
  70. new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')
  71. new_v = v
  72. else:
  73. new_k = k
  74. new_v = v
  75. new_ckpt[new_k] = new_v
  76. return new_ckpt
  77. def swin_converter(ckpt):
  78. new_ckpt = OrderedDict()
  79. def correct_unfold_reduction_order(x):
  80. out_channel, in_channel = x.shape
  81. x = x.reshape(out_channel, 4, in_channel // 4)
  82. x = x[:, [0, 2, 1, 3], :].transpose(1,
  83. 2).reshape(out_channel, in_channel)
  84. return x
  85. def correct_unfold_norm_order(x):
  86. in_channel = x.shape[0]
  87. x = x.reshape(4, in_channel // 4)
  88. x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
  89. return x
  90. for k, v in ckpt.items():
  91. if k.startswith('head'):
  92. continue
  93. elif k.startswith('layers'):
  94. new_v = v
  95. if 'attn.' in k:
  96. new_k = k.replace('attn.', 'attn.w_msa.')
  97. elif 'mlp.' in k:
  98. if 'mlp.fc1.' in k:
  99. new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
  100. elif 'mlp.fc2.' in k:
  101. new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
  102. else:
  103. new_k = k.replace('mlp.', 'ffn.')
  104. elif 'downsample' in k:
  105. new_k = k
  106. if 'reduction.' in k:
  107. new_v = correct_unfold_reduction_order(v)
  108. elif 'norm.' in k:
  109. new_v = correct_unfold_norm_order(v)
  110. else:
  111. new_k = k
  112. new_k = new_k.replace('layers', 'stages', 1)
  113. elif k.startswith('patch_embed'):
  114. new_v = v
  115. if 'proj' in k:
  116. new_k = k.replace('proj', 'projection')
  117. else:
  118. new_k = k
  119. else:
  120. new_v = v
  121. new_k = k
  122. new_ckpt['backbone.' + new_k] = new_v
  123. return new_ckpt

No Description

Contributors (1)