|
|
|
@@ -214,7 +214,7 @@ if __name__ == '__main__': |
|
|
|
loss_scale_manager=loss_scale) |
|
|
|
|
|
|
|
cb = [Monitor(lr_init=lr.asnumpy())] |
|
|
|
if args_opt.run_distribute: |
|
|
|
if args_opt.run_distribute and args_opt.device_target != "CPU": |
|
|
|
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" |
|
|
|
else: |
|
|
|
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" |
|
|
|
|