|
|
|
@@ -61,7 +61,7 @@ if __name__ == '__main__': |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) |
|
|
|
if args_opt.parameter_server: |
|
|
|
context.set_ps_context(enable_ps=True) |
|
|
|
device_id = int(os.getenv('DEVICE_ID'), '0') |
|
|
|
device_id = int(os.getenv('DEVICE_ID', '0')) |
|
|
|
if args_opt.run_distribute: |
|
|
|
if target == "Ascend": |
|
|
|
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) |
|
|
|
|