Browse Source

fix ps cache process data thread can not exit

tags/v1.2.0-rc1
lizhenyu 4 years ago
parent
commit
8bddfba9e2
3 changed files with 18 additions and 5 deletions
  1. +4
    -1
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc
  2. +12
    -3
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc
  3. +2
    -1
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h

+ 4
- 1
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -353,7 +353,10 @@ bool PsCacheManager::ProcessData() {
struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr);
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
void *data = nullptr;
if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) {
return false;
}
if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
std::unique_lock<std::mutex> locker(data_mutex_);


+ 12
- 3
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc View File

@@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c
const std::string supported_data_type = "int32";
if (data_type != supported_data_type) {
MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]";
invalid_data_type_ = true;
return false;
}
if (data == nullptr) {
@@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
return false;
}

void *PsDataPrefetch::data(const std::string &channel_name) const {
bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const {
if (invalid_data_type_) {
return false;
}
if (data_ptr == nullptr) {
return false;
}
auto channel = ps_data_channel(channel_name);
if (channel == nullptr) {
return nullptr;
*data_ptr = nullptr;
return true;
}
return channel->data();
*data_ptr = channel->data();
return true;
}

size_t PsDataPrefetch::data_size(const std::string &channel_name) const {


+ 2
- 1
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h View File

@@ -41,7 +41,7 @@ class EXPORT PsDataPrefetch {
const std::string &data_type);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;
EXPORT bool QueryData(const std::string &channel_name, void **data_ptr) const;
EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT bool TryWakeChannel(const std::string &channel_name);

@@ -59,6 +59,7 @@ class EXPORT PsDataPrefetch {
std::condition_variable data_prefetch_;
std::condition_variable data_process_;
std::atomic_bool need_wait_{true};
std::atomic_bool invalid_data_type_{false};
};
} // namespace ps
} // namespace mindspore


Loading…
Cancel
Save