| @@ -0,0 +1,69 @@ | |||||
| # Deeplab-V3 Example | |||||
| ## Description | |||||
| - This is an example of training DeepLabv3 with PASCAL VOC 2012 dataset in MindSpore. | |||||
| - Paper Rethinking Atrous Convolution for Semantic Image Segmentation | |||||
| Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam | |||||
| ## Requirements | |||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||||
| - Download the VOC 2012 dataset for training. | |||||
| > Notes: | |||||
| If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. | |||||
| ## Running the Example | |||||
| ### Training | |||||
| - Set options in config.py. | |||||
| - Run `run_standalone_train.sh` for non-distributed training. | |||||
| ``` bash | |||||
| sh scripts/run_standalone_train.sh DEVICE_ID EPOCH_SIZE DATA_DIR | |||||
| ``` | |||||
| - Run `run_distribute_train.sh` for distributed training. | |||||
| ``` bash | |||||
| sh scripts/run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATA_DIR MINDSPORE_HCCL_CONFIG_PATH | |||||
| ``` | |||||
| ### Evaluation | |||||
| Set options in evaluation_config.py. Make sure the 'data_file' and 'finetune_ckpt' are set to your own path. | |||||
| - Run run_eval.sh for evaluation. | |||||
| ``` bash | |||||
| sh scripts/run_eval.sh DEVICE_ID DATA_DIR | |||||
| ``` | |||||
| ## Options and Parameters | |||||
| It contains of parameters of Deeplab-V3 model and options for training, which is set in file config.py. | |||||
| ### Options: | |||||
| ``` | |||||
| config.py: | |||||
| learning_rate Learning rate, default is 0.0014. | |||||
| weight_decay Weight decay, default is 5e-5. | |||||
| momentum Momentum, default is 0.97. | |||||
| crop_size Image crop size [height, width] during training, default is 513. | |||||
| eval_scales The scales to resize images for evaluation, default is [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]. | |||||
| output_stride The ratio of input to output spatial resolution, default is 16. | |||||
| ignore_label Ignore label value, default is 255. | |||||
| seg_num_classes Number of semantic classes, including the background class (if exists). | |||||
| foreground classes + 1 background class in the PASCAL VOC 2012 dataset, default is 21. | |||||
| fine_tune_batch_norm Fine tune the batch norm parameters or not, default is False. | |||||
| atrous_rates Atrous rates for atrous spatial pyramid pooling, default is None. | |||||
| decoder_output_stride The ratio of input to output spatial resolution when employing decoder | |||||
| to refine segmentation results, default is None. | |||||
| image_pyramid Input scales for multi-scale feature extraction, default is None. | |||||
| ``` | |||||
| ### Parameters: | |||||
| ``` | |||||
| Parameters for dataset and network: | |||||
| distribute Run distribute, default is false. | |||||
| epoch_size Epoch size, default is 6. | |||||
| batch_size batch size of input dataset: N, default is 2. | |||||
| data_url Train/Evaluation data url, required. | |||||
| checkpoint_url Checkpoint path, default is None. | |||||
| enable_save_ckpt Enable save checkpoint, default is true. | |||||
| save_checkpoint_steps Save checkpoint steps, default is 1000. | |||||
| save_checkpoint_num Save checkpoint numbers, default is 1. | |||||
| ``` | |||||
| @@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(description="Deeplabv3 evaluation") | |||||
| parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.') | parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') | parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') | ||||
| parser.add_argument('--data_url', required=True, default=None, help='Train data url') | |||||
| parser.add_argument('--data_url', required=True, default=None, help='Evaluation data url') | |||||
| parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') | parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| @@ -15,8 +15,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "bash run_eval.sh DEVICE_ID EPOCH_SIZE DATA_DIR" | |||||
| echo "for example: bash run_eval.sh 0 /path/zh-wiki/ " | |||||
| echo "bash run_eval.sh DEVICE_ID DATA_DIR" | |||||
| echo "for example: bash run_eval.sh /path/zh-wiki/ " | |||||
| echo "==============================================================================================================" | echo "==============================================================================================================" | ||||
| DEVICE_ID=$1 | DEVICE_ID=$1 | ||||
| @@ -27,13 +27,12 @@ from src.config import config | |||||
| parser = argparse.ArgumentParser(description="Deeplabv3 training") | parser = argparse.ArgumentParser(description="Deeplabv3 training") | ||||
| parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") | parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.") | ||||
| parser.add_argument('--epoch_size', type=int, default=2, help='Epoch size.') | |||||
| parser.add_argument('--epoch_size', type=int, default=6, help='Epoch size.') | |||||
| parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') | parser.add_argument('--batch_size', type=int, default=2, help='Batch size.') | ||||
| parser.add_argument('--data_url', required=True, default=None, help='Train data url') | parser.add_argument('--data_url', required=True, default=None, help='Train data url') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') | parser.add_argument('--checkpoint_url', default=None, help='Checkpoint path') | ||||
| parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") | parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") | ||||
| parser.add_argument('--max_checkpoint_num', type=int, default=5, help='Max checkpoint number.') | |||||
| parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.") | parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.") | ||||
| parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") | parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| @@ -80,7 +79,7 @@ if __name__ == "__main__": | |||||
| keep_checkpoint_max=args_opt.save_checkpoint_num) | keep_checkpoint_max=args_opt.save_checkpoint_num) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix='checkpoint_deeplabv3', config=config_ck) | ||||
| callback.append(ckpoint_cb) | callback.append(ckpoint_cb) | ||||
| net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], | |||||
| net = deeplabv3_resnet50(config.seg_num_classes, [args_opt.batch_size, 3, args_opt.crop_size, args_opt.crop_size], | |||||
| infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, | infer_scale_sizes=config.eval_scales, atrous_rates=config.atrous_rates, | ||||
| decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, | decoder_output_stride=config.decoder_output_stride, output_stride=config.output_stride, | ||||
| fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) | fine_tune_batch_norm=config.fine_tune_batch_norm, image_pyramid=config.image_pyramid) | ||||