Browse Source

!12165 add input data type check for ps cache mode

From: @zyli2020
Reviewed-by: @cristoval
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
a3057441d6
5 changed files with 21 additions and 7 deletions
  1. +8
    -4
      mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
  2. +2
    -1
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
  3. +8
    -1
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc
  4. +2
    -1
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h
  5. +1
    -0
      mindspore/ccsrc/runtime/device/gpu/blocking_queue.h

+ 8
- 4
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc View File

@@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() {
}

// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
items[0].data_type_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
@@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
if (sub_item.data_ptr_ == nullptr) {
return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed.");
}
if (curr_row[i] == nullptr) {
MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null";
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr.");
}
sub_item.data_type_ = curr_row[i]->type().ToString();
const unsigned char *column_data = curr_row[i]->GetBuffer();
if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data,
static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) {


+ 2
- 1
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc View File

@@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
#if ENABLE_D
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_,
items[0].tensorType_)) {
return FAILED;
}
}


+ 8
- 1
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc View File

@@ -44,10 +44,17 @@ std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string
return iter->second;
}

bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type) {
if (cache_enable_ == false) {
return true;
}
// In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32'
const std::string supported_data_type = "int32";
if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
return false;
}
if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch.";
return true;


+ 2
- 1
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h View File

@@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch {
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size,
const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;


+ 1
- 0
mindspore/ccsrc/runtime/device/gpu/blocking_queue.h View File

@@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,

struct DataItemGpu {
int32_t worker_id_;
std::string data_type_;
size_t data_len_;
void *data_ptr_;
};


Loading…
Cancel
Save