diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index e50fb3d4d3..e4a983c2fe 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -16,6 +16,7 @@ #include "backend/session/executor.h" #include "runtime/device/kernel_runtime_manager.h" #include "backend/session/executor_manager.h" +#include "utils/comm_manager.h" namespace mindspore { namespace session { @@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs, } } } - -BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - py::tuple output_tensors(ref_list.size()); - for (size_t i = 0; i < ref_list.size(); ++i) { - auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef - if (utils::isa(output)) { - auto tensor_ptr = utils::cast(output); - MS_EXCEPTION_IF_NULL(tensor_ptr); - output_tensors[i] = tensor_ptr; - } else if (utils::isa(output)) { - py::object obj = utils::cast(output).object_; - py::tuple tensor_tuple = py::cast(obj); - output_tensors[i] = tensor_tuple; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - } - return output_tensors; // turn tuple to py::object and store in PyObjectRef - } else if (utils::isa(base_ref)) { - return base_ref; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } -} } // namespace void CompileNodesTask::Run() { MS_EXCEPTION_IF_NULL(session_); @@ -104,6 +79,10 @@ void RunOpTask::Run() { session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); } +void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } + +void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } + Executor::Executor(const std::string &device_name, uint32_t device_id) { device_name_ = device_name; device_id_ = device_id; @@ -141,22 +120,8 @@ void Executor::WorkerLoop() { } catch (const std::exception &e) { exception_ptr_ = std::current_exception(); } - - auto task_type = task->type_; task = nullptr; - if (task_type == kCompileNodes) { - compile_cond_var_.notify_all(); - } else if (task_type == kCompileGraph) { - compile_cond_var_.notify_all(); - } else if (task_type == kBuildGraph) { - build_cond_var_.notify_all(); - } else if (task_type == kRunGraph) { - run_cond_var_.notify_all(); - } else if (task_type == kBuildOp) { - build_op_cond_var_.notify_all(); - } else if (task_type == kRunOp) { - run_op_cond_var_.notify_all(); - } + sync_cond_var_.notify_all(); } } @@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL task->output_nodes_ = outputs; ready_tasks_.push(task); task_cond_var_.notify_all(); - compile_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); return task->graph_id_; } @@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNullfunc_graph_ = func_graph; ready_tasks_.push(task); task_cond_var_.notify_all(); - compile_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); return task->graph_id_; } @@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { task->graph_id_ = graphId; ready_tasks_.push(task); task_cond_var_.notify_all(); - build_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); } @@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, ready_tasks_.push(task); task_cond_var_.notify_all(); py::gil_scoped_release release; - run_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); } @@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c task->tensors_mask_ = tensors_mask; ready_tasks_.push(task); task_cond_var_.notify_all(); - build_op_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); } -py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { +void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) { CheckException(); std::unique_lock lock(task_mutex_); auto task = std::make_shared(); @@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info task->input_tensors_ = input_tensors; ready_tasks_.push(task); task_cond_var_.notify_all(); - run_op_cond_var_.wait(lock); + sync_cond_var_.wait(lock); CheckException(); + *outputs = task->outputs_; +} - // Trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(task->outputs_); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); - return tuple_tensors; +bool Executor::CreateCommGroup(const std::string &group_name, std::vector ranks) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->group_name_ = group_name; + task->ranks_ = ranks; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + sync_cond_var_.wait(lock); + return task->result_; +} + +bool Executor::DestroyCommGroup(const std::string &group_name) { + std::unique_lock lock(task_mutex_); + auto task = std::make_shared(); + task->group_name_ = group_name; + ready_tasks_.push(task); + task_cond_var_.notify_all(); + sync_cond_var_.wait(lock); + return task->result_; } void Executor::StopWorker() { diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index 467a61ce2e..df4bf03be3 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -32,10 +32,22 @@ #include "ir/tensor.h" #include "utils/any.h" #include "utils/contract.h" +#include "utils/comm_manager.h" namespace mindspore { namespace session { -enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp }; +enum TaskType { + kUnKnown, + kExit, + kCompileNodes, + kCompileGraph, + kBuildGraph, + kBuildOp, + kRunGraph, + kRunOp, + kCreateCommGroup, + kDestroyCommGroup +}; class Task { public: @@ -106,6 +118,25 @@ class RunOpTask : public Task { VectorRef outputs_; }; +class CreateCommGroupTask : public Task { + public: + CreateCommGroupTask() { type_ = kCreateCommGroup; } + ~CreateCommGroupTask() override = default; + void Run() override; + std::string group_name_; + std::vector ranks_; + bool result_; +}; + +class DestroyCommGroupTask : public Task { + public: + DestroyCommGroupTask() { type_ = kDestroyCommGroup; } + ~DestroyCommGroupTask() override = default; + void Run() override; + std::string group_name_; + bool result_; +}; + class ExitTask : public Task { public: ExitTask() { type_ = kExit; } @@ -125,9 +156,11 @@ class Executor { VectorRef *outputs); void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask); - py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors); + void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs); void OnRunGraphFinished(); + bool CreateCommGroup(const std::string &group_name, std::vector ranks); + bool DestroyCommGroup(const std::string &group_name); private: void UpdateOutputTensors(VectorRef *outputs, @@ -143,11 +176,7 @@ class Executor { std::mutex task_mutex_; std::mutex pending_task_mutex_; std::condition_variable task_cond_var_; - std::condition_variable compile_cond_var_; - std::condition_variable build_cond_var_; - std::condition_variable run_cond_var_; - std::condition_variable build_op_cond_var_; - std::condition_variable run_op_cond_var_; + std::condition_variable sync_cond_var_; std::queue> ready_tasks_; std::list> pending_tasks_; std::shared_ptr worker_; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 845bc3e7f8..c4f04c1bfe 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); } -py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { +void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(executor_); - return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors); + executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); } void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index d8483f13b7..7e1f518dca 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this { void RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, const std::vector &tensors_mask); - py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors); + void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, + VectorRef *outputs); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc index 8c707d7fbd..e6ec6ff68a 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -15,12 +15,12 @@ */ #include "frontend/parallel/group_manager.h" - #include #include - #include "frontend/parallel/device_manager.h" +#include "backend/session/executor_manager.h" #include "utils/comm_manager.h" +#include "utils/ms_context.h" namespace mindspore { namespace parallel { @@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto vector ranks; (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), [](const Device dev) { return (uint32_t)dev.rank(); }); - // Create group through the CommManager interface - bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); + // Create group through the executor + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string device_name = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); + auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); + MS_EXCEPTION_IF_NULL(executor); + bool ret = executor->CreateCommGroup(group_name, ranks); if (!ret) { MS_LOG(ERROR) << "Create group failed, group name is " << group_name; return Status::FAILED; @@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto } } +Status GroupManager::DestroyGroup(const std::string &group_name) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string device_name = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); + auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); + MS_EXCEPTION_IF_NULL(executor); + bool ret = executor->DestroyCommGroup(group_name); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { std::string name = (*group).name(); auto it = groups_.find(name); @@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { return Status::FAILED; } (void)groups_.erase(it); - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; + return DestroyGroup(name); } Status GroupManager::DestroyAllGroups() { for (auto &it : groups_) { std::string name = it.first; - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { + auto ret = DestroyGroup(name); + if (ret != Status::SUCCESS) { return Status::FAILED; } } diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.h b/mindspore/ccsrc/frontend/parallel/group_manager.h index 5d4eaef815..2ef6e40f15 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.h +++ b/mindspore/ccsrc/frontend/parallel/group_manager.h @@ -65,6 +65,7 @@ class GroupManager { void Clear(); private: + Status DestroyGroup(const std::string &group_name); // the key is group name (name_) std::map groups_; std::string world_group_; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c9036dbf4a..63f9890ebe 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -19,18 +19,22 @@ #include #include #include +#include #include #include #include "debug/trace.h" #include "pybind_api/ir/tensor_py.h" #include "ir/param_info.h" +#include "ir/anf.h" +#include "ir/tensor.h" #include "utils/any.h" #include "utils/utils.h" #include "utils/ms_context.h" #include "utils/context/context_extends.h" #include "utils/config_manager.h" #include "utils/convert_utils_py.h" +#include "utils/base_ref_extends.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/do_signature.h" @@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector(base_ref)) { + auto ref_list = utils::cast(base_ref); + py::tuple output_tensors(ref_list.size()); + for (size_t i = 0; i < ref_list.size(); ++i) { + auto output = TransformBaseRefListToTuple(ref_list[i]); + if (utils::isa(output)) { + auto tensor_ptr = utils::cast(output); + MS_EXCEPTION_IF_NULL(tensor_ptr); + output_tensors[i] = tensor_ptr; + } else if (utils::isa(output)) { + py::object obj = utils::cast(output).object_; + py::tuple tensor_tuple = py::cast(obj); + output_tensors[i] = tensor_tuple; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + } + return std::make_shared(output_tensors); + } else if (utils::isa(base_ref)) { + return base_ref; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } +} + py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; @@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, &input_tensors); - py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); + + VectorRef outputs; + session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors, &outputs); + + // Trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(outputs); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple result = py::cast(tuple_obj); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); *status = PYNATIVE_SUCCESS; MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms";