diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.h b/mindspore/ccsrc/backend/session/kernel_build_client.h index 8e3c639b17..0dd6106dd3 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.h +++ b/mindspore/ccsrc/backend/session/kernel_build_client.h @@ -25,14 +25,28 @@ #include "common/duplex_pipe.h" #include "utils/log_adapter.h" +#include "utils/ms_context.h" namespace mindspore { namespace kernel { void ReplaceStr(std::string *dest, const std::string &replace, char new_char); constexpr inline static int kBufferSize = 4096; +constexpr inline static auto kEnv = "python"; // The TAG as prefix of real command from remote. constexpr inline static auto kTag = "[~]"; +static std::string GetPyExe() { + // get real python executable path + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + return kEnv; + } + auto env = ms_context->get_param(MS_CTX_PYTHON_EXE_PATH); + if (env.empty()) { + return kEnv; + } + return env; +} class KernelBuildClient { public: @@ -164,7 +178,6 @@ static std::string GetScriptFilePath(const std::string cmd_env, const std::strin class AscendKernelBuildClient : public KernelBuildClient { public: // Server configure - constexpr inline static auto kEnv = "python"; constexpr inline static auto kGetPathScript = "-c " "\"" @@ -196,9 +209,12 @@ class AscendKernelBuildClient : public KernelBuildClient { return instance; } - std::string GetEnv() override { return kEnv; } + std::string GetEnv() override { return GetPyExe(); } - std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + std::string GetScript() override { + auto env = GetPyExe(); + return GetScriptFilePath(env, kGetPathScript); + } // Before building. std::string SelectFormat(const std::string &json); @@ -229,7 +245,6 @@ class AscendKernelBuildClient : public KernelBuildClient { class GpuKernelBuildClient : public KernelBuildClient { public: // Server configure - constexpr inline static auto kEnv = "python"; constexpr inline static auto kGetPathScript = "-c " "\"" @@ -249,9 +264,12 @@ class GpuKernelBuildClient : public KernelBuildClient { return instance; } - std::string GetEnv() override { return kEnv; } + std::string GetEnv() override { return GetPyExe(); } - std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); } + std::string GetScript() override { + auto env = GetPyExe(); + return GetScriptFilePath(env, kGetPathScript); + } // Fetch pid(pid_t) from remote. int AkgGetPid(); diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 01b2af52e4..1b0bc210ce 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -91,7 +91,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), py::arg("broadcast_params") = py::dict(), "Build data graph.") .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") - .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); + .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph.") + .def("set_py_exe_path", &ExecutorPy::PyExePath, py::arg("phase") = py::str(""), "set python executable path."); (void)py::class_>(m, "EnvInstance_").def(py::init()); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 08c465417b..4d5711895b 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -895,6 +895,15 @@ void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &ph #endif } +void ExecutorPy::PyExePath(const py::object &py_exe_path) { + if (!py::isinstance(py_exe_path)) { + MS_LOG(EXCEPTION) << "Failed, phase input is not a str"; + } + auto py_exe_path_s = py::cast(py_exe_path); + auto ms_context = MsContext::GetInstance(); + ms_context->set_param(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s); +} + bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, const std::string &phase, bool need_run) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 2e7502ec90..aad5dffc7f 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -90,6 +90,7 @@ class ExecutorPy : public std::enable_shared_from_this { void UpdataParamNodeDefaultInput(const std::string &phase, const std::unordered_map ¶ms); void RunInitGraph(const py::dict &init_params, const std::string &phase); + void PyExePath(const py::object &phase); py::dict GetParameterLayout(const std::string &phase); py::dict GetCNodeStrategy(const std::string &phase); void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy); diff --git a/mindspore/common/api.py b/mindspore/common/api.py index ea315d7658..2142be0191 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -16,6 +16,7 @@ # ============================================================================ """Providing interface methods.""" import types +import sys from collections import OrderedDict from functools import wraps @@ -340,6 +341,7 @@ class _Executor: self.is_init = False self._executor = Executor_.get_instance() self.compile_cache = {} + self._executor.set_py_exe_path(sys.executable) def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes, input_indexs, phase='dataset'): diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 6763e23c44..adeac0b9b6 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -34,6 +34,7 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_SAVE_GRAPHS_FLAG, false); set_param(MS_CTX_SAVE_GRAPHS_PATH, "."); + set_param(MS_CTX_PYTHON_EXE_PATH, "python"); set_param(MS_CTX_ENABLE_DUMP, false); set_param(MS_CTX_SAVE_DUMP_PATH, "."); set_param(MS_CTX_TSD_REF, 0); diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 38eb2d4d32..79802e6f6b 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -102,6 +102,7 @@ enum MsCtxParam : unsigned { MS_CTX_SAVE_DUMP_PATH, MS_CTX_SAVE_GRAPHS_PATH, MS_CTX_VARIABLE_MEMORY_MAX_SIZE, + MS_CTX_PYTHON_EXE_PATH, MS_CTX_TYPE_STRING_END, // parameter numbers of each type