Browse Source

fix bug for convert scripts

tags/v1.2.0-rc1
qujianwei 4 years ago
parent
commit
8b7d6b3b67
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      model_zoo/official/cv/faster_rcnn/src/convert_checkpoint.py
  2. +1
    -1
      model_zoo/official/cv/maskrcnn/src/convert_checkpoint.py

+ 2
- 2
model_zoo/official/cv/faster_rcnn/src/convert_checkpoint.py View File

@@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
param_name = msname
if "down_sample_layer.0" in param_name:
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
if "down_sample-layer.1" in param_name:
if "down_sample_layer.1" in param_name:
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
weights[param_name] = ms_ckpt[msname].data.asnumpy()
if use_fp16_weight:
@@ -60,5 +60,5 @@ def load_weights(model_path, use_fp16_weight):
return param_list

if __name__ == "__main__":
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=True)
parameter_list = load_weights(args_opt.ckpt_file, use_fp16_weight=False)
save_checkpoint(parameter_list, "resnet50_backbone.ckpt")

+ 1
- 1
model_zoo/official/cv/maskrcnn/src/convert_checkpoint.py View File

@@ -44,7 +44,7 @@ def load_weights(model_path, use_fp16_weight):
param_name = msname
if "down_sample_layer.0" in param_name:
param_name = param_name.replace("down_sample_layer.0", "conv_down_sample")
if "down_sample-layer.1" in param_name:
if "down_sample_layer.1" in param_name:
param_name = param_name.replace("down_sample_layer.1", "bn_down_sample")
weights[param_name] = ms_ckpt[msname].data.asnumpy()
if use_fp16_weight:


Loading…
Cancel
Save