|
|
|
@@ -124,6 +124,25 @@ if __name__ == '__main__': |
|
|
|
load_path = args_opt.pre_trained |
|
|
|
if load_path != "": |
|
|
|
param_dict = load_checkpoint(load_path) |
|
|
|
|
|
|
|
key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta', |
|
|
|
'down_sample_layer.1.gamma': 'bn_down_sample.gamma', |
|
|
|
'down_sample_layer.0.weight': 'conv_down_sample.weight', |
|
|
|
'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean', |
|
|
|
'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance', |
|
|
|
} |
|
|
|
for oldkey in list(param_dict.keys()): |
|
|
|
if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')): |
|
|
|
data = param_dict.pop(oldkey) |
|
|
|
newkey = 'backbone.' + oldkey |
|
|
|
param_dict[newkey] = data |
|
|
|
oldkey = newkey |
|
|
|
for k, v in key_mapping.items(): |
|
|
|
if k in oldkey: |
|
|
|
newkey = oldkey.replace(k, v) |
|
|
|
param_dict[newkey] = param_dict.pop(oldkey) |
|
|
|
break |
|
|
|
|
|
|
|
for item in list(param_dict.keys()): |
|
|
|
if not item.startswith('backbone'): |
|
|
|
param_dict.pop(item) |
|
|
|
|