|
|
@@ -91,6 +91,12 @@ def _get_optimizer(args_opt, network): |
|
|
return optimizer |
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auto_enable_graph_kernel(device_target, graph_kernel_mode): |
|
|
|
|
|
"""Judge whether is suitable to enable graph kernel.""" |
|
|
|
|
|
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \ |
|
|
|
|
|
cfg.bert_network == 'base' and cfg.batch_size == 32 and cfg.optimizer == 'AdamWeightDecay' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_pretrain(): |
|
|
def run_pretrain(): |
|
|
"""pre-train bert_clue""" |
|
|
"""pre-train bert_clue""" |
|
|
parser = argparse.ArgumentParser(description='bert pre_training') |
|
|
parser = argparse.ArgumentParser(description='bert pre_training') |
|
|
@@ -121,6 +127,8 @@ def run_pretrain(): |
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") |
|
|
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") |
|
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") |
|
|
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") |
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") |
|
|
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") |
|
|
|
|
|
parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"], |
|
|
|
|
|
help="Accelerate by graph kernel, default is auto.") |
|
|
|
|
|
|
|
|
args_opt = parser.parse_args() |
|
|
args_opt = parser.parse_args() |
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id) |
|
|
@@ -145,10 +153,17 @@ def run_pretrain(): |
|
|
rank = 0 |
|
|
rank = 0 |
|
|
device_num = 1 |
|
|
device_num = 1 |
|
|
|
|
|
|
|
|
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32: |
|
|
|
|
|
|
|
|
is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel) |
|
|
|
|
|
|
|
|
|
|
|
if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel: |
|
|
|
|
|
context.set_context(enable_graph_kernel=True) |
|
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \ |
|
|
|
|
|
not is_auto_enable_graph_kernel: |
|
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.') |
|
|
logger.warning('Gpu only support fp32 temporarily, run with fp32.') |
|
|
bert_net_cfg.compute_type = mstype.float32 |
|
|
bert_net_cfg.compute_type = mstype.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args_opt.accumulation_steps > 1: |
|
|
if args_opt.accumulation_steps > 1: |
|
|
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) |
|
|
logger.info("accumulation steps: {}".format(args_opt.accumulation_steps)) |
|
|
logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps)) |
|
|
logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps)) |
|
|
|