|
|
|
@@ -23,9 +23,10 @@ |
|
|
|
#include "backend/session/session_factory.h" |
|
|
|
#include "base/base_ref_utils.h" |
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h" |
|
|
|
#include "utils/context/context_extends.h" |
|
|
|
|
|
|
|
#ifdef ENABLE_D |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
using std::string; |
|
|
|
@@ -233,7 +234,7 @@ Status MSInferSession::FinalizeEnv() { |
|
|
|
MS_LOG(ERROR) << "Get Context failed!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (!ms_context->CloseTsd()) { |
|
|
|
if (!context::CloseTsd(ms_context)) { |
|
|
|
MS_LOG(ERROR) << "Inference CloseTsd failed!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -340,6 +341,10 @@ string MSInferSession::AjustTargetName(const std::string &device) { |
|
|
|
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { |
|
|
|
RegAllOp(); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
if (ms_context == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Get Context failed!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
ms_context->set_execution_mode(kGraphMode); |
|
|
|
ms_context->set_device_id(device_id); |
|
|
|
auto ajust_device = AjustTargetName(device); |
|
|
|
@@ -353,11 +358,7 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
session_impl_->Init(device_id); |
|
|
|
if (ms_context == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Get Context failed!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (!ms_context->OpenTsd()) { |
|
|
|
if (!context::OpenTsd(ms_context)) { |
|
|
|
MS_LOG(ERROR) << "Session init OpenTsd failed!"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|