From 8fa83cca873a730e0df52285bb61bcaeecdfdcbf Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Fri, 9 Oct 2020 10:54:53 +0800 Subject: [PATCH] delete device id for gpu --- model_zoo/official/nlp/tinybert/run_general_distill.py | 8 +++++++- model_zoo/official/nlp/tinybert/run_task_distill.py | 9 ++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py index 1c10313dde..ceb9b82ae7 100644 --- a/model_zoo/official/nlp/tinybert/run_general_distill.py +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -62,7 +62,13 @@ def run_general_distill(): help="dataset type tfrecord/mindrecord, default is tfrecord") args_opt = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + if args_opt.device_target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + elif args_opt.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + else: + raise Exception("Target error, GPU or Ascend is supported.") + context.set_context(reserve_class_name_in_scope=False) context.set_context(variable_memory_max_size="30GB") diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index 15948449f5..459fd52900 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -184,7 +184,14 @@ def run_task_distill(ckpt_file): if ckpt_file == '': raise ValueError("Student ckpt file should not be None") cfg = phase2_cfg - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + + if args_opt.device_target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) + elif args_opt.device_target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) + else: + raise Exception("Target error, GPU or Ascend is supported.") + load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path load_student_checkpoint_path = ckpt_file netwithloss = BertNetworkWithLoss_td(teacher_config=td_teacher_net_cfg, teacher_ckpt=load_teacher_checkpoint_path,