You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

c2_model_loading.py 15 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import copy
  3. import logging
  4. import re
  5. import torch
  6. from fvcore.common.checkpoint import (
  7. get_missing_parameters_message,
  8. get_unexpected_parameters_message,
  9. )
  10. def convert_basic_c2_names(original_keys):
  11. """
  12. Apply some basic name conversion to names in C2 weights.
  13. It only deals with typical backbone models.
  14. Args:
  15. original_keys (list[str]):
  16. Returns:
  17. list[str]: The same number of strings matching those in original_keys.
  18. """
  19. layer_keys = copy.deepcopy(original_keys)
  20. layer_keys = [
  21. {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
  22. ] # some hard-coded mappings
  23. layer_keys = [k.replace("_", ".") for k in layer_keys]
  24. layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
  25. layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
  26. # Uniform both bn and gn names to "norm"
  27. layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
  28. layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
  29. layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
  30. layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
  31. layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
  32. layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
  33. layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
  34. layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
  35. layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
  36. layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
  37. # stem
  38. layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
  39. # to avoid mis-matching with "conv1" in other components (e.g. detection head)
  40. layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
  41. # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
  42. # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
  43. # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
  44. # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
  45. # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
  46. # blocks
  47. layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
  48. layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
  49. layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
  50. layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
  51. # DensePose substitutions
  52. layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
  53. layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
  54. layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
  55. layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
  56. layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
  57. return layer_keys
  58. def convert_c2_detectron_names(weights):
  59. """
  60. Map Caffe2 Detectron weight names to Detectron2 names.
  61. Args:
  62. weights (dict): name -> tensor
  63. Returns:
  64. dict: detectron2 names -> tensor
  65. dict: detectron2 names -> C2 names
  66. """
  67. logger = logging.getLogger(__name__)
  68. logger.info("Remapping C2 weights ......")
  69. original_keys = sorted(weights.keys())
  70. layer_keys = copy.deepcopy(original_keys)
  71. layer_keys = convert_basic_c2_names(layer_keys)
  72. # --------------------------------------------------------------------------
  73. # RPN hidden representation conv
  74. # --------------------------------------------------------------------------
  75. # FPN case
  76. # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
  77. # shared for all other levels, hence the appearance of "fpn2"
  78. layer_keys = [
  79. k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
  80. ]
  81. # Non-FPN case
  82. layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
  83. # --------------------------------------------------------------------------
  84. # RPN box transformation conv
  85. # --------------------------------------------------------------------------
  86. # FPN case (see note above about "fpn2")
  87. layer_keys = [
  88. k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
  89. for k in layer_keys
  90. ]
  91. layer_keys = [
  92. k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
  93. for k in layer_keys
  94. ]
  95. # Non-FPN case
  96. layer_keys = [
  97. k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
  98. ]
  99. layer_keys = [
  100. k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
  101. for k in layer_keys
  102. ]
  103. # --------------------------------------------------------------------------
  104. # Fast R-CNN box head
  105. # --------------------------------------------------------------------------
  106. layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
  107. layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
  108. layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
  109. layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
  110. # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
  111. layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
  112. # --------------------------------------------------------------------------
  113. # FPN lateral and output convolutions
  114. # --------------------------------------------------------------------------
  115. def fpn_map(name):
  116. """
  117. Look for keys with the following patterns:
  118. 1) Starts with "fpn.inner."
  119. Example: "fpn.inner.res2.2.sum.lateral.weight"
  120. Meaning: These are lateral pathway convolutions
  121. 2) Starts with "fpn.res"
  122. Example: "fpn.res2.2.sum.weight"
  123. Meaning: These are FPN output convolutions
  124. """
  125. splits = name.split(".")
  126. norm = ".norm" if "norm" in splits else ""
  127. if name.startswith("fpn.inner."):
  128. # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
  129. stage = int(splits[2][len("res") :])
  130. return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
  131. elif name.startswith("fpn.res"):
  132. # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
  133. stage = int(splits[1][len("res") :])
  134. return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
  135. return name
  136. layer_keys = [fpn_map(k) for k in layer_keys]
  137. # --------------------------------------------------------------------------
  138. # Mask R-CNN mask head
  139. # --------------------------------------------------------------------------
  140. # roi_heads.StandardROIHeads case
  141. layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
  142. layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
  143. layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
  144. # roi_heads.Res5ROIHeads case
  145. layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
  146. # --------------------------------------------------------------------------
  147. # Keypoint R-CNN head
  148. # --------------------------------------------------------------------------
  149. # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
  150. layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
  151. layer_keys = [
  152. k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
  153. ]
  154. layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
  155. # --------------------------------------------------------------------------
  156. # Done with replacements
  157. # --------------------------------------------------------------------------
  158. assert len(set(layer_keys)) == len(layer_keys)
  159. assert len(original_keys) == len(layer_keys)
  160. new_weights = {}
  161. new_keys_to_original_keys = {}
  162. for orig, renamed in zip(original_keys, layer_keys):
  163. new_keys_to_original_keys[renamed] = orig
  164. if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
  165. # remove the meaningless prediction weight for background class
  166. new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
  167. new_weights[renamed] = weights[orig][new_start_idx:]
  168. logger.info(
  169. "Remove prediction weight for background class in {}. The shape changes from "
  170. "{} to {}.".format(
  171. renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
  172. )
  173. )
  174. elif renamed.startswith("cls_score."):
  175. # move weights of bg class from original index 0 to last index
  176. logger.info(
  177. "Move classification weights for background class in {} from index 0 to "
  178. "index {}.".format(renamed, weights[orig].shape[0] - 1)
  179. )
  180. new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
  181. else:
  182. new_weights[renamed] = weights[orig]
  183. return new_weights, new_keys_to_original_keys
  184. # Note the current matching is not symmetric.
  185. # it assumes model_state_dict will have longer names.
  186. def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
  187. """
  188. Match names between the two state-dict, and update the values of model_state_dict in-place with
  189. copies of the matched tensor in ckpt_state_dict.
  190. If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
  191. model and will be renamed at first.
  192. Strategy: suppose that the models that we will create will have prefixes appended
  193. to each of its keys, for example due to an extra level of nesting that the original
  194. pre-trained weights from ImageNet won't contain. For example, model.state_dict()
  195. might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
  196. res2.conv1.weight. We thus want to match both parameters together.
  197. For that, we look for each model weight, look among all loaded keys if there is one
  198. that is a suffix of the current weight name, and use it if that's the case.
  199. If multiple matches exist, take the one with longest size
  200. of the corresponding name. For example, for the same model as before, the pretrained
  201. weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
  202. we want to match backbone[0].body.conv1.weight to conv1.weight, and
  203. backbone[0].body.res2.conv1.weight to res2.conv1.weight.
  204. """
  205. model_keys = sorted(list(model_state_dict.keys()))
  206. if c2_conversion:
  207. ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
  208. # original_keys: the name in the original dict (before renaming)
  209. else:
  210. original_keys = {x: x for x in ckpt_state_dict.keys()}
  211. ckpt_keys = sorted(list(ckpt_state_dict.keys()))
  212. def match(a, b):
  213. # Matched ckpt_key should be a complete (starts with '.') suffix.
  214. # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
  215. # but matches whatever_conv1 or mesh_head.whatever_conv1.
  216. return a == b or a.endswith("." + b)
  217. # get a matrix of string matches, where each (i, j) entry correspond to the size of the
  218. # ckpt_key string, if it matches
  219. match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
  220. match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
  221. # use the matched one with longest size in case of multiple matches
  222. max_match_size, idxs = match_matrix.max(1)
  223. # remove indices that correspond to no-match
  224. idxs[max_match_size == 0] = -1
  225. # used for logging
  226. max_len_model = max(len(key) for key in model_keys) if model_keys else 1
  227. max_len_ckpt = max(len(key) for key in ckpt_keys) if ckpt_keys else 1
  228. log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
  229. logger = logging.getLogger(__name__)
  230. # matched_pairs (matched checkpoint key --> matched model key)
  231. matched_keys = {}
  232. for idx_model, idx_ckpt in enumerate(idxs.tolist()):
  233. if idx_ckpt == -1:
  234. continue
  235. key_model = model_keys[idx_model]
  236. key_ckpt = ckpt_keys[idx_ckpt]
  237. value_ckpt = ckpt_state_dict[key_ckpt]
  238. shape_in_model = model_state_dict[key_model].shape
  239. if shape_in_model != value_ckpt.shape:
  240. logger.warning(
  241. "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
  242. key_ckpt, value_ckpt.shape, key_model, shape_in_model
  243. )
  244. )
  245. logger.warning(
  246. "{} will not be loaded. Please double check and see if this is desired.".format(
  247. key_ckpt
  248. )
  249. )
  250. continue
  251. model_state_dict[key_model] = value_ckpt.clone()
  252. if key_ckpt in matched_keys: # already added to matched_keys
  253. logger.error(
  254. "Ambiguity found for {} in checkpoint!"
  255. "It matches at least two keys in the model ({} and {}).".format(
  256. key_ckpt, key_model, matched_keys[key_ckpt]
  257. )
  258. )
  259. raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
  260. matched_keys[key_ckpt] = key_model
  261. logger.info(
  262. log_str_template.format(
  263. key_model,
  264. max_len_model,
  265. original_keys[key_ckpt],
  266. max_len_ckpt,
  267. tuple(shape_in_model),
  268. )
  269. )
  270. matched_model_keys = matched_keys.values()
  271. matched_ckpt_keys = matched_keys.keys()
  272. # print warnings about unmatched keys on both side
  273. unmatched_model_keys = [k for k in model_keys if k not in matched_model_keys]
  274. if len(unmatched_model_keys):
  275. logger.info(get_missing_parameters_message(unmatched_model_keys))
  276. unmatched_ckpt_keys = [k for k in ckpt_keys if k not in matched_ckpt_keys]
  277. if len(unmatched_ckpt_keys):
  278. logger.info(
  279. get_unexpected_parameters_message(original_keys[x] for x in unmatched_ckpt_keys)
  280. )

No Description