|
|
|
@@ -51,11 +51,16 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") |
|
|
|
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], |
|
|
|
help="Run distribute, default is false.") |
|
|
|
parser.add_argument("--enable_graph_kernel", type=str, default="true", choices=["true", "false"], |
|
|
|
help="Accelerate by graph kernel, default is true.") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
_enable_graph_kernel = args.enable_graph_kernel == "true" and args.device_target == "GPU" |
|
|
|
context.set_context( |
|
|
|
mode=context.GRAPH_MODE, |
|
|
|
save_graphs=False, |
|
|
|
enable_graph_kernel=_enable_graph_kernel, |
|
|
|
device_target=args.device_target) |
|
|
|
|
|
|
|
rank = 0 |
|
|
|
|