|
|
|
@@ -99,8 +99,7 @@ def _get_optimizer(args_opt, network): |
|
|
|
def _auto_enable_graph_kernel(device_target, graph_kernel_mode): |
|
|
|
"""Judge whether is suitable to enable graph kernel.""" |
|
|
|
return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \ |
|
|
|
cfg.bert_network == 'base' and (cfg.batch_size == 32 or cfg.batch_size == 64 or cfg.batch_size == 160) and \ |
|
|
|
cfg.optimizer == 'AdamWeightDecay' |
|
|
|
cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay' |
|
|
|
|
|
|
|
|
|
|
|
def run_pretrain(): |
|
|
|
|