| @@ -16,6 +16,7 @@ | |||||
| #include "backend/session/executor.h" | #include "backend/session/executor.h" | ||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| #include "utils/comm_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { | |||||
| if (utils::isa<VectorRef>(base_ref)) { | |||||
| auto ref_list = utils::cast<VectorRef>(base_ref); | |||||
| py::tuple output_tensors(ref_list.size()); | |||||
| for (size_t i = 0; i < ref_list.size(); ++i) { | |||||
| auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef | |||||
| if (utils::isa<tensor::TensorPtr>(output)) { | |||||
| auto tensor_ptr = utils::cast<tensor::TensorPtr>(output); | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| output_tensors[i] = tensor_ptr; | |||||
| } else if (utils::isa<PyObjectRef>(output)) { | |||||
| py::object obj = utils::cast<PyObjectRef>(output).object_; | |||||
| py::tuple tensor_tuple = py::cast<py::tuple>(obj); | |||||
| output_tensors[i] = tensor_tuple; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; | |||||
| } | |||||
| } | |||||
| return output_tensors; // turn tuple to py::object and store in PyObjectRef | |||||
| } else if (utils::isa<tensor::TensorPtr>(base_ref)) { | |||||
| return base_ref; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void CompileNodesTask::Run() { | void CompileNodesTask::Run() { | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| @@ -104,6 +79,10 @@ void RunOpTask::Run() { | |||||
| session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); | session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_); | ||||
| } | } | ||||
| void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | |||||
| void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | |||||
| Executor::Executor(const std::string &device_name, uint32_t device_id) { | Executor::Executor(const std::string &device_name, uint32_t device_id) { | ||||
| device_name_ = device_name; | device_name_ = device_name; | ||||
| device_id_ = device_id; | device_id_ = device_id; | ||||
| @@ -141,22 +120,8 @@ void Executor::WorkerLoop() { | |||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| exception_ptr_ = std::current_exception(); | exception_ptr_ = std::current_exception(); | ||||
| } | } | ||||
| auto task_type = task->type_; | |||||
| task = nullptr; | task = nullptr; | ||||
| if (task_type == kCompileNodes) { | |||||
| compile_cond_var_.notify_all(); | |||||
| } else if (task_type == kCompileGraph) { | |||||
| compile_cond_var_.notify_all(); | |||||
| } else if (task_type == kBuildGraph) { | |||||
| build_cond_var_.notify_all(); | |||||
| } else if (task_type == kRunGraph) { | |||||
| run_cond_var_.notify_all(); | |||||
| } else if (task_type == kBuildOp) { | |||||
| build_op_cond_var_.notify_all(); | |||||
| } else if (task_type == kRunOp) { | |||||
| run_op_cond_var_.notify_all(); | |||||
| } | |||||
| sync_cond_var_.notify_all(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL | |||||
| task->output_nodes_ = outputs; | task->output_nodes_ = outputs; | ||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| compile_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| return task->graph_id_; | return task->graph_id_; | ||||
| } | } | ||||
| @@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph | |||||
| task->func_graph_ = func_graph; | task->func_graph_ = func_graph; | ||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| compile_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| return task->graph_id_; | return task->graph_id_; | ||||
| } | } | ||||
| @@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) { | |||||
| task->graph_id_ = graphId; | task->graph_id_ = graphId; | ||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| build_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| } | } | ||||
| @@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, | |||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| py::gil_scoped_release release; | py::gil_scoped_release release; | ||||
| run_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| } | } | ||||
| @@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c | |||||
| task->tensors_mask_ = tensors_mask; | task->tensors_mask_ = tensors_mask; | ||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| build_op_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| } | } | ||||
| py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||||
| void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||||
| CheckException(); | CheckException(); | ||||
| std::unique_lock<std::mutex> lock(task_mutex_); | std::unique_lock<std::mutex> lock(task_mutex_); | ||||
| auto task = std::make_shared<RunOpTask>(); | auto task = std::make_shared<RunOpTask>(); | ||||
| @@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info | |||||
| task->input_tensors_ = input_tensors; | task->input_tensors_ = input_tensors; | ||||
| ready_tasks_.push(task); | ready_tasks_.push(task); | ||||
| task_cond_var_.notify_all(); | task_cond_var_.notify_all(); | ||||
| run_op_cond_var_.wait(lock); | |||||
| sync_cond_var_.wait(lock); | |||||
| CheckException(); | CheckException(); | ||||
| *outputs = task->outputs_; | |||||
| } | |||||
| // Trans output to tuple | |||||
| auto output_tensors = TransformBaseRefListToTuple(task->outputs_); | |||||
| if (!utils::isa<PyObjectRef>(output_tensors) || | |||||
| !py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) { | |||||
| MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; | |||||
| } | |||||
| py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; | |||||
| py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj); | |||||
| return tuple_tensors; | |||||
| bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | |||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| auto task = std::make_shared<CreateCommGroupTask>(); | |||||
| task->group_name_ = group_name; | |||||
| task->ranks_ = ranks; | |||||
| ready_tasks_.push(task); | |||||
| task_cond_var_.notify_all(); | |||||
| sync_cond_var_.wait(lock); | |||||
| return task->result_; | |||||
| } | |||||
| bool Executor::DestroyCommGroup(const std::string &group_name) { | |||||
| std::unique_lock<std::mutex> lock(task_mutex_); | |||||
| auto task = std::make_shared<DestroyCommGroupTask>(); | |||||
| task->group_name_ = group_name; | |||||
| ready_tasks_.push(task); | |||||
| task_cond_var_.notify_all(); | |||||
| sync_cond_var_.wait(lock); | |||||
| return task->result_; | |||||
| } | } | ||||
| void Executor::StopWorker() { | void Executor::StopWorker() { | ||||
| @@ -32,10 +32,22 @@ | |||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/contract.h" | #include "utils/contract.h" | ||||
| #include "utils/comm_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| enum TaskType { kUnKnown, kExit, kCompileNodes, kCompileGraph, kBuildGraph, kBuildOp, kRunGraph, kRunOp }; | |||||
| enum TaskType { | |||||
| kUnKnown, | |||||
| kExit, | |||||
| kCompileNodes, | |||||
| kCompileGraph, | |||||
| kBuildGraph, | |||||
| kBuildOp, | |||||
| kRunGraph, | |||||
| kRunOp, | |||||
| kCreateCommGroup, | |||||
| kDestroyCommGroup | |||||
| }; | |||||
| class Task { | class Task { | ||||
| public: | public: | ||||
| @@ -106,6 +118,25 @@ class RunOpTask : public Task { | |||||
| VectorRef outputs_; | VectorRef outputs_; | ||||
| }; | }; | ||||
| class CreateCommGroupTask : public Task { | |||||
| public: | |||||
| CreateCommGroupTask() { type_ = kCreateCommGroup; } | |||||
| ~CreateCommGroupTask() override = default; | |||||
| void Run() override; | |||||
| std::string group_name_; | |||||
| std::vector<uint32_t> ranks_; | |||||
| bool result_; | |||||
| }; | |||||
| class DestroyCommGroupTask : public Task { | |||||
| public: | |||||
| DestroyCommGroupTask() { type_ = kDestroyCommGroup; } | |||||
| ~DestroyCommGroupTask() override = default; | |||||
| void Run() override; | |||||
| std::string group_name_; | |||||
| bool result_; | |||||
| }; | |||||
| class ExitTask : public Task { | class ExitTask : public Task { | ||||
| public: | public: | ||||
| ExitTask() { type_ = kExit; } | ExitTask() { type_ = kExit; } | ||||
| @@ -125,9 +156,11 @@ class Executor { | |||||
| VectorRef *outputs); | VectorRef *outputs); | ||||
| void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | void BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask); | const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask); | ||||
| py::tuple RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors); | |||||
| void RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | |||||
| void OnRunGraphFinished(); | void OnRunGraphFinished(); | ||||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | |||||
| bool DestroyCommGroup(const std::string &group_name); | |||||
| private: | private: | ||||
| void UpdateOutputTensors(VectorRef *outputs, | void UpdateOutputTensors(VectorRef *outputs, | ||||
| @@ -143,11 +176,7 @@ class Executor { | |||||
| std::mutex task_mutex_; | std::mutex task_mutex_; | ||||
| std::mutex pending_task_mutex_; | std::mutex pending_task_mutex_; | ||||
| std::condition_variable task_cond_var_; | std::condition_variable task_cond_var_; | ||||
| std::condition_variable compile_cond_var_; | |||||
| std::condition_variable build_cond_var_; | |||||
| std::condition_variable run_cond_var_; | |||||
| std::condition_variable build_op_cond_var_; | |||||
| std::condition_variable run_op_cond_var_; | |||||
| std::condition_variable sync_cond_var_; | |||||
| std::queue<std::shared_ptr<Task>> ready_tasks_; | std::queue<std::shared_ptr<Task>> ready_tasks_; | ||||
| std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; | std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; | ||||
| std::shared_ptr<std::thread> worker_; | std::shared_ptr<std::thread> worker_; | ||||
| @@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i | |||||
| executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); | executor_->BuildOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); | ||||
| } | } | ||||
| py::tuple SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors) { | |||||
| void SessionBasic::RunOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| return executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors); | |||||
| executor_->RunOpAsync(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); | |||||
| } | } | ||||
| void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| @@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | ||||
| void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | void BuildOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int> &tensors_mask); | const std::vector<int> &tensors_mask); | ||||
| py::tuple RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors); | |||||
| void RunOpAsync(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| VectorRef *outputs); | |||||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | ||||
| @@ -15,12 +15,12 @@ | |||||
| */ | */ | ||||
| #include "frontend/parallel/group_manager.h" | #include "frontend/parallel/group_manager.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <vector> | #include <vector> | ||||
| #include "frontend/parallel/device_manager.h" | #include "frontend/parallel/device_manager.h" | ||||
| #include "backend/session/executor_manager.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto | |||||
| vector<uint32_t> ranks; | vector<uint32_t> ranks; | ||||
| (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), | (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), | ||||
| [](const Device dev) { return (uint32_t)dev.rank(); }); | [](const Device dev) { return (uint32_t)dev.rank(); }); | ||||
| // Create group through the CommManager interface | |||||
| bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); | |||||
| // Create group through the executor | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||||
| auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); | |||||
| MS_EXCEPTION_IF_NULL(executor); | |||||
| bool ret = executor->CreateCommGroup(group_name, ranks); | |||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(ERROR) << "Create group failed, group name is " << group_name; | MS_LOG(ERROR) << "Create group failed, group name is " << group_name; | ||||
| return Status::FAILED; | return Status::FAILED; | ||||
| @@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto | |||||
| } | } | ||||
| } | } | ||||
| Status GroupManager::DestroyGroup(const std::string &group_name) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||||
| auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); | |||||
| MS_EXCEPTION_IF_NULL(executor); | |||||
| bool ret = executor->DestroyCommGroup(group_name); | |||||
| if (!ret) { | |||||
| return Status::FAILED; | |||||
| } | |||||
| return Status::SUCCESS; | |||||
| } | |||||
| Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { | Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { | ||||
| std::string name = (*group).name(); | std::string name = (*group).name(); | ||||
| auto it = groups_.find(name); | auto it = groups_.find(name); | ||||
| @@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { | |||||
| return Status::FAILED; | return Status::FAILED; | ||||
| } | } | ||||
| (void)groups_.erase(it); | (void)groups_.erase(it); | ||||
| bool ret = CommManager::GetInstance().DestroyGroup(name); | |||||
| if (!ret) { | |||||
| return Status::FAILED; | |||||
| } | |||||
| return Status::SUCCESS; | |||||
| return DestroyGroup(name); | |||||
| } | } | ||||
| Status GroupManager::DestroyAllGroups() { | Status GroupManager::DestroyAllGroups() { | ||||
| for (auto &it : groups_) { | for (auto &it : groups_) { | ||||
| std::string name = it.first; | std::string name = it.first; | ||||
| bool ret = CommManager::GetInstance().DestroyGroup(name); | |||||
| if (!ret) { | |||||
| auto ret = DestroyGroup(name); | |||||
| if (ret != Status::SUCCESS) { | |||||
| return Status::FAILED; | return Status::FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -65,6 +65,7 @@ class GroupManager { | |||||
| void Clear(); | void Clear(); | ||||
| private: | private: | ||||
| Status DestroyGroup(const std::string &group_name); | |||||
| // the key is group name (name_) | // the key is group name (name_) | ||||
| std::map<std::string, Group> groups_; | std::map<std::string, Group> groups_; | ||||
| std::string world_group_; | std::string world_group_; | ||||
| @@ -19,18 +19,22 @@ | |||||
| #include <typeinfo> | #include <typeinfo> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | #include <set> | ||||
| #include <memory> | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "pybind_api/ir/tensor_py.h" | #include "pybind_api/ir/tensor_py.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/anf.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/context/context_extends.h" | #include "utils/context/context_extends.h" | ||||
| #include "utils/config_manager.h" | #include "utils/config_manager.h" | ||||
| #include "utils/convert_utils_py.h" | #include "utils/convert_utils_py.h" | ||||
| #include "utils/base_ref_extends.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/operator/composite/composite.h" | #include "frontend/operator/composite/composite.h" | ||||
| #include "frontend/operator/composite/do_signature.h" | #include "frontend/operator/composite/do_signature.h" | ||||
| @@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens | |||||
| *input_tensors = new_input_tensors; | *input_tensors = new_input_tensors; | ||||
| } | } | ||||
| BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { | |||||
| if (utils::isa<VectorRef>(base_ref)) { | |||||
| auto ref_list = utils::cast<VectorRef>(base_ref); | |||||
| py::tuple output_tensors(ref_list.size()); | |||||
| for (size_t i = 0; i < ref_list.size(); ++i) { | |||||
| auto output = TransformBaseRefListToTuple(ref_list[i]); | |||||
| if (utils::isa<tensor::TensorPtr>(output)) { | |||||
| auto tensor_ptr = utils::cast<tensor::TensorPtr>(output); | |||||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||||
| output_tensors[i] = tensor_ptr; | |||||
| } else if (utils::isa<PyObjectRef>(output)) { | |||||
| py::object obj = utils::cast<PyObjectRef>(output).object_; | |||||
| py::tuple tensor_tuple = py::cast<py::tuple>(obj); | |||||
| output_tensors[i] = tensor_tuple; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; | |||||
| } | |||||
| } | |||||
| return std::make_shared<PyObjectRef>(output_tensors); | |||||
| } else if (utils::isa<tensor::TensorPtr>(base_ref)) { | |||||
| return base_ref; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; | |||||
| } | |||||
| } | |||||
| py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { | py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; | MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; | ||||
| @@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); | std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); | ||||
| session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); | session->BuildOpAsync(op_exec_info.get(), graph_info, input_tensors, tensors_mask); | ||||
| EraseValueNodeTensor(tensors_mask, &input_tensors); | EraseValueNodeTensor(tensors_mask, &input_tensors); | ||||
| py::tuple result = session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors); | |||||
| VectorRef outputs; | |||||
| session->RunOpAsync(op_exec_info.get(), graph_info, input_tensors, &outputs); | |||||
| // Trans output to tuple | |||||
| auto output_tensors = TransformBaseRefListToTuple(outputs); | |||||
| if (!utils::isa<PyObjectRef>(output_tensors) || | |||||
| !py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) { | |||||
| MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; | |||||
| } | |||||
| py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; | |||||
| py::tuple result = py::cast<py::tuple>(tuple_obj); | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | ||||
| *status = PYNATIVE_SUCCESS; | *status = PYNATIVE_SUCCESS; | ||||
| MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; | MS_LOG(INFO) << "End run op[" << op_exec_info->op_name << "] with backend policy ms"; | ||||