Browse Source

bugfix tinybert

tags/v1.1.0
yoonlee666 5 years ago
parent
commit
0e85c5c9a4
1 changed files with 12 additions and 2 deletions
  1. +12
    -2
      model_zoo/official/nlp/tinybert/run_task_distill.py

+ 12
- 2
model_zoo/official/nlp/tinybert/run_task_distill.py View File

@@ -112,7 +112,12 @@ def run_predistill():
run predistill
"""
cfg = phase1_cfg
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
context.set_context(reserve_class_name_in_scope=False)
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path
load_student_checkpoint_path = args_opt.load_gd_ckpt_path
@@ -265,7 +270,12 @@ def do_eval_standalone():
ckpt_file = args_opt.load_td1_ckpt_path
if ckpt_file == '':
raise ValueError("Student ckpt file should not be None")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else:
raise Exception("Target error, GPU or Ascend is supported.")
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student")
param_dict = load_checkpoint(ckpt_file)
new_param_dict = {}


Loading…
Cancel
Save