|
|
|
@@ -56,6 +56,9 @@ |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "utils/config_manager.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#if ENABLE_CPU && ENABLE_GPU |
|
|
|
#include "ps/util.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
@@ -255,9 +258,12 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr |
|
|
|
// Graph kernel fusion optimization |
|
|
|
GraphKernelOptimize(graph); |
|
|
|
|
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) |
|
|
|
// Assign parameter keys. |
|
|
|
AssignParamKey(graph); |
|
|
|
#if ENABLE_CPU && ENABLE_GPU |
|
|
|
if (ps::Util::IsParamServerMode()) { |
|
|
|
CheckPSModeConsistence(graph); |
|
|
|
// Assign parameter keys. |
|
|
|
AssignParamKey(graph); |
|
|
|
} |
|
|
|
#endif |
|
|
|
// Start gpu kernel runtime |
|
|
|
StartKernelRT(); |
|
|
|
@@ -299,7 +305,7 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: |
|
|
|
// Load input data from user input |
|
|
|
LoadInputData(kernel_graph, inputs); |
|
|
|
PreIterationDbg(kernel_graph); |
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) |
|
|
|
#if ENABLE_CPU && ENABLE_GPU |
|
|
|
// Initialize parameter server |
|
|
|
InitPSParamAndOptim(kernel_graph, inputs); |
|
|
|
#endif |
|
|
|
|