From 7d54d627f1ad8a841ee5daa7ab5ccaa6ad03eba4 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Wed, 24 Feb 2021 15:14:08 +0800 Subject: [PATCH] 910 parallel inference Signed-off-by: zhoufeng --- .../cxx_api/graph/ascend/ascend_graph_impl.cc | 53 +++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc index f928e0184c..a5ba26997d 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -25,10 +25,15 @@ #include "backend/session/executor_manager.h" #include "runtime/device/kernel_runtime_manager.h" #include "runtime/dev.h" +#include "pipeline/jit/pipeline.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); +static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL"; +static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE"; + AscendGraphImpl::AscendGraphImpl() : session_impl_(nullptr), graph_id_(0), @@ -209,11 +214,11 @@ Status AscendGraphImpl::Load() { } session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_); session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_); - if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) { + if (inputs_info_.size() != input_names_.size()) { MS_LOG_ERROR << "Get model inputs info failed"; return kMCInvalidInput; } - if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) { + if (outputs_info_.size() != output_names_.size()) { MS_LOG_ERROR << "Get model outputs info failed"; return kMCInvalidInput; } @@ -287,12 +292,34 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) { return; } + auto env_hccl_mode = common::GetEnv(kHcclEnable); + if (!env_hccl_mode.empty() && env_hccl_mode != std::to_string(0)) { + MS_LOG(INFO) << "Enable hccl parallel mode."; + ms_context->set_param(MS_CTX_ENABLE_HCCL, true); + } + ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); ms_context->set_param(MS_CTX_DEVICE_ID, device_id_); ms_context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); - auto ret = rtSetDevice(device_id_); - if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast(ret) << "]"; + + if (ms_context->get_param(MS_CTX_ENABLE_HCCL)) { + pipeline::InitHccl(); + auto para_group_file = common::GetEnv(kHcclGroupFile); + if (para_group_file.empty()) { + MS_LOG(INFO) << "Cannot get Env " << kHcclGroupFile << ", skip."; + } else { + MS_LOG(INFO) << "Get env " << kHcclGroupFile << " success: " << para_group_file; + if (!parallel::CreateGroupsByCkptFile(para_group_file)) { + MS_LOG(ERROR) << "CreateGroupsByCkptFile failed."; + errno_ = kMCFailed; + return; + } + } + } else { + auto ret = rtSetDevice(device_id_); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast(ret) << "]"; + } } MS_LOG(INFO) << "Device " << device_id << " init env success."; @@ -310,10 +337,18 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { return; } - auto ret = rtDeviceReset(device_id_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; - return; + if (ms_context->get_param(MS_CTX_ENABLE_HCCL)) { + PythonEnvGuard guard; + if (!context::CloseTsd(ms_context)) { + MS_LOG(ERROR) << "CloseTsd failed!"; + return; + } + } else { + auto ret = rtDeviceReset(device_id_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; + return; + } } MS_LOG(INFO) << "End finalize device " << device_id_;