| @@ -14,31 +14,37 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/tdt/tdt_handle.h" | |||
| namespace mindspore { | |||
| extern std::set<void **> acl_handle_set; | |||
| namespace dataset { | |||
| std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>(); | |||
| void TdtHandle::AddHandle(acltdtChannelHandle *handle) { | |||
| if (handle != nullptr) { | |||
| acl_handle.emplace_back(handle); | |||
| void TdtHandle::AddHandle(acltdtChannelHandle **handle) { | |||
| if (*handle != nullptr) { | |||
| acl_handle_set.insert(reinterpret_cast<void **>(handle)); | |||
| } | |||
| } | |||
| void TdtHandle::DelHandle(acltdtChannelHandle **handle) { | |||
| void **void_handle = reinterpret_cast<void **>(handle); | |||
| acl_handle_set.erase(void_handle); | |||
| } | |||
| bool TdtHandle::DestroyHandle() { | |||
| bool destroy_all = true; | |||
| for (auto &handle : acl_handle) { | |||
| if (handle != nullptr) { | |||
| if (acltdtDestroyChannel(handle) != ACL_SUCCESS) { | |||
| for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) { | |||
| acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it); | |||
| if (*handle != nullptr) { | |||
| acltdtStopChannel(*handle); | |||
| if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) { | |||
| destroy_all = false; | |||
| } else { | |||
| handle = nullptr; | |||
| *handle = nullptr; | |||
| } | |||
| } | |||
| } | |||
| return destroy_all; | |||
| } | |||
| std::vector<acltdtChannelHandle *> TdtHandle::GetHandle() { return acl_handle; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -17,23 +17,21 @@ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_ | |||
| #include <iostream> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "acl/acl_tdt.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class TdtHandle { | |||
| public: | |||
| static void AddHandle(acltdtChannelHandle *handle); | |||
| static void AddHandle(acltdtChannelHandle **handle); | |||
| static bool DestroyHandle(); | |||
| static std::vector<acltdtChannelHandle *> GetHandle(); | |||
| static void DelHandle(acltdtChannelHandle **handle); | |||
| private: | |||
| TdtHandle() {} | |||
| static std::vector<acltdtChannelHandle *> acl_handle; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -29,15 +29,12 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) { | |||
| if (acl_handle_ == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to create channel for tdt queue."; | |||
| } | |||
| TdtHandle::AddHandle(acl_handle_); | |||
| TdtHandle::AddHandle(&acl_handle_); | |||
| } | |||
| TdtPlugin::~TdtPlugin() { | |||
| std::vector<acltdtChannelHandle *> total_handle = TdtHandle::GetHandle(); | |||
| if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) { | |||
| if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; | |||
| } | |||
| if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; | |||
| } | |||
| } | |||
| @@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| } | |||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); | |||
| #ifdef ENABLE_TDTQUE | |||
| acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); | |||
| acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle(); | |||
| if (acl_handle == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Get acltdt handle failed"; | |||
| return false; | |||
| @@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||
| bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| if (ms_context_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "nullptr"; | |||
| MS_LOG(EXCEPTION) << "ms_context_prt is nullptr"; | |||
| } | |||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { | |||
| return true; | |||
| @@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||
| ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); | |||
| #ifdef ENABLE_TDTQUE | |||
| acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); | |||
| aclError stopStatus = acltdtStopChannel(acl_handle); | |||
| if (stopStatus != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed stop acl data channel for host queue "; | |||
| } else { | |||
| MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; | |||
| } | |||
| MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op "; | |||
| ms_context_ptr->DestroyAclTdtChannelHandle(); | |||
| py::gil_scoped_release gil_release; | |||
| aclError destrodStatus = acltdtDestroyChannel(acl_handle); | |||
| if (destrodStatus != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op "; | |||
| } else { | |||
| MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op "; | |||
| } | |||
| try { | |||
| if (ms_context_ptr->acl_tdt_print.joinable()) { | |||
| MS_LOG(INFO) << "join acl tdt host receive process"; | |||
| @@ -17,6 +17,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| std::set<void **> acl_handle_set = std::set<void **>(); | |||
| // set default log level to WARNING for all sub modules | |||
| int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; | |||
| } // namespace mindspore | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <functional> | |||
| #include "utils/overload.h" | |||
| #include "./securec.h" | |||
| @@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept { | |||
| } | |||
| namespace mindspore { | |||
| extern std::set<void **> acl_handle_set __attribute__((visibility("default"))); | |||
| #define FILE_NAME \ | |||
| (sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \ | |||
| : static_cast<const char *>(__FILE__)) | |||
| @@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) { | |||
| return true; | |||
| } | |||
| #ifdef ENABLE_TDTQUE | |||
| namespace py = pybind11; | |||
| acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() { | |||
| uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| std::string kReceivePrefix = "TF_RECEIVE_"; | |||
| std::string channel_name = "_npu_log"; | |||
| acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); | |||
| if (acl_handle != nullptr) { | |||
| MS_LOG(INFO) << "Success to create acltdt handle."; | |||
| acl_handle_ = acl_handle; | |||
| TdtHandle::AddHandle(&acl_handle_); | |||
| } | |||
| return acl_handle; | |||
| } | |||
| void MsContext::DestroyAclTdtChannelHandle() { | |||
| if (acl_handle_ == nullptr) { | |||
| MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr"; | |||
| return; | |||
| } | |||
| aclError stopStatus = acltdtStopChannel(acl_handle_); | |||
| if (stopStatus != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; | |||
| aclError destroydStatus = acltdtDestroyChannel(acl_handle_); | |||
| if (destroydStatus != ACL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl; | |||
| return; | |||
| } | |||
| TdtHandle::DelHandle(&acl_handle_); | |||
| MS_LOG(INFO) << "Succeed destroy acl channel"; | |||
| } | |||
| #endif | |||
| std::string MsContext::backend_policy() const { | |||
| auto res = std::find_if( | |||
| policy_map_.begin(), policy_map_.end(), | |||
| @@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const { | |||
| #endif | |||
| } | |||
| #ifdef ENABLE_TDTQUE | |||
| acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() { | |||
| if (acl_handle == nullptr) { | |||
| std::string kReceivePrefix = "TF_RECEIVE_"; | |||
| std::string channel_name = "_npu_log"; | |||
| uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||
| acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); | |||
| if (acl_handle == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name; | |||
| return nullptr; | |||
| } | |||
| MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name; | |||
| return acl_handle; | |||
| } | |||
| return acl_handle; | |||
| } | |||
| #endif | |||
| } // namespace mindspore | |||
| @@ -25,9 +25,15 @@ | |||
| #include <utility> | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #ifdef ENABLE_TDTQUE | |||
| #include "pybind11/pybind11.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h" | |||
| using mindspore::dataset::TdtHandle; | |||
| #endif | |||
| #ifndef NO_DLIB | |||
| #include "acl/acl_tdt.h" | |||
| #endif | |||
| namespace mindspore { | |||
| enum MsBackendPolicy { | |||
| kMsBackendGeOnly = 0, | |||
| @@ -137,7 +143,8 @@ class MsContext { | |||
| std::string backend_policy() const; | |||
| bool set_backend_policy(const std::string &policy); | |||
| #ifdef ENABLE_TDTQUE | |||
| acltdtChannelHandle *get_acl_tdt_channel_handle(); | |||
| acltdtChannelHandle *CreateAclTdtChannelHandle(); | |||
| void DestroyAclTdtChannelHandle(); | |||
| #endif | |||
| static void device_seter(DeviceSeter device) { seter_ = device; } | |||
| static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } | |||
| @@ -175,10 +182,9 @@ class MsContext { | |||
| uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; | |||
| float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; | |||
| std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; | |||
| MsBackendPolicy backend_policy_; | |||
| #ifdef ENABLE_TDTQUE | |||
| acltdtChannelHandle *acl_handle = nullptr; | |||
| acltdtChannelHandle *acl_handle_ = nullptr; | |||
| #endif | |||
| }; | |||
| @@ -14,9 +14,6 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/ms_utils.h" | |||
| #include <string> | |||
| #include <vector> | |||
| #include <atomic> | |||
| namespace mindspore { | |||
| namespace common { | |||
| @@ -19,6 +19,8 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <atomic> | |||
| #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | |||
| ClassType(const ClassType &) = delete; \ | |||