|
|
|
@@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c |
|
|
|
const std::string supported_data_type = "int32"; |
|
|
|
if (data_type != supported_data_type) { |
|
|
|
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; |
|
|
|
invalid_data_type_ = true; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (data == nullptr) { |
|
|
|
@@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void *PsDataPrefetch::data(const std::string &channel_name) const { |
|
|
|
bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const { |
|
|
|
if (invalid_data_type_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (data_ptr == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto channel = ps_data_channel(channel_name); |
|
|
|
if (channel == nullptr) { |
|
|
|
return nullptr; |
|
|
|
*data_ptr = nullptr; |
|
|
|
return true; |
|
|
|
} |
|
|
|
return channel->data(); |
|
|
|
*data_ptr = channel->data(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
size_t PsDataPrefetch::data_size(const std::string &channel_name) const { |
|
|
|
|