|
|
|
@@ -46,6 +46,7 @@ def main(): |
|
|
|
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") |
|
|
|
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.") |
|
|
|
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") |
|
|
|
parser.add_argument("--filter_weight", type=bool, default=False, help="Filter weight parameters, default is False.") |
|
|
|
args_opt = parser.parse_args() |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) |
|
|
|
@@ -117,7 +118,8 @@ def main(): |
|
|
|
if args_opt.pre_trained_epoch_size <= 0: |
|
|
|
raise KeyError("pre_trained_epoch_size must be greater than 0.") |
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained) |
|
|
|
filter_checkpoint_parameter(param_dict) |
|
|
|
if args_opt.filter_weight: |
|
|
|
filter_checkpoint_parameter(param_dict) |
|
|
|
load_param_into_net(net, param_dict) |
|
|
|
|
|
|
|
lr = Tensor(get_lr(global_step=config.global_step, |
|
|
|
|