From: @zyli2020 Reviewed-by: @cristoval Signed-off-by:tags/v1.2.0-rc1
| @@ -301,10 +301,9 @@ Status DeviceQueueOp::PushDataToGPU() { | |||||
| } | } | ||||
| // Data prefetch only when PS mode enables cache. | // Data prefetch only when PS mode enables cache. | ||||
| if (items.size() > 0) { | |||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { | |||||
| return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); | |||||
| } | |||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_, | |||||
| items[0].data_type_)) { | |||||
| return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); | |||||
| } | } | ||||
| while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { | while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { | ||||
| BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | ||||
| @@ -434,6 +433,11 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items, | |||||
| if (sub_item.data_ptr_ == nullptr) { | if (sub_item.data_ptr_ == nullptr) { | ||||
| return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed."); | return Status(StatusCode::kMDOutOfMemory, __LINE__, __FILE__, "Memory malloc failed."); | ||||
| } | } | ||||
| if (curr_row[i] == nullptr) { | |||||
| MS_LOG(ERROR) << "The pointer curr_row[" << i << "] is null"; | |||||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "TensorRow 'curr_row' contains nullptr."); | |||||
| } | |||||
| sub_item.data_type_ = curr_row[i]->type().ToString(); | |||||
| const unsigned char *column_data = curr_row[i]->GetBuffer(); | const unsigned char *column_data = curr_row[i]->GetBuffer(); | ||||
| if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, | if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, | ||||
| static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) { | static_cast<uint32_t>(curr_row[i++]->SizeInBytes())) != 0) { | ||||
| @@ -55,7 +55,8 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe | |||||
| #if ENABLE_D | #if ENABLE_D | ||||
| // Data prefetch only when PS mode enables cache. | // Data prefetch only when PS mode enables cache. | ||||
| if (items.size() > 0) { | if (items.size() > 0) { | ||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { | |||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_, | |||||
| items[0].tensorType_)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -44,10 +44,17 @@ std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string | |||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||||
| bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size, | |||||
| const std::string &data_type) { | |||||
| if (cache_enable_ == false) { | if (cache_enable_ == false) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| // In ps cache mode, input ids are from dataset and data type transmitted from minddata must be 'int32' | |||||
| const std::string supported_data_type = "int32"; | |||||
| if (data_type != supported_data_type) { | |||||
| MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; | |||||
| return false; | |||||
| } | |||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(WARNING) << "No data prefetch."; | MS_LOG(WARNING) << "No data prefetch."; | ||||
| return true; | return true; | ||||
| @@ -37,7 +37,8 @@ class EXPORT PsDataPrefetch { | |||||
| EXPORT bool cache_enable() const { return cache_enable_; } | EXPORT bool cache_enable() const { return cache_enable_; } | ||||
| EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | ||||
| EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | ||||
| EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||||
| EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size, | |||||
| const std::string &data_type); | |||||
| EXPORT bool FinalizeData(const std::string &channel_name); | EXPORT bool FinalizeData(const std::string &channel_name); | ||||
| EXPORT void NotifyFinalize(); | EXPORT void NotifyFinalize(); | ||||
| EXPORT void *data(const std::string &channel_name) const; | EXPORT void *data(const std::string &channel_name) const; | ||||
| @@ -34,6 +34,7 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, | |||||
| struct DataItemGpu { | struct DataItemGpu { | ||||
| int32_t worker_id_; | int32_t worker_id_; | ||||
| std::string data_type_; | |||||
| size_t data_len_; | size_t data_len_; | ||||
| void *data_ptr_; | void *data_ptr_; | ||||
| }; | }; | ||||