|
|
|
@@ -112,7 +112,12 @@ def run_predistill(): |
|
|
|
run predistill |
|
|
|
""" |
|
|
|
cfg = phase1_cfg |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
if args_opt.device_target == "Ascend": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
elif args_opt.device_target == "GPU": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) |
|
|
|
else: |
|
|
|
raise Exception("Target error, GPU or Ascend is supported.") |
|
|
|
context.set_context(reserve_class_name_in_scope=False) |
|
|
|
load_teacher_checkpoint_path = args_opt.load_teacher_ckpt_path |
|
|
|
load_student_checkpoint_path = args_opt.load_gd_ckpt_path |
|
|
|
@@ -265,7 +270,12 @@ def do_eval_standalone(): |
|
|
|
ckpt_file = args_opt.load_td1_ckpt_path |
|
|
|
if ckpt_file == '': |
|
|
|
raise ValueError("Student ckpt file should not be None") |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
if args_opt.device_target == "Ascend": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
elif args_opt.device_target == "GPU": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) |
|
|
|
else: |
|
|
|
raise Exception("Target error, GPU or Ascend is supported.") |
|
|
|
eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") |
|
|
|
param_dict = load_checkpoint(ckpt_file) |
|
|
|
new_param_dict = {} |
|
|
|
|