|
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import copy
- import logging
- import re
- import torch
- from fvcore.common.checkpoint import (
- get_missing_parameters_message,
- get_unexpected_parameters_message,
- )
-
-
- def convert_basic_c2_names(original_keys):
- """
- Apply some basic name conversion to names in C2 weights.
- It only deals with typical backbone models.
-
- Args:
- original_keys (list[str]):
- Returns:
- list[str]: The same number of strings matching those in original_keys.
- """
- layer_keys = copy.deepcopy(original_keys)
- layer_keys = [
- {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
- ] # some hard-coded mappings
-
- layer_keys = [k.replace("_", ".") for k in layer_keys]
- layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
- layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
- # Uniform both bn and gn names to "norm"
- layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
- layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
- layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
- layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
-
- # stem
- layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
- # to avoid mis-matching with "conv1" in other components (e.g. detection head)
- layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
-
- # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
- # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
- # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
- # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
- # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
-
- # blocks
- layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
- layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
- layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
- layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
-
- # DensePose substitutions
- layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
- layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
- layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
- layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
- layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
- return layer_keys
-
-
- def convert_c2_detectron_names(weights):
- """
- Map Caffe2 Detectron weight names to Detectron2 names.
-
- Args:
- weights (dict): name -> tensor
-
- Returns:
- dict: detectron2 names -> tensor
- dict: detectron2 names -> C2 names
- """
- logger = logging.getLogger(__name__)
- logger.info("Remapping C2 weights ......")
- original_keys = sorted(weights.keys())
- layer_keys = copy.deepcopy(original_keys)
-
- layer_keys = convert_basic_c2_names(layer_keys)
-
- # --------------------------------------------------------------------------
- # RPN hidden representation conv
- # --------------------------------------------------------------------------
- # FPN case
- # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
- # shared for all other levels, hence the appearance of "fpn2"
- layer_keys = [
- k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
- ]
- # Non-FPN case
- layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
-
- # --------------------------------------------------------------------------
- # RPN box transformation conv
- # --------------------------------------------------------------------------
- # FPN case (see note above about "fpn2")
- layer_keys = [
- k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
- for k in layer_keys
- ]
- layer_keys = [
- k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
- for k in layer_keys
- ]
- # Non-FPN case
- layer_keys = [
- k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
- ]
- layer_keys = [
- k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
- for k in layer_keys
- ]
-
- # --------------------------------------------------------------------------
- # Fast R-CNN box head
- # --------------------------------------------------------------------------
- layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
- layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
- layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
- layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
- # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
- layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
-
- # --------------------------------------------------------------------------
- # FPN lateral and output convolutions
- # --------------------------------------------------------------------------
- def fpn_map(name):
- """
- Look for keys with the following patterns:
- 1) Starts with "fpn.inner."
- Example: "fpn.inner.res2.2.sum.lateral.weight"
- Meaning: These are lateral pathway convolutions
- 2) Starts with "fpn.res"
- Example: "fpn.res2.2.sum.weight"
- Meaning: These are FPN output convolutions
- """
- splits = name.split(".")
- norm = ".norm" if "norm" in splits else ""
- if name.startswith("fpn.inner."):
- # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
- stage = int(splits[2][len("res") :])
- return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
- elif name.startswith("fpn.res"):
- # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
- stage = int(splits[1][len("res") :])
- return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
- return name
-
- layer_keys = [fpn_map(k) for k in layer_keys]
-
- # --------------------------------------------------------------------------
- # Mask R-CNN mask head
- # --------------------------------------------------------------------------
- # roi_heads.StandardROIHeads case
- layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
- layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
- layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
- # roi_heads.Res5ROIHeads case
- layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
-
- # --------------------------------------------------------------------------
- # Keypoint R-CNN head
- # --------------------------------------------------------------------------
- # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
- layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
- layer_keys = [
- k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
- ]
- layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
-
- # --------------------------------------------------------------------------
- # Done with replacements
- # --------------------------------------------------------------------------
- assert len(set(layer_keys)) == len(layer_keys)
- assert len(original_keys) == len(layer_keys)
-
- new_weights = {}
- new_keys_to_original_keys = {}
- for orig, renamed in zip(original_keys, layer_keys):
- new_keys_to_original_keys[renamed] = orig
- if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
- # remove the meaningless prediction weight for background class
- new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
- new_weights[renamed] = weights[orig][new_start_idx:]
- logger.info(
- "Remove prediction weight for background class in {}. The shape changes from "
- "{} to {}.".format(
- renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
- )
- )
- elif renamed.startswith("cls_score."):
- # move weights of bg class from original index 0 to last index
- logger.info(
- "Move classification weights for background class in {} from index 0 to "
- "index {}.".format(renamed, weights[orig].shape[0] - 1)
- )
- new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
- else:
- new_weights[renamed] = weights[orig]
-
- return new_weights, new_keys_to_original_keys
-
-
- # Note the current matching is not symmetric.
- # it assumes model_state_dict will have longer names.
- def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
- """
- Match names between the two state-dict, and update the values of model_state_dict in-place with
- copies of the matched tensor in ckpt_state_dict.
- If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
- model and will be renamed at first.
-
- Strategy: suppose that the models that we will create will have prefixes appended
- to each of its keys, for example due to an extra level of nesting that the original
- pre-trained weights from ImageNet won't contain. For example, model.state_dict()
- might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
- res2.conv1.weight. We thus want to match both parameters together.
- For that, we look for each model weight, look among all loaded keys if there is one
- that is a suffix of the current weight name, and use it if that's the case.
- If multiple matches exist, take the one with longest size
- of the corresponding name. For example, for the same model as before, the pretrained
- weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
- we want to match backbone[0].body.conv1.weight to conv1.weight, and
- backbone[0].body.res2.conv1.weight to res2.conv1.weight.
- """
- model_keys = sorted(list(model_state_dict.keys()))
- if c2_conversion:
- ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
- # original_keys: the name in the original dict (before renaming)
- else:
- original_keys = {x: x for x in ckpt_state_dict.keys()}
- ckpt_keys = sorted(list(ckpt_state_dict.keys()))
-
- def match(a, b):
- # Matched ckpt_key should be a complete (starts with '.') suffix.
- # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
- # but matches whatever_conv1 or mesh_head.whatever_conv1.
- return a == b or a.endswith("." + b)
-
- # get a matrix of string matches, where each (i, j) entry correspond to the size of the
- # ckpt_key string, if it matches
- match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
- match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
- # use the matched one with longest size in case of multiple matches
- max_match_size, idxs = match_matrix.max(1)
- # remove indices that correspond to no-match
- idxs[max_match_size == 0] = -1
-
- # used for logging
- max_len_model = max(len(key) for key in model_keys) if model_keys else 1
- max_len_ckpt = max(len(key) for key in ckpt_keys) if ckpt_keys else 1
- log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
- logger = logging.getLogger(__name__)
- # matched_pairs (matched checkpoint key --> matched model key)
- matched_keys = {}
- for idx_model, idx_ckpt in enumerate(idxs.tolist()):
- if idx_ckpt == -1:
- continue
- key_model = model_keys[idx_model]
- key_ckpt = ckpt_keys[idx_ckpt]
- value_ckpt = ckpt_state_dict[key_ckpt]
- shape_in_model = model_state_dict[key_model].shape
-
- if shape_in_model != value_ckpt.shape:
- logger.warning(
- "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
- key_ckpt, value_ckpt.shape, key_model, shape_in_model
- )
- )
- logger.warning(
- "{} will not be loaded. Please double check and see if this is desired.".format(
- key_ckpt
- )
- )
- continue
-
- model_state_dict[key_model] = value_ckpt.clone()
- if key_ckpt in matched_keys: # already added to matched_keys
- logger.error(
- "Ambiguity found for {} in checkpoint!"
- "It matches at least two keys in the model ({} and {}).".format(
- key_ckpt, key_model, matched_keys[key_ckpt]
- )
- )
- raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
-
- matched_keys[key_ckpt] = key_model
- logger.info(
- log_str_template.format(
- key_model,
- max_len_model,
- original_keys[key_ckpt],
- max_len_ckpt,
- tuple(shape_in_model),
- )
- )
- matched_model_keys = matched_keys.values()
- matched_ckpt_keys = matched_keys.keys()
- # print warnings about unmatched keys on both side
- unmatched_model_keys = [k for k in model_keys if k not in matched_model_keys]
- if len(unmatched_model_keys):
- logger.info(get_missing_parameters_message(unmatched_model_keys))
-
- unmatched_ckpt_keys = [k for k in ckpt_keys if k not in matched_ckpt_keys]
- if len(unmatched_ckpt_keys):
- logger.info(
- get_unexpected_parameters_message(original_keys[x] for x in unmatched_ckpt_keys)
- )
|