From 5dd26694d492db67994f7087a8e2a2228edd077f Mon Sep 17 00:00:00 2001 From: y00451588 Date: Mon, 26 Apr 2021 17:06:14 +0800 Subject: [PATCH] Enable acceleration by graph kernel for LSTM model on GPU device. --- model_zoo/official/nlp/lstm/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/model_zoo/official/nlp/lstm/train.py b/model_zoo/official/nlp/lstm/train.py index 87de21c7ef..fdf5002ee6 100644 --- a/model_zoo/official/nlp/lstm/train.py +++ b/model_zoo/official/nlp/lstm/train.py @@ -51,11 +51,16 @@ if __name__ == '__main__': parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"], help="Run distribute, default is false.") + parser.add_argument("--enable_graph_kernel", type=str, default="true", choices=["true", "false"], + help="Accelerate by graph kernel, default is true.") + args = parser.parse_args() + _enable_graph_kernel = args.enable_graph_kernel == "true" and args.device_target == "GPU" context.set_context( mode=context.GRAPH_MODE, save_graphs=False, + enable_graph_kernel=_enable_graph_kernel, device_target=args.device_target) rank = 0