|
|
|
@@ -126,6 +126,7 @@ def train_and_eval(config): |
|
|
|
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', |
|
|
|
config=ckptconfig) |
|
|
|
if cache_enable: |
|
|
|
config.stra_ckpt = './stra_ckpt_' + str(get_rank()) + '/strategy.ckpt' |
|
|
|
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) |
|
|
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] |
|
|
|
if get_rank() == 0: |
|
|
|
|