|
|
|
@@ -25,10 +25,15 @@ |
|
|
|
#include "backend/session/executor_manager.h" |
|
|
|
#include "runtime/device/kernel_runtime_manager.h" |
|
|
|
#include "runtime/dev.h" |
|
|
|
#include "pipeline/jit/pipeline.h" |
|
|
|
#include "frontend/parallel/step_parallel.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); |
|
|
|
|
|
|
|
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL"; |
|
|
|
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE"; |
|
|
|
|
|
|
|
AscendGraphImpl::AscendGraphImpl() |
|
|
|
: session_impl_(nullptr), |
|
|
|
graph_id_(0), |
|
|
|
@@ -209,11 +214,11 @@ Status AscendGraphImpl::Load() { |
|
|
|
} |
|
|
|
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_); |
|
|
|
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_); |
|
|
|
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) { |
|
|
|
if (inputs_info_.size() != input_names_.size()) { |
|
|
|
MS_LOG_ERROR << "Get model inputs info failed"; |
|
|
|
return kMCInvalidInput; |
|
|
|
} |
|
|
|
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) { |
|
|
|
if (outputs_info_.size() != output_names_.size()) { |
|
|
|
MS_LOG_ERROR << "Get model outputs info failed"; |
|
|
|
return kMCInvalidInput; |
|
|
|
} |
|
|
|
@@ -287,12 +292,34 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto env_hccl_mode = common::GetEnv(kHcclEnable); |
|
|
|
if (!env_hccl_mode.empty() && env_hccl_mode != std::to_string(0)) { |
|
|
|
MS_LOG(INFO) << "Enable hccl parallel mode."; |
|
|
|
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true); |
|
|
|
} |
|
|
|
|
|
|
|
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); |
|
|
|
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); |
|
|
|
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice); |
|
|
|
auto ret = rtSetDevice(device_id_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
|
|
|
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) { |
|
|
|
pipeline::InitHccl(); |
|
|
|
auto para_group_file = common::GetEnv(kHcclGroupFile); |
|
|
|
if (para_group_file.empty()) { |
|
|
|
MS_LOG(INFO) << "Cannot get Env " << kHcclGroupFile << ", skip."; |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Get env " << kHcclGroupFile << " success: " << para_group_file; |
|
|
|
if (!parallel::CreateGroupsByCkptFile(para_group_file)) { |
|
|
|
MS_LOG(ERROR) << "CreateGroupsByCkptFile failed."; |
|
|
|
errno_ = kMCFailed; |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = rtSetDevice(device_id_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Device " << device_id << " init env success."; |
|
|
|
@@ -310,10 +337,18 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = rtDeviceReset(device_id_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
return; |
|
|
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) { |
|
|
|
PythonEnvGuard guard; |
|
|
|
if (!context::CloseTsd(ms_context)) { |
|
|
|
MS_LOG(ERROR) << "CloseTsd failed!"; |
|
|
|
return; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = rtDeviceReset(device_id_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "End finalize device " << device_id_; |
|
|
|
|