From: @zyli2020 Reviewed-by: @cristoval,@jjfeing Signed-off-by: @cristovaltags/v1.2.0-rc1
| @@ -353,7 +353,10 @@ bool PsCacheManager::ProcessData() { | |||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| const uint64_t kUSecondInSecond = 1000000; | const uint64_t kUSecondInSecond = 1000000; | ||||
| (void)gettimeofday(&start_time, nullptr); | (void)gettimeofday(&start_time, nullptr); | ||||
| auto data = PsDataPrefetch::GetInstance().data(channel_name_); | |||||
| void *data = nullptr; | |||||
| if (!PsDataPrefetch::GetInstance().QueryData(channel_name_, &data)) { | |||||
| return false; | |||||
| } | |||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| @@ -53,6 +53,7 @@ bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, c | |||||
| const std::string supported_data_type = "int32"; | const std::string supported_data_type = "int32"; | ||||
| if (data_type != supported_data_type) { | if (data_type != supported_data_type) { | ||||
| MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; | MS_LOG(ERROR) << "Parameter server cache mode need input id with data type[int32], but got[" << data_type << "]"; | ||||
| invalid_data_type_ = true; | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| @@ -105,12 +106,20 @@ bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void *PsDataPrefetch::data(const std::string &channel_name) const { | |||||
| bool PsDataPrefetch::QueryData(const std::string &channel_name, void **data_ptr) const { | |||||
| if (invalid_data_type_) { | |||||
| return false; | |||||
| } | |||||
| if (data_ptr == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| if (channel == nullptr) { | if (channel == nullptr) { | ||||
| return nullptr; | |||||
| *data_ptr = nullptr; | |||||
| return true; | |||||
| } | } | ||||
| return channel->data(); | |||||
| *data_ptr = channel->data(); | |||||
| return true; | |||||
| } | } | ||||
| size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | ||||
| @@ -41,7 +41,7 @@ class EXPORT PsDataPrefetch { | |||||
| const std::string &data_type); | const std::string &data_type); | ||||
| EXPORT bool FinalizeData(const std::string &channel_name); | EXPORT bool FinalizeData(const std::string &channel_name); | ||||
| EXPORT void NotifyFinalize(); | EXPORT void NotifyFinalize(); | ||||
| EXPORT void *data(const std::string &channel_name) const; | |||||
| EXPORT bool QueryData(const std::string &channel_name, void **data_ptr) const; | |||||
| EXPORT size_t data_size(const std::string &channel_name) const; | EXPORT size_t data_size(const std::string &channel_name) const; | ||||
| EXPORT bool TryWakeChannel(const std::string &channel_name); | EXPORT bool TryWakeChannel(const std::string &channel_name); | ||||
| @@ -59,6 +59,7 @@ class EXPORT PsDataPrefetch { | |||||
| std::condition_variable data_prefetch_; | std::condition_variable data_prefetch_; | ||||
| std::condition_variable data_process_; | std::condition_variable data_process_; | ||||
| std::atomic_bool need_wait_{true}; | std::atomic_bool need_wait_{true}; | ||||
| std::atomic_bool invalid_data_type_{false}; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||