Browse Source

!14133 tensorprint_debug

From: @yepei6
Reviewed-by: 
Signed-off-by:
pull/14133/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
0ef2d78411
10 changed files with 76 additions and 61 deletions
  1. +16
    -10
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc
  2. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h
  3. +3
    -6
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
  4. +3
    -17
      mindspore/ccsrc/utils/context/context_extends.cc
  5. +1
    -0
      mindspore/core/gvar/logging_level.cc
  6. +2
    -0
      mindspore/core/utils/log_adapter.h
  7. +37
    -17
      mindspore/core/utils/ms_context.cc
  8. +9
    -3
      mindspore/core/utils/ms_context.h
  9. +0
    -3
      mindspore/core/utils/ms_utils.cc
  10. +2
    -0
      mindspore/core/utils/ms_utils.h

+ 16
- 10
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.cc View File

@@ -14,31 +14,37 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/tdt/tdt_handle.h"

namespace mindspore {
extern std::set<void **> acl_handle_set;
namespace dataset {

std::vector<acltdtChannelHandle *> TdtHandle::acl_handle = std::vector<acltdtChannelHandle *>();

void TdtHandle::AddHandle(acltdtChannelHandle *handle) {
if (handle != nullptr) {
acl_handle.emplace_back(handle);
void TdtHandle::AddHandle(acltdtChannelHandle **handle) {
if (*handle != nullptr) {
acl_handle_set.insert(reinterpret_cast<void **>(handle));
}
}

void TdtHandle::DelHandle(acltdtChannelHandle **handle) {
void **void_handle = reinterpret_cast<void **>(handle);
acl_handle_set.erase(void_handle);
}

bool TdtHandle::DestroyHandle() {
bool destroy_all = true;
for (auto &handle : acl_handle) {
if (handle != nullptr) {
if (acltdtDestroyChannel(handle) != ACL_SUCCESS) {
for (auto it = acl_handle_set.begin(); it != acl_handle_set.end(); it++) {
acltdtChannelHandle **handle = reinterpret_cast<acltdtChannelHandle **>(*it);
if (*handle != nullptr) {
acltdtStopChannel(*handle);
if (acltdtDestroyChannel(*handle) != ACL_SUCCESS) {
destroy_all = false;
} else {
handle = nullptr;
*handle = nullptr;
}
}
}
return destroy_all;
}

std::vector<acltdtChannelHandle *> TdtHandle::GetHandle() { return acl_handle; }
} // namespace dataset
} // namespace mindspore

+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h View File

@@ -17,23 +17,21 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_TDT_TDT_HANDLE_H_

#include <iostream>
#include <vector>
#include <set>
#include "acl/acl_tdt.h"

namespace mindspore {
namespace dataset {
class TdtHandle {
public:
static void AddHandle(acltdtChannelHandle *handle);
static void AddHandle(acltdtChannelHandle **handle);

static bool DestroyHandle();

static std::vector<acltdtChannelHandle *> GetHandle();
static void DelHandle(acltdtChannelHandle **handle);

private:
TdtHandle() {}

static std::vector<acltdtChannelHandle *> acl_handle;
};
} // namespace dataset
} // namespace mindspore


+ 3
- 6
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc View File

@@ -29,15 +29,12 @@ TdtPlugin::TdtPlugin(const std::string &channel_name, int32_t device_id) {
if (acl_handle_ == nullptr) {
MS_LOG(ERROR) << "Failed to create channel for tdt queue.";
}
TdtHandle::AddHandle(acl_handle_);
TdtHandle::AddHandle(&acl_handle_);
}

TdtPlugin::~TdtPlugin() {
std::vector<acltdtChannelHandle *> total_handle = TdtHandle::GetHandle();
if (std::find(total_handle.begin(), total_handle.end(), acl_handle_) != total_handle.end()) {
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue.";
}
if (acl_handle_ != nullptr && acltdtDestroyChannel(acl_handle_) != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed to destroy channel for tdt queue.";
}
}



+ 3
- 17
mindspore/ccsrc/utils/context/context_extends.cc View File

@@ -78,7 +78,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
}
ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle();
if (acl_handle == nullptr) {
MS_LOG(EXCEPTION) << "Get acltdt handle failed";
return false;
@@ -92,7 +92,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {

bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ms_context_ptr == nullptr) {
MS_LOG(EXCEPTION) << "nullptr";
MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
}
if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
return true;
@@ -102,22 +102,8 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);

#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = ms_context_ptr->get_acl_tdt_channel_handle();
aclError stopStatus = acltdtStopChannel(acl_handle);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel for host queue ";
} else {
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
}
MS_LOG(INFO) << "Succeed run cancellation callback of out-feed dequeue op ";

ms_context_ptr->DestroyAclTdtChannelHandle();
py::gil_scoped_release gil_release;
aclError destrodStatus = acltdtDestroyChannel(acl_handle);
if (destrodStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel for out-feed dequeue op ";
} else {
MS_LOG(INFO) << "Succeed destroy acl channel for out-feed dequeue op ";
}
try {
if (ms_context_ptr->acl_tdt_print.joinable()) {
MS_LOG(INFO) << "join acl tdt host receive process";


+ 1
- 0
mindspore/core/gvar/logging_level.cc View File

@@ -17,6 +17,7 @@
#include "utils/log_adapter.h"

namespace mindspore {
std::set<void **> acl_handle_set = std::set<void **>();
// set default log level to WARNING for all sub modules
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
} // namespace mindspore

+ 2
- 0
mindspore/core/utils/log_adapter.h View File

@@ -22,6 +22,7 @@
#include <string>
#include <sstream>
#include <memory>
#include <set>
#include <functional>
#include "utils/overload.h"
#include "./securec.h"
@@ -41,6 +42,7 @@ static constexpr size_t GetRelPathPos() noexcept {
}

namespace mindspore {
extern std::set<void **> acl_handle_set __attribute__((visibility("default")));
#define FILE_NAME \
(sizeof(__FILE__) > GetRelPathPos() ? static_cast<const char *>(__FILE__) + GetRelPathPos() \
: static_cast<const char *>(__FILE__))


+ 37
- 17
mindspore/core/utils/ms_context.cc View File

@@ -109,6 +109,43 @@ bool MsContext::set_backend_policy(const std::string &policy) {
return true;
}

#ifdef ENABLE_TDTQUE
namespace py = pybind11;
acltdtChannelHandle *MsContext::CreateAclTdtChannelHandle() {
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
acltdtChannelHandle *acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle != nullptr) {
MS_LOG(INFO) << "Success to create acltdt handle.";
acl_handle_ = acl_handle;
TdtHandle::AddHandle(&acl_handle_);
}
return acl_handle;
}

void MsContext::DestroyAclTdtChannelHandle() {
if (acl_handle_ == nullptr) {
MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
return;
}
aclError stopStatus = acltdtStopChannel(acl_handle_);
if (stopStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
return;
}
MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";

aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
if (destroydStatus != ACL_SUCCESS) {
MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
return;
}
TdtHandle::DelHandle(&acl_handle_);
MS_LOG(INFO) << "Succeed destroy acl channel";
}
#endif

std::string MsContext::backend_policy() const {
auto res = std::find_if(
policy_map_.begin(), policy_map_.end(),
@@ -127,21 +164,4 @@ bool MsContext::enable_dump_ir() const {
#endif
}

#ifdef ENABLE_TDTQUE
acltdtChannelHandle *MsContext::get_acl_tdt_channel_handle() {
if (acl_handle == nullptr) {
std::string kReceivePrefix = "TF_RECEIVE_";
std::string channel_name = "_npu_log";
uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
acl_handle = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
if (acl_handle == nullptr) {
MS_LOG(ERROR) << "Failed to create acltdt handle : " << channel_name;
return nullptr;
}
MS_LOG(INFO) << "Success to create acltdt handle: " << channel_name;
return acl_handle;
}
return acl_handle;
}
#endif
} // namespace mindspore

+ 9
- 3
mindspore/core/utils/ms_context.h View File

@@ -25,9 +25,15 @@
#include <utility>
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#ifdef ENABLE_TDTQUE
#include "pybind11/pybind11.h"
#include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h"
using mindspore::dataset::TdtHandle;
#endif
#ifndef NO_DLIB
#include "acl/acl_tdt.h"
#endif

namespace mindspore {
enum MsBackendPolicy {
kMsBackendGeOnly = 0,
@@ -137,7 +143,8 @@ class MsContext {
std::string backend_policy() const;
bool set_backend_policy(const std::string &policy);
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *get_acl_tdt_channel_handle();
acltdtChannelHandle *CreateAclTdtChannelHandle();
void DestroyAclTdtChannelHandle();
#endif
static void device_seter(DeviceSeter device) { seter_ = device; }
static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; }
@@ -175,10 +182,9 @@ class MsContext {
uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];

MsBackendPolicy backend_policy_;
#ifdef ENABLE_TDTQUE
acltdtChannelHandle *acl_handle = nullptr;
acltdtChannelHandle *acl_handle_ = nullptr;
#endif
};



+ 0
- 3
mindspore/core/utils/ms_utils.cc View File

@@ -14,9 +14,6 @@
* limitations under the License.
*/
#include "utils/ms_utils.h"
#include <string>
#include <vector>
#include <atomic>

namespace mindspore {
namespace common {


+ 2
- 0
mindspore/core/utils/ms_utils.h View File

@@ -19,6 +19,8 @@
#include <memory>
#include <utility>
#include <string>
#include <vector>
#include <atomic>

#define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType &) = delete; \


Loading…
Cancel
Save