|
|
|
@@ -106,6 +106,7 @@ if __name__ == '__main__': |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) |
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1)) |
|
|
|
|
|
|
|
rank = 0 |
|
|
|
if device_target == "Ascend": |
|
|
|
if args_opt.device_id is not None: |
|
|
|
context.set_context(device_id=args_opt.device_id) |
|
|
|
@@ -117,6 +118,7 @@ if __name__ == '__main__': |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
init() |
|
|
|
rank = get_rank() |
|
|
|
elif device_target == "GPU": |
|
|
|
init() |
|
|
|
|
|
|
|
@@ -124,6 +126,7 @@ if __name__ == '__main__': |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True) |
|
|
|
rank = get_rank() |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported platform.") |
|
|
|
|
|
|
|
@@ -200,14 +203,13 @@ if __name__ == '__main__': |
|
|
|
if device_target == "Ascend": |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, |
|
|
|
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) |
|
|
|
ckpt_save_dir = "./" |
|
|
|
else: # GPU |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, |
|
|
|
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager) |
|
|
|
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" |
|
|
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
time_cb = TimeMonitor(data_size=batch_num) |
|
|
|
ckpt_save_dir = "./ckpt_" + str(rank) + "/" |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, |
|
|
|
config=config_ck) |
|
|
|
loss_cb = LossMonitor() |
|
|
|
|