| @@ -51,8 +51,8 @@ Parameters for both training and evaluating can be set in config.py. | |||||
| "image_height": 224, # image height | "image_height": 224, # image height | ||||
| "image_width": 224, # image width | "image_width": 224, # image width | ||||
| "save_checkpoint": True, # whether save checkpoint or not | "save_checkpoint": True, # whether save checkpoint or not | ||||
| "save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step | |||||
| "keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint | |||||
| "save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch | |||||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | |||||
| "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | ||||
| "warmup_epochs": 0, # number of warmup epoch | "warmup_epochs": 0, # number of warmup epoch | ||||
| "lr_decay_mode": "cosine" # decay mode for generating learning rate | "lr_decay_mode": "cosine" # decay mode for generating learning rate | ||||
| @@ -28,8 +28,8 @@ config = ed({ | |||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_steps": 500, | |||||
| "keep_checkpoint_max": 40, | |||||
| "save_checkpoint_epochs": 1, | |||||
| "keep_checkpoint_max": 10, | |||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "warmup_epochs": 0, | "warmup_epochs": 0, | ||||
| "lr_decay_mode": "cosine", | "lr_decay_mode": "cosine", | ||||
| @@ -54,7 +54,7 @@ if __name__ == '__main__': | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True, parameter_broadcast=True) | mirror_mean=True, parameter_broadcast=True) | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([140]) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) | |||||
| init() | init() | ||||
| epoch_size = config.epoch_size | epoch_size = config.epoch_size | ||||
| @@ -59,7 +59,7 @@ if __name__ == '__main__': | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True, parameter_broadcast=True) | mirror_mean=True, parameter_broadcast=True) | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([140]) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) | |||||
| init() | init() | ||||
| epoch_size = config.epoch_size | epoch_size = config.epoch_size | ||||
| @@ -91,7 +91,7 @@ if __name__ == '__main__': | |||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
| if config.save_checkpoint: | if config.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, | |||||
| keep_checkpoint_max=config.keep_checkpoint_max) | keep_checkpoint_max=config.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) | ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck) | ||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||