Browse Source

!3992 add filter_weight parameter

Merge pull request !3992 from hanjun996/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3be7329389
2 changed files with 5 additions and 1 deletions
  1. +2
    -0
      model_zoo/official/cv/ssd/eval.py
  2. +3
    -1
      model_zoo/official/cv/ssd/train.py

+ 2
- 0
model_zoo/official/cv/ssd/eval.py View File

@@ -78,6 +78,8 @@ if __name__ == '__main__':
prefix = "ssd_eval.mindrecord"
mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if args_opt.dataset == "voc":
config.coco_root = config.voc_root
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)


+ 3
- 1
model_zoo/official/cv/ssd/train.py View File

@@ -46,6 +46,7 @@ def main():
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
parser.add_argument("--filter_weight", type=bool, default=False, help="Filter weight parameters, default is False.")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
@@ -117,7 +118,8 @@ def main():
if args_opt.pre_trained_epoch_size <= 0:
raise KeyError("pre_trained_epoch_size must be greater than 0.")
param_dict = load_checkpoint(args_opt.pre_trained)
filter_checkpoint_parameter(param_dict)
if args_opt.filter_weight:
filter_checkpoint_parameter(param_dict)
load_param_into_net(net, param_dict)

lr = Tensor(get_lr(global_step=config.global_step,


Loading…
Cancel
Save