From: @limingqi107 Reviewed-by: @cristoval,@jjfeing Signed-off-by: @jjfeingtags/v1.1.0
| @@ -24,6 +24,7 @@ if (NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | |||
| add_subdirectory(ps_cache) | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/embedding_hash_map.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||
| const size_t graph_running_step, size_t *swap_out_size) { | |||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||
| MS_EXCEPTION_IF_NULL(swap_out_size); | |||
| auto hash_index = Hash(id); | |||
| auto need_swap = NeedSwap(); | |||
| size_t loop = 0; | |||
| while (true) { | |||
| if (loop++ == hash_capacity_) { | |||
| return INVALID_INDEX_VALUE; | |||
| } | |||
| if (hash_map_unit_[hash_index].IsEmpty()) { | |||
| hash_count_++; | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_unit_[hash_index].set_id(id); | |||
| hash_map_unit_[hash_index].set_step(data_step); | |||
| return hash_index; | |||
| } else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) { | |||
| // Need swap out from the hash table. | |||
| swap_out_index[*swap_out_size] = hash_index; | |||
| swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_; | |||
| (*swap_out_size)++; | |||
| (void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_); | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_unit_[hash_index].set_id(id); | |||
| hash_map_unit_[hash_index].set_step(data_step); | |||
| return hash_index; | |||
| } | |||
| hash_index = (hash_index + 1) % hash_capacity_; | |||
| } | |||
| } | |||
| void EmbeddingHashMap::DumpHashMap() { | |||
| MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_; | |||
| MS_LOG(INFO) << "Dump hash_id_to_index: "; | |||
| for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) { | |||
| MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; | |||
| } | |||
| MS_LOG(INFO) << "Dump hash_map_unit: "; | |||
| for (size_t i = 0; i < hash_map_unit_.size(); i++) { | |||
| if (!hash_map_unit_[i].IsEmpty()) { | |||
| MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Dump hash map info end."; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||
| #include <math.h> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "utils/convert_utils_base.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| static const size_t INVALID_STEP_VALUE = 0; | |||
| static const int INVALID_INDEX_VALUE = -1; | |||
| struct HashMapElement { | |||
| int id_; | |||
| size_t step_; | |||
| bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | |||
| bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | |||
| void set_id(int id) { id_ = id; } | |||
| void set_step(size_t step) { step_ = step; } | |||
| }; | |||
| // Hash table is held in device, HashMap is used to manage hash table in host. | |||
| class EmbeddingHashMap { | |||
| public: | |||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | |||
| hash_map_unit_.resize(hash_capacity); | |||
| } | |||
| virtual ~EmbeddingHashMap() = default; | |||
| int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||
| const size_t graph_running_step, size_t *swap_out_size); | |||
| std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); } | |||
| bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const { | |||
| return iter != hash_id_to_index_.end(); | |||
| } | |||
| size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; } | |||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); } | |||
| void DumpHashMap(); | |||
| private: | |||
| int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } | |||
| bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | |||
| size_t hash_count_; | |||
| size_t hash_capacity_; | |||
| std::vector<HashMapElement> hash_map_unit_; | |||
| std::unordered_map<int, int> hash_id_to_index_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) { | |||
| if (cache_enable_ == false) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ")."; | |||
| auto iter = ps_data_channel_map_.find(channel_name); | |||
| if (iter != ps_data_channel_map_.end()) { | |||
| MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; | |||
| auto channel = iter->second; | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| channel->set_step_num(step_num); | |||
| } else { | |||
| auto channel = std::make_shared<PsDataChannel>(channel_name, step_num); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| (void)ps_data_channel_map_.emplace(channel_name, channel); | |||
| } | |||
| } | |||
| std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { | |||
| auto iter = ps_data_channel_map_.find(channel_name); | |||
| if (iter == ps_data_channel_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; | |||
| } | |||
| return iter->second; | |||
| } | |||
| void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||
| if (cache_enable_ == false) { | |||
| return; | |||
| } | |||
| if (data == nullptr) { | |||
| MS_LOG(WARNING) << "No data prefetch."; | |||
| return; | |||
| } | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| channel->set_data(data, data_size); | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_ready_ = true; | |||
| data_process_.notify_one(); | |||
| for (int i = 0; i < 10; i++) { | |||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { | |||
| return; | |||
| } else { | |||
| MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||
| } | |||
| void PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||
| if (cache_enable_ == false) { | |||
| return; | |||
| } | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| channel->ResetData(); | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_ready_ = false; | |||
| data_prefetch_.notify_one(); | |||
| for (int i = 0; i < 10; i++) { | |||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { | |||
| return; | |||
| } else { | |||
| MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; | |||
| } | |||
| void *PsDataPrefetch::data(const std::string &channel_name) const { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| return channel->data(); | |||
| } | |||
| size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| return channel->data_size(); | |||
| } | |||
| void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| channel->TryWakeChannel(); | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <condition_variable> | |||
| #include "ps/ps_cache/ps_data/ps_data_channel.h" | |||
| #define EXPORT __attribute__((visibility("default"))) | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class PsDataPrefetch { | |||
| public: | |||
| EXPORT static PsDataPrefetch &GetInstance() { | |||
| static PsDataPrefetch instance; | |||
| return instance; | |||
| } | |||
| EXPORT bool cache_enable() const { return cache_enable_; } | |||
| EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | |||
| EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | |||
| EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||
| EXPORT void FinalizeData(const std::string &channel_name); | |||
| EXPORT void *data(const std::string &channel_name) const; | |||
| EXPORT size_t data_size(const std::string &channel_name) const; | |||
| EXPORT void TryWakeChannel(const std::string &channel_name); | |||
| private: | |||
| PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} | |||
| virtual ~PsDataPrefetch() = default; | |||
| PsDataPrefetch(const PsDataPrefetch &) = delete; | |||
| PsDataPrefetch &operator=(const PsDataPrefetch &) = delete; | |||
| std::shared_ptr<PsDataChannel> ps_data_channel(const std::string &channel_name) const; | |||
| std::map<std::string, std::shared_ptr<PsDataChannel>> ps_data_channel_map_; | |||
| bool cache_enable_; | |||
| bool data_ready_; | |||
| std::mutex data_mutex_; | |||
| std::condition_variable data_prefetch_; | |||
| std::condition_variable data_process_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ | |||