|
|
|
@@ -19,12 +19,17 @@ |
|
|
|
#include <string> |
|
|
|
#include <memory> |
|
|
|
#include <thread> |
|
|
|
#include <atomic> |
|
|
|
|
|
|
|
#include "pybind11/pybind11.h" |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "utils/convert_utils_base.h" |
|
|
|
|
|
|
|
#ifndef NO_DLIB |
|
|
|
#include "acl/acl_tdt.h" |
|
|
|
#include "runtime/dev.h" |
|
|
|
#include "toolchain/plog.h" |
|
|
|
#endif |
|
|
|
#ifdef ENABLE_GE |
|
|
|
#include "transform/graph_ir/df_graph_manager.h" |
|
|
|
#endif |
|
|
|
namespace py = pybind11; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
@@ -70,18 +75,21 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { |
|
|
|
rank_size = IntToUint(rank_env); |
|
|
|
} |
|
|
|
|
|
|
|
int log_ret = DlogReportInitialize(); |
|
|
|
if (log_ret != 0) { |
|
|
|
MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; |
|
|
|
auto ret = rtSetDevice(device_id); |
|
|
|
if (ret != RT_ERROR_NONE) { |
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); |
|
|
|
#ifdef ENABLE_TDTQUE |
|
|
|
acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle(); |
|
|
|
if (acl_handle == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Get acltdt handle failed"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
ms_context_ptr->acl_tdt_print = std::thread(TensorPrint(acl_handle)); |
|
|
|
#endif |
|
|
|
@@ -121,6 +129,7 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { |
|
|
|
} |
|
|
|
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false); |
|
|
|
MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]"; |
|
|
|
(void)DlogReportFinalize(); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = " |
|
|
|
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << "."; |
|
|
|
|