From: @huangxinjing Reviewed-by: @yao_yf,@stsuteng,@zhunaipan Signed-off-by: @stsutengtags/v1.1.0
| @@ -40,7 +40,7 @@ def argparse_init(): | |||||
| parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") | parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") | ||||
| parser.add_argument("--output_path", type=str, default="./output/") | parser.add_argument("--output_path", type=str, default="./output/") | ||||
| parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.") | parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.") | ||||
| parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt", | |||||
| parser.add_argument("--stra_ckpt", type=str, default="./checkpoints", | |||||
| help="The strategy checkpoint file.") | help="The strategy checkpoint file.") | ||||
| parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") | parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") | ||||
| parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.") | parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.") | ||||
| @@ -124,11 +124,15 @@ def train_and_eval(config): | |||||
| eval_callback = EvalCallBack( | eval_callback = EvalCallBack( | ||||
| model, ds_eval, auc_metric, config) | model, ds_eval, auc_metric, config) | ||||
| # Save strategy ckpts according to the rank id, this must be done before initializing the callbacks. | |||||
| config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") | |||||
| callback = LossCallBack(config=config, per_print_times=20) | callback = LossCallBack(config=config, per_print_times=20) | ||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, | ||||
| keep_checkpoint_max=5, integrated_save=False) | keep_checkpoint_max=5, integrated_save=False) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ||||
| directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) | |||||
| directory=os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank())), config=ckptconfig) | |||||
| context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | ||||
| callback_list = [TimeMonitor( | callback_list = [TimeMonitor( | ||||
| ds_train.get_dataset_size()), eval_callback, callback] | ds_train.get_dataset_size()), eval_callback, callback] | ||||
| @@ -115,6 +115,10 @@ def train_and_eval(config): | |||||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | ||||
| if cache_enable: | |||||
| config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") | |||||
| context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | ||||
| callback = LossCallBack(config=config) | callback = LossCallBack(config=config) | ||||
| @@ -129,9 +133,6 @@ def train_and_eval(config): | |||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ||||
| directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', | directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', | ||||
| config=ckptconfig) | config=ckptconfig) | ||||
| if cache_enable: | |||||
| config.stra_ckpt = './stra_ckpt_' + str(get_rank()) + '/strategy.ckpt' | |||||
| context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) | |||||
| callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] | ||||
| if get_rank() == 0: | if get_rank() == 0: | ||||
| callback_list.append(ckpoint_cb) | callback_list.append(ckpoint_cb) | ||||