|
|
|
@@ -43,6 +43,8 @@ args_opt = parser.parse_args() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
target = args_opt.device_target |
|
|
|
ckpt_save_dir = config.save_checkpoint_path |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) |
|
|
|
if not args_opt.do_eval and args_opt.run_distribute: |
|
|
|
if target == "Ascend": |
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
@@ -80,13 +82,13 @@ if __name__ == '__main__': |
|
|
|
else: |
|
|
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, |
|
|
|
amp_level="O2", keep_batchnorm_fp32=True) |
|
|
|
amp_level="O2", keep_batchnorm_fp32=False) |
|
|
|
|
|
|
|
time_cb = TimeMonitor(data_size=step_size) |
|
|
|
loss_cb = LossMonitor() |
|
|
|
cb = [time_cb, loss_cb] |
|
|
|
if config.save_checkpoint: |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size, |
|
|
|
keep_checkpoint_max=config.keep_checkpoint_max) |
|
|
|
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) |
|
|
|
cb += [ckpt_cb] |
|
|
|
|