|
|
|
@@ -68,7 +68,8 @@ def run_pretrain(): |
|
|
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") |
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.") |
|
|
|
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.") |
|
|
|
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") |
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path") |
|
|
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path") |
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " |
|
|
|
"default is 1000.") |
|
|
|
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " |
|
|
|
@@ -81,7 +82,7 @@ def run_pretrain(): |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
|
context.set_context(reserve_class_name_in_scope=False) |
|
|
|
|
|
|
|
ckpt_save_dir = args_opt.checkpoint_path |
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path |
|
|
|
if args_opt.distribute == "true": |
|
|
|
if args_opt.device_target == 'Ascend': |
|
|
|
D.init('hccl') |
|
|
|
@@ -91,7 +92,7 @@ def run_pretrain(): |
|
|
|
D.init('nccl') |
|
|
|
device_num = D.get_group_size() |
|
|
|
rank = D.get_rank() |
|
|
|
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/' |
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' |
|
|
|
|
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, |
|
|
|
@@ -150,8 +151,8 @@ def run_pretrain(): |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck) |
|
|
|
callback.append(ckpoint_cb) |
|
|
|
|
|
|
|
if args_opt.checkpoint_path: |
|
|
|
param_dict = load_checkpoint(args_opt.checkpoint_path) |
|
|
|
if args_opt.load_checkpoint_path: |
|
|
|
param_dict = load_checkpoint(args_opt.load_checkpoint_path) |
|
|
|
load_param_into_net(netwithloss, param_dict) |
|
|
|
|
|
|
|
if args_opt.enable_lossscale == "true": |
|
|
|
|