Browse Source

Enable acceleration by Graph Kernel for Wide&Deep base model

pull/16046/head
lishanni513 4 years ago
parent
commit
e17ec96df6
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      model_zoo/official/recommend/wide_and_deep/train.py
  2. +1
    -1
      model_zoo/official/recommend/wide_and_deep/train_and_eval.py

+ 1
- 1
model_zoo/official/recommend/wide_and_deep/train.py View File

@@ -89,5 +89,5 @@ if __name__ == "__main__":
config = WideDeepConfig() config = WideDeepConfig()
config.argparse_init() config.argparse_init()


context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=config.device_target)
test_train(config) test_train(config)

+ 1
- 1
model_zoo/official/recommend/wide_and_deep/train_and_eval.py View File

@@ -106,6 +106,6 @@ if __name__ == "__main__":
wide_deep_config = WideDeepConfig() wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init() wide_deep_config.argparse_init()


context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target=wide_deep_config.device_target)
context.set_context(enable_sparse=wide_deep_config.sparse) context.set_context(enable_sparse=wide_deep_config.sparse)
test_train_eval(wide_deep_config) test_train_eval(wide_deep_config)

Loading…
Cancel
Save