Browse Source

!15707 Enable acceleration by graph kernel for LSTM model.

From: @baita
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/15707/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
2c753ace11
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      model_zoo/official/nlp/lstm/train.py

+ 5
- 0
model_zoo/official/nlp/lstm/train.py View File

@@ -51,11 +51,16 @@ if __name__ == '__main__':
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--enable_graph_kernel", type=str, default="true", choices=["true", "false"],
help="Accelerate by graph kernel, default is true.")

args = parser.parse_args()

_enable_graph_kernel = args.enable_graph_kernel == "true" and args.device_target == "GPU"
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
enable_graph_kernel=_enable_graph_kernel,
device_target=args.device_target)

rank = 0


Loading…
Cancel
Save