|
|
|
@@ -80,9 +80,9 @@ if __name__ == "__main__": |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) |
|
|
|
callback.append(ckpoint_cb) |
|
|
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], |
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, |
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, |
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) |
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, |
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, |
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) |
|
|
|
net.set_train() |
|
|
|
model_fine_tune(args_opt, net, 'layer') |
|
|
|
loss = OhemLoss(config.seg_num_classes, config.ignore_label) |
|
|
|
|