|
|
|
@@ -87,13 +87,13 @@ if __name__ == "__main__": |
|
|
|
keep_checkpoint_max=args_opt.save_checkpoint_num) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) |
|
|
|
callback.append(ckpoint_cb) |
|
|
|
net = deeplabv3_resnet50(crop_size.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], |
|
|
|
infer_scale_sizes=crop_size.eval_scales, atrous_rates=crop_size.atrous_rates, |
|
|
|
decoder_output_stride=crop_size.decoder_output_stride, output_stride = crop_size.output_stride, |
|
|
|
fine_tune_batch_norm=crop_size.fine_tune_batch_norm, image_pyramid = crop_size.image_pyramid) |
|
|
|
net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size,3,args_opt.crop_size,args_opt.crop_size], |
|
|
|
infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, |
|
|
|
decoder_output_stride=config.decoder_output_stride, output_stride = config.output_stride, |
|
|
|
fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid = config.image_pyramid) |
|
|
|
net.set_train() |
|
|
|
model_fine_tune(args_opt, net, 'layer') |
|
|
|
loss = OhemLoss(crop_size.seg_num_classes, crop_size.ignore_label) |
|
|
|
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=args_opt.learning_rate, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay) |
|
|
|
loss = OhemLoss(config.seg_num_classes, config.ignore_label) |
|
|
|
opt = Momentum(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'depth' not in x.name and 'bias' not in x.name, net.trainable_params()), learning_rate=config.learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) |
|
|
|
model = Model(net, loss, opt) |
|
|
|
model.train(args_opt.epoch_size, train_dataset, callback) |