| @@ -124,6 +124,25 @@ if __name__ == '__main__': | |||||
| load_path = args_opt.pre_trained | load_path = args_opt.pre_trained | ||||
| if load_path != "": | if load_path != "": | ||||
| param_dict = load_checkpoint(load_path) | param_dict = load_checkpoint(load_path) | ||||
| key_mapping = {'down_sample_layer.1.beta': 'bn_down_sample.beta', | |||||
| 'down_sample_layer.1.gamma': 'bn_down_sample.gamma', | |||||
| 'down_sample_layer.0.weight': 'conv_down_sample.weight', | |||||
| 'down_sample_layer.1.moving_mean': 'bn_down_sample.moving_mean', | |||||
| 'down_sample_layer.1.moving_variance': 'bn_down_sample.moving_variance', | |||||
| } | |||||
| for oldkey in list(param_dict.keys()): | |||||
| if not oldkey.startswith(('backbone', 'end_point', 'global_step', 'learning_rate', 'moments', 'momentum')): | |||||
| data = param_dict.pop(oldkey) | |||||
| newkey = 'backbone.' + oldkey | |||||
| param_dict[newkey] = data | |||||
| oldkey = newkey | |||||
| for k, v in key_mapping.items(): | |||||
| if k in oldkey: | |||||
| newkey = oldkey.replace(k, v) | |||||
| param_dict[newkey] = param_dict.pop(oldkey) | |||||
| break | |||||
| for item in list(param_dict.keys()): | for item in list(param_dict.keys()): | ||||
| if not item.startswith('backbone'): | if not item.startswith('backbone'): | ||||
| param_dict.pop(item) | param_dict.pop(item) | ||||