|
|
@@ -76,8 +76,8 @@ if __name__ == "__main__": |
|
|
gradients_mean=True) |
|
|
gradients_mean=True) |
|
|
init() |
|
|
init() |
|
|
elif device_target == "GPU": |
|
|
elif device_target == "GPU": |
|
|
init() |
|
|
|
|
|
if device_num > 1: |
|
|
if device_num > 1: |
|
|
|
|
|
init() |
|
|
context.reset_auto_parallel_context() |
|
|
context.reset_auto_parallel_context() |
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
gradients_mean=True) |
|
|
gradients_mean=True) |
|
|
|