|
|
|
@@ -57,7 +57,9 @@ if __name__ == '__main__': |
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
gradients_mean=True, |
|
|
|
all_reduce_fusion_config=[9, 11]) |
|
|
|
init() |
|
|
|
rank_id = int(os.environ.get('RANK_ID')) |
|
|
|
elif args_opt.device_target == "GPU": |
|
|
|
|