|
|
|
@@ -81,7 +81,7 @@ def run_pretrain(): |
|
|
|
args_opt = parser.parse_args() |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
context.set_context(reserve_class_name_in_scope=False) |
|
|
|
|
|
|
|
context.set_context(variable_memory_max_size="30GB") |
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path |
|
|
|
if args_opt.distribute == "true": |
|
|
|
if args_opt.device_target == 'Ascend': |
|
|
|
|