From: @limingqi107 Reviewed-by: @cristoval,@jjfeing Signed-off-by: @jjfeingtags/v1.1.0
| @@ -24,6 +24,7 @@ if (NOT ENABLE_GPU) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | ||||
| endif() | endif() | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | ||||
| add_subdirectory(ps_cache) | add_subdirectory(ps_cache) | ||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/ps_cache/embedding_hash_map.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||||
| const size_t graph_running_step, size_t *swap_out_size) { | |||||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_size); | |||||
| auto hash_index = Hash(id); | |||||
| auto need_swap = NeedSwap(); | |||||
| size_t loop = 0; | |||||
| while (true) { | |||||
| if (loop++ == hash_capacity_) { | |||||
| return INVALID_INDEX_VALUE; | |||||
| } | |||||
| if (hash_map_unit_[hash_index].IsEmpty()) { | |||||
| hash_count_++; | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_unit_[hash_index].set_id(id); | |||||
| hash_map_unit_[hash_index].set_step(data_step); | |||||
| return hash_index; | |||||
| } else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) { | |||||
| // Need swap out from the hash table. | |||||
| swap_out_index[*swap_out_size] = hash_index; | |||||
| swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_; | |||||
| (*swap_out_size)++; | |||||
| (void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_); | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_unit_[hash_index].set_id(id); | |||||
| hash_map_unit_[hash_index].set_step(data_step); | |||||
| return hash_index; | |||||
| } | |||||
| hash_index = (hash_index + 1) % hash_capacity_; | |||||
| } | |||||
| } | |||||
| void EmbeddingHashMap::DumpHashMap() { | |||||
| MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_; | |||||
| MS_LOG(INFO) << "Dump hash_id_to_index: "; | |||||
| for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) { | |||||
| MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; | |||||
| } | |||||
| MS_LOG(INFO) << "Dump hash_map_unit: "; | |||||
| for (size_t i = 0; i < hash_map_unit_.size(); i++) { | |||||
| if (!hash_map_unit_[i].IsEmpty()) { | |||||
| MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Dump hash map info end."; | |||||
| } | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||||
| #include <math.h> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include "utils/convert_utils_base.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| static const size_t INVALID_STEP_VALUE = 0; | |||||
| static const int INVALID_INDEX_VALUE = -1; | |||||
| struct HashMapElement { | |||||
| int id_; | |||||
| size_t step_; | |||||
| bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | |||||
| bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | |||||
| void set_id(int id) { id_ = id; } | |||||
| void set_step(size_t step) { step_ = step; } | |||||
| }; | |||||
| // Hash table is held in device, HashMap is used to manage hash table in host. | |||||
| class EmbeddingHashMap { | |||||
| public: | |||||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | |||||
| hash_map_unit_.resize(hash_capacity); | |||||
| } | |||||
| virtual ~EmbeddingHashMap() = default; | |||||
| int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||||
| const size_t graph_running_step, size_t *swap_out_size); | |||||
| std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); } | |||||
| bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const { | |||||
| return iter != hash_id_to_index_.end(); | |||||
| } | |||||
| size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; } | |||||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); } | |||||
| void DumpHashMap(); | |||||
| private: | |||||
| int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } | |||||
| bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | |||||
| size_t hash_count_; | |||||
| size_t hash_capacity_; | |||||
| std::vector<HashMapElement> hash_map_unit_; | |||||
| std::unordered_map<int, int> hash_id_to_index_; | |||||
| }; | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||||
| @@ -0,0 +1,110 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) { | |||||
| if (cache_enable_ == false) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ")."; | |||||
| auto iter = ps_data_channel_map_.find(channel_name); | |||||
| if (iter != ps_data_channel_map_.end()) { | |||||
| MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; | |||||
| auto channel = iter->second; | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| channel->set_step_num(step_num); | |||||
| } else { | |||||
| auto channel = std::make_shared<PsDataChannel>(channel_name, step_num); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| (void)ps_data_channel_map_.emplace(channel_name, channel); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { | |||||
| auto iter = ps_data_channel_map_.find(channel_name); | |||||
| if (iter == ps_data_channel_map_.end()) { | |||||
| MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||||
| if (cache_enable_ == false) { | |||||
| return; | |||||
| } | |||||
| if (data == nullptr) { | |||||
| MS_LOG(WARNING) << "No data prefetch."; | |||||
| return; | |||||
| } | |||||
| auto channel = ps_data_channel(channel_name); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| channel->set_data(data, data_size); | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| data_ready_ = true; | |||||
| data_process_.notify_one(); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { | |||||
| return; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; | |||||
| } | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||||
| } | |||||
| void PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||||
| if (cache_enable_ == false) { | |||||
| return; | |||||
| } | |||||
| auto channel = ps_data_channel(channel_name); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| channel->ResetData(); | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| data_ready_ = false; | |||||
| data_prefetch_.notify_one(); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { | |||||
| return; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; | |||||
| } | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; | |||||
| } | |||||
| void *PsDataPrefetch::data(const std::string &channel_name) const { | |||||
| auto channel = ps_data_channel(channel_name); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| return channel->data(); | |||||
| } | |||||
| size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | |||||
| auto channel = ps_data_channel(channel_name); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| return channel->data_size(); | |||||
| } | |||||
| void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||||
| auto channel = ps_data_channel(channel_name); | |||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| channel->TryWakeChannel(); | |||||
| } | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <condition_variable> | |||||
| #include "ps/ps_cache/ps_data/ps_data_channel.h" | |||||
| #define EXPORT __attribute__((visibility("default"))) | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| class PsDataPrefetch { | |||||
| public: | |||||
| EXPORT static PsDataPrefetch &GetInstance() { | |||||
| static PsDataPrefetch instance; | |||||
| return instance; | |||||
| } | |||||
| EXPORT bool cache_enable() const { return cache_enable_; } | |||||
| EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | |||||
| EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | |||||
| EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||||
| EXPORT void FinalizeData(const std::string &channel_name); | |||||
| EXPORT void *data(const std::string &channel_name) const; | |||||
| EXPORT size_t data_size(const std::string &channel_name) const; | |||||
| EXPORT void TryWakeChannel(const std::string &channel_name); | |||||
| private: | |||||
| PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} | |||||
| virtual ~PsDataPrefetch() = default; | |||||
| PsDataPrefetch(const PsDataPrefetch &) = delete; | |||||
| PsDataPrefetch &operator=(const PsDataPrefetch &) = delete; | |||||
| std::shared_ptr<PsDataChannel> ps_data_channel(const std::string &channel_name) const; | |||||
| std::map<std::string, std::shared_ptr<PsDataChannel>> ps_data_channel_map_; | |||||
| bool cache_enable_; | |||||
| bool data_ready_; | |||||
| std::mutex data_mutex_; | |||||
| std::condition_variable data_prefetch_; | |||||
| std::condition_variable data_process_; | |||||
| }; | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||||