|
|
|
@@ -116,6 +116,7 @@ size_t PsDataPrefetch::data_size(const std::string &channel_name) const { |
|
|
|
|
|
|
|
void PsDataPrefetch::NotifyFinalize() { |
|
|
|
need_wait_ = false; |
|
|
|
WakeAllChannel(); |
|
|
|
data_prefetch_.notify_one(); |
|
|
|
data_process_.notify_one(); |
|
|
|
} |
|
|
|
@@ -128,5 +129,15 @@ bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { |
|
|
|
channel->TryWakeChannel(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void PsDataPrefetch::WakeAllChannel() { |
|
|
|
for (auto iter = ps_data_channel_map_.begin(); iter != ps_data_channel_map_.end(); iter++) { |
|
|
|
auto channel = iter->second; |
|
|
|
if (channel == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
channel->TryWakeChannel(true); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ps |
|
|
|
} // namespace mindspore |