|
|
|
@@ -121,12 +121,12 @@ def train_and_eval(config): |
|
|
|
model = Model(train_net, eval_network=eval_net, |
|
|
|
metrics={"auc": auc_metric}) |
|
|
|
|
|
|
|
eval_callback = EvalCallBack( |
|
|
|
model, ds_eval, auc_metric, config) |
|
|
|
|
|
|
|
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks. |
|
|
|
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") |
|
|
|
|
|
|
|
eval_callback = EvalCallBack( |
|
|
|
model, ds_eval, auc_metric, config) |
|
|
|
|
|
|
|
callback = LossCallBack(config=config, per_print_times=20) |
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, |
|
|
|
keep_checkpoint_max=5, integrated_save=False) |
|
|
|
@@ -146,10 +146,11 @@ if __name__ == "__main__": |
|
|
|
wide_deep_config = WideDeepConfig() |
|
|
|
wide_deep_config.argparse_init() |
|
|
|
context.set_context(mode=context.GRAPH_MODE, |
|
|
|
device_target=wide_deep_config.device_target, save_graphs=True) |
|
|
|
device_target=wide_deep_config.device_target) |
|
|
|
context.set_context(variable_memory_max_size="24GB") |
|
|
|
context.set_context(enable_sparse=True) |
|
|
|
init() |
|
|
|
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True) |
|
|
|
if wide_deep_config.sparse: |
|
|
|
context.set_auto_parallel_context( |
|
|
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True) |
|
|
|
|