diff --git a/model_zoo/official/recommend/wide_and_deep/train.py b/model_zoo/official/recommend/wide_and_deep/train.py index 4c4e384b6e..7366c41cfd 100644 --- a/model_zoo/official/recommend/wide_and_deep/train.py +++ b/model_zoo/official/recommend/wide_and_deep/train.py @@ -89,5 +89,5 @@ if __name__ == "__main__": config = WideDeepConfig() config.argparse_init() - context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target) test_train(config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py index b88c72f6a8..59c797b30d 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval.py @@ -106,6 +106,6 @@ if __name__ == "__main__": wide_deep_config = WideDeepConfig() wide_deep_config.argparse_init() - context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) + context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=wide_deep_config.device_target) context.set_context(enable_sparse=wide_deep_config.sparse) test_train_eval(wide_deep_config)