Merge pull request !1717 from gengdongjie/mastertags/v0.5.0-beta
| @@ -29,7 +29,7 @@ config = ed({ | |||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_epochs": 1, | |||||
| "save_checkpoint_epochs": 5, | |||||
| "keep_checkpoint_max": 10, | "keep_checkpoint_max": 10, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "warmup_epochs": 0, | "warmup_epochs": 0, | ||||
| @@ -28,7 +28,7 @@ config = ed({ | |||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_steps": 1950, | |||||
| "save_checkpoint_epochs": 5, | |||||
| "keep_checkpoint_max": 10, | "keep_checkpoint_max": 10, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "warmup_epochs": 5, | "warmup_epochs": 5, | ||||
| @@ -43,6 +43,8 @@ args_opt = parser.parse_args() | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| target = args_opt.device_target | target = args_opt.device_target | ||||
| ckpt_save_dir = config.save_checkpoint_path | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| @@ -80,13 +82,13 @@ if __name__ == '__main__': | |||||
| else: | else: | ||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | ||||
| amp_level="O2", keep_batchnorm_fp32=True) | |||||
| amp_level="O2", keep_batchnorm_fp32=False) | |||||
| time_cb = TimeMonitor(data_size=step_size) | time_cb = TimeMonitor(data_size=step_size) | ||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
| if config.save_checkpoint: | if config.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, | |||||
| keep_checkpoint_max=config.keep_checkpoint_max) | keep_checkpoint_max=config.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) | ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) | ||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| @@ -29,7 +29,7 @@ config = ed({ | |||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_epochs": 1, | |||||
| "save_checkpoint_epochs": 5, | |||||
| "keep_checkpoint_max": 10, | "keep_checkpoint_max": 10, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "warmup_epochs": 0, | "warmup_epochs": 0, | ||||
| @@ -46,6 +46,8 @@ args_opt = parser.parse_args() | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| target = args_opt.device_target | target = args_opt.device_target | ||||
| ckpt_save_dir = config.save_checkpoint_path | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | if not args_opt.do_eval and args_opt.run_distribute: | ||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||