|
|
@@ -1104,7 +1104,7 @@ bool ExecutorPy::AddDFGraph(const py::dict& init_params, const std::string& phas |
|
|
} |
|
|
} |
|
|
std::string init_graph = "init_subgraph." + net_id; |
|
|
std::string init_graph = "init_subgraph." + net_id; |
|
|
std::string checkpoint_name = "save." + net_id; |
|
|
std::string checkpoint_name = "save." + net_id; |
|
|
if (phase == "train") { |
|
|
|
|
|
|
|
|
if (phase.find("train") != std::string::npos) { |
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); |
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); |
|
|
} else { |
|
|
} else { |
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); |
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); |
|
|
|