| @@ -14,31 +14,37 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/tdt/tdt_handle.h" | #include "minddata/dataset/engine/tdt/tdt_handle.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| extern std::set<void **> acl_handle_set; | |||||
| namespace dataset { | namespace dataset { | ||||
| std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>(); | |||||
| void TdtHandle::AddHandle(acltdtChannelHandle *handle) { | |||||
| if (handle != nullptr) { | |||||
| acl_handle.emplace_back(handle); | |||||
| void TdtHandle::AddHandle(acltdtChannelHandle **handle) { | |||||
| if (*handle != nullptr) { | |||||
| acl_handle_set.insert(reinterpret_cast<void **>(handle)); | |||||
| } | } | ||||
| } | } | ||||
| void TdtHandle::DelHandle(acltdtChannelHandle **handle) { | |||||
| void **void_handle = reinterpret_cast<void **>(handle); | |||||
| acl_handle_set.erase(void_handle); | |||||
| } | |||||
| bool TdtHandle::DestroyHandle() { | bool TdtHandle::DestroyHandle() { | ||||
| bool destroy_all = true; | bool destroy_all = true; | ||||
| for (auto &handle : acl_handle) { | |||||
| if (handle != nullptr) { | |||||
| if (acltdtDestroyChannel(handle) != ACL_SUCCESS) { | |||||
| for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) { | |||||
| acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it); | |||||
| if (*handle != nullptr) { | |||||
| acltdtStopChannel(*handle); | |||||
| if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) { | |||||
| destroy_all = false; | destroy_all = false; | ||||
| } else { | } else { | ||||
| handle = nullptr; | |||||
| *handle = nullptr; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return destroy_all; | return destroy_all; | ||||
| } | } | ||||
| std::vector<acltdtChannelHandle *> TdtHandle::GetHandle() { return acl_handle; } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,23 +17,21 @@ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_ | #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_ | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <vector> | |||||
| #include <set> | |||||
| #include "acl/acl_tdt.h" | #include "acl/acl_tdt.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| class TdtHandle { | class TdtHandle { | ||||
| public: | public: | ||||
| static void AddHandle(acltdtChannelHandle *handle); | |||||
| static void AddHandle(acltdtChannelHandle **handle); | |||||
| static bool DestroyHandle(); | static bool DestroyHandle(); | ||||
| static std::vector<acltdtChannelHandle *> GetHandle(); | |||||
| static void DelHandle(acltdtChannelHandle **handle); | |||||
| private: | private: | ||||
| TdtHandle() {} | TdtHandle() {} | ||||
| static std::vector<acltdtChannelHandle *> acl_handle; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,15 +29,12 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) { | |||||
| if (acl_handle_ == nullptr) { | if (acl_handle_ == nullptr) { | ||||
| MS_LOG(ERROR) << "Failed to create channel for tdt queue."; | MS_LOG(ERROR) << "Failed to create channel for tdt queue."; | ||||
| } | } | ||||
| TdtHandle::AddHandle(acl_handle_); | |||||
| TdtHandle::AddHandle(&acl_handle_); | |||||
| } | } | ||||
| TdtPlugin::~TdtPlugin() { | TdtPlugin::~TdtPlugin() { | ||||
| std::vector<acltdtChannelHandle *> total_handle = TdtHandle::GetHandle(); | |||||
| if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) { | |||||
| if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; | |||||
| } | |||||
| if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed to destroy channel for tdt queue."; | |||||
| } | } | ||||
| } | } | ||||
| @@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||||
| } | } | ||||
| ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); | ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF); | ||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); | |||||
| acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle(); | |||||
| if (acl_handle == nullptr) { | if (acl_handle == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Get acltdt handle failed"; | MS_LOG(EXCEPTION) << "Get acltdt handle failed"; | ||||
| return false; | return false; | ||||
| @@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { | |||||
| bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | ||||
| if (ms_context_ptr == nullptr) { | if (ms_context_ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "nullptr"; | |||||
| MS_LOG(EXCEPTION) << "ms_context_prt is nullptr"; | |||||
| } | } | ||||
| if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { | if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) { | ||||
| return true; | return true; | ||||
| @@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { | |||||
| ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); | ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0); | ||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle(); | |||||
| aclError stopStatus = acltdtStopChannel(acl_handle); | |||||
| if (stopStatus != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed stop acl data channel for host queue "; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; | |||||
| } | |||||
| MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op "; | |||||
| ms_context_ptr->DestroyAclTdtChannelHandle(); | |||||
| py::gil_scoped_release gil_release; | py::gil_scoped_release gil_release; | ||||
| aclError destrodStatus = acltdtDestroyChannel(acl_handle); | |||||
| if (destrodStatus != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op "; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op "; | |||||
| } | |||||
| try { | try { | ||||
| if (ms_context_ptr->acl_tdt_print.joinable()) { | if (ms_context_ptr->acl_tdt_print.joinable()) { | ||||
| MS_LOG(INFO) << "join acl tdt host receive process"; | MS_LOG(INFO) << "join acl tdt host receive process"; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| std::set<void **> acl_handle_set = std::set<void **>(); | |||||
| // set default log level to WARNING for all sub modules | // set default log level to WARNING for all sub modules | ||||
| int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; | int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | |||||
| #include <functional> | #include <functional> | ||||
| #include "utils/overload.h" | #include "utils/overload.h" | ||||
| #include "./securec.h" | #include "./securec.h" | ||||
| @@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept { | |||||
| } | } | ||||
| namespace mindspore { | namespace mindspore { | ||||
| extern std::set<void **> acl_handle_set __attribute__((visibility("default"))); | |||||
| #define FILE_NAME \ | #define FILE_NAME \ | ||||
| (sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \ | (sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \ | ||||
| : static_cast<const char *>(__FILE__)) | : static_cast<const char *>(__FILE__)) | ||||
| @@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| #ifdef ENABLE_TDTQUE | |||||
| namespace py = pybind11; | |||||
| acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() { | |||||
| uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||||
| std::string kReceivePrefix = "TF_RECEIVE_"; | |||||
| std::string channel_name = "_npu_log"; | |||||
| acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); | |||||
| if (acl_handle != nullptr) { | |||||
| MS_LOG(INFO) << "Success to create acltdt handle."; | |||||
| acl_handle_ = acl_handle; | |||||
| TdtHandle::AddHandle(&acl_handle_); | |||||
| } | |||||
| return acl_handle; | |||||
| } | |||||
| void MsContext::DestroyAclTdtChannelHandle() { | |||||
| if (acl_handle_ == nullptr) { | |||||
| MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr"; | |||||
| return; | |||||
| } | |||||
| aclError stopStatus = acltdtStopChannel(acl_handle_); | |||||
| if (stopStatus != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Succeed stop acl data channel for host queue "; | |||||
| aclError destroydStatus = acltdtDestroyChannel(acl_handle_); | |||||
| if (destroydStatus != ACL_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl; | |||||
| return; | |||||
| } | |||||
| TdtHandle::DelHandle(&acl_handle_); | |||||
| MS_LOG(INFO) << "Succeed destroy acl channel"; | |||||
| } | |||||
| #endif | |||||
| std::string MsContext::backend_policy() const { | std::string MsContext::backend_policy() const { | ||||
| auto res = std::find_if( | auto res = std::find_if( | ||||
| policy_map_.begin(), policy_map_.end(), | policy_map_.begin(), policy_map_.end(), | ||||
| @@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const { | |||||
| #endif | #endif | ||||
| } | } | ||||
| #ifdef ENABLE_TDTQUE | |||||
| acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() { | |||||
| if (acl_handle == nullptr) { | |||||
| std::string kReceivePrefix = "TF_RECEIVE_"; | |||||
| std::string channel_name = "_npu_log"; | |||||
| uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||||
| acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str()); | |||||
| if (acl_handle == nullptr) { | |||||
| MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name; | |||||
| return nullptr; | |||||
| } | |||||
| MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name; | |||||
| return acl_handle; | |||||
| } | |||||
| return acl_handle; | |||||
| } | |||||
| #endif | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,9 +25,15 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #ifdef ENABLE_TDTQUE | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h" | |||||
| using mindspore::dataset::TdtHandle; | |||||
| #endif | |||||
| #ifndef NO_DLIB | #ifndef NO_DLIB | ||||
| #include "acl/acl_tdt.h" | #include "acl/acl_tdt.h" | ||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| enum MsBackendPolicy { | enum MsBackendPolicy { | ||||
| kMsBackendGeOnly = 0, | kMsBackendGeOnly = 0, | ||||
| @@ -137,7 +143,8 @@ class MsContext { | |||||
| std::string backend_policy() const; | std::string backend_policy() const; | ||||
| bool set_backend_policy(const std::string &policy); | bool set_backend_policy(const std::string &policy); | ||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| acltdtChannelHandle *get_acl_tdt_channel_handle(); | |||||
| acltdtChannelHandle *CreateAclTdtChannelHandle(); | |||||
| void DestroyAclTdtChannelHandle(); | |||||
| #endif | #endif | ||||
| static void device_seter(DeviceSeter device) { seter_ = device; } | static void device_seter(DeviceSeter device) { seter_ = device; } | ||||
| static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } | static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } | ||||
| @@ -175,10 +182,9 @@ class MsContext { | |||||
| uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; | uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; | ||||
| float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; | float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; | ||||
| std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; | std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; | ||||
| MsBackendPolicy backend_policy_; | MsBackendPolicy backend_policy_; | ||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| acltdtChannelHandle *acl_handle = nullptr; | |||||
| acltdtChannelHandle *acl_handle_ = nullptr; | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -14,9 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <atomic> | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace common { | namespace common { | ||||
| @@ -19,6 +19,8 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <atomic> | |||||
| #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | #define DISABLE_COPY_AND_ASSIGN(ClassType) \ | ||||
| ClassType(const ClassType &) = delete; \ | ClassType(const ClassType &) = delete; \ | ||||