|
|
|
@@ -276,7 +276,7 @@ Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MST |
|
|
|
} |
|
|
|
|
|
|
|
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) { |
|
|
|
MS_LOG(INFO) << "Start to init env."; |
|
|
|
MS_LOG(INFO) << "Start to init device " << device_id; |
|
|
|
device_id_ = device_id; |
|
|
|
RegAllOp(); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
@@ -294,49 +294,54 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) { |
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "InitEnv success."; |
|
|
|
MS_LOG(INFO) << "Device " << device_id << " init env success."; |
|
|
|
errno_ = kSuccess; |
|
|
|
} |
|
|
|
|
|
|
|
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { |
|
|
|
MS_LOG(INFO) << "Start finalize env"; |
|
|
|
MS_LOG(INFO) << "Start finalize device " << device_id_; |
|
|
|
session::ExecutorManager::Instance().Clear(); |
|
|
|
device::KernelRuntimeManager::Instance().ClearRuntimeResource(); |
|
|
|
|
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
if (ms_context == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Get Context failed!"; |
|
|
|
errno_ = kMCFailed; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto ret = rtDeviceReset(device_id_); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
errno_ = kSuccess; |
|
|
|
MS_LOG(INFO) << "End finalize env"; |
|
|
|
MS_LOG(INFO) << "End finalize device " << device_id_; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) { |
|
|
|
std::shared_ptr<MsEnvGuard> acl_env; |
|
|
|
std::lock_guard<std::mutex> lock(global_ms_env_mutex_); |
|
|
|
acl_env = global_ms_env_.lock(); |
|
|
|
auto iter = global_ms_env_.find(device_id); |
|
|
|
if (iter != global_ms_env_.end()) { |
|
|
|
acl_env = iter->second.lock(); |
|
|
|
} |
|
|
|
|
|
|
|
if (acl_env != nullptr) { |
|
|
|
MS_LOG(INFO) << "Env has been initialized, skip."; |
|
|
|
} else { |
|
|
|
acl_env = std::make_shared<MsEnvGuard>(device_id); |
|
|
|
if (acl_env->GetErrno() != kSuccess) { |
|
|
|
MS_LOG(ERROR) << "Execute aclInit Failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
global_ms_env_ = acl_env; |
|
|
|
MS_LOG(INFO) << "Env init success"; |
|
|
|
return acl_env; |
|
|
|
} |
|
|
|
|
|
|
|
acl_env = std::make_shared<MsEnvGuard>(device_id); |
|
|
|
if (acl_env->GetErrno() != kSuccess) { |
|
|
|
MS_LOG(ERROR) << "Init ascend env Failed"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
global_ms_env_.emplace(device_id, acl_env); |
|
|
|
MS_LOG(INFO) << "Env init success"; |
|
|
|
return acl_env; |
|
|
|
} |
|
|
|
|
|
|
|
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_; |
|
|
|
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_; |
|
|
|
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; |
|
|
|
} // namespace mindspore |