| @@ -18,15 +18,18 @@ endif () | |||
| if (NOT ENABLE_D) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| endif() | |||
| if (NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | |||
| add_subdirectory(ps_cache) | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | |||
| @@ -0,0 +1,217 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| using mindspore::kernel::Address; | |||
| namespace mindspore { | |||
| namespace ps { | |||
| void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size) { | |||
| if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) { | |||
| MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; | |||
| } | |||
| hash_tables_[param_name].cache_vocab_size = cache_vocab_size; | |||
| hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor; | |||
| hash_tables_[param_name].embedding_size = embedding_size; | |||
| hash_tables_[param_name].vocab_size = vocab_size; | |||
| if (vocab_size_ == 0) { | |||
| vocab_size_ = vocab_size; | |||
| } | |||
| if (cache_vocab_size_ == 0) { | |||
| cache_vocab_size_ = cache_vocab_size; | |||
| } | |||
| if (host_cache_vocab_size_ == 0) { | |||
| host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor; | |||
| } | |||
| } | |||
| void PsCacheManager::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size) { | |||
| if (cache_vocab_size == 0 || embedding_size == 0) { | |||
| MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; | |||
| } | |||
| if (new_param_name.empty() || cur_param_name.empty()) { | |||
| MS_LOG(EXCEPTION) << "Parameter name can not be empty."; | |||
| } | |||
| if (new_param_name == cur_param_name) { | |||
| return; | |||
| } | |||
| auto iter = hash_tables_.find(cur_param_name); | |||
| if (iter != hash_tables_.end()) { | |||
| hash_tables_.emplace(new_param_name, iter->second); | |||
| hash_tables_.erase(iter); | |||
| } else { | |||
| hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size; | |||
| hash_tables_[new_param_name].embedding_size = embedding_size; | |||
| } | |||
| } | |||
| void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) { | |||
| auto iter = hash_tables_.find(param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | |||
| } | |||
| auto &hash_table_info = iter->second; | |||
| hash_table_info.param_init_info_.param_type_ = kWeight; | |||
| hash_table_info.param_init_info_.global_seed_ = global_seed; | |||
| hash_table_info.param_init_info_.op_seed_ = op_seed; | |||
| } | |||
| void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { | |||
| auto iter = hash_tables_.find(param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | |||
| } | |||
| auto &hash_table_info = iter->second; | |||
| hash_table_info.param_init_info_.param_type_ = kAccumulation; | |||
| hash_table_info.param_init_info_.init_val_ = init_val; | |||
| } | |||
| void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { | |||
| if (dest_param_name == src_param_name) { | |||
| MS_LOG(INFO) << "The dest_param_name is same as src_param_name"; | |||
| return; | |||
| } | |||
| auto iter = hash_tables_.find(src_param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed."; | |||
| } | |||
| hash_tables_.emplace(dest_param_name, iter->second); | |||
| } | |||
| const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { | |||
| auto iter = hash_tables_.find(param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; | |||
| } | |||
| return iter->second.device_address; | |||
| } | |||
| void PsCacheManager::Initialize() { | |||
| MS_LOG(INFO) << "PS cache initialize."; | |||
| if (!worker.running()) { | |||
| Util::SetInternalEnvVar(); | |||
| worker.Run(); | |||
| } | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_); | |||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_); | |||
| InitParameterServer(); | |||
| AllocMemForHashTable(); | |||
| SetLocalIdRank(); | |||
| initialized_ps_cache_ = true; | |||
| } | |||
| void PsCacheManager::InitParameterServer() { | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| size_t row_count = item.second.vocab_size; | |||
| std::vector<size_t> keys{key, key, key, key}; | |||
| std::vector<float> values{ | |||
| SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; | |||
| std::vector<int64_t> lens{2, 2, 3}; | |||
| const auto &hash_table_info = item.second; | |||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||
| if (param_init_info.param_type_ == kWeight) { | |||
| lens.push_back(0); | |||
| values.push_back(SizeToFloat(param_init_info.global_seed_)); | |||
| values.push_back(SizeToFloat(param_init_info.op_seed_)); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| lens.push_back(1); | |||
| values.push_back(param_init_info.init_val_); | |||
| } | |||
| // if worker role | |||
| worker.AddEmbeddingTable(key, row_count); | |||
| worker.InitPSEmbeddingTable(keys, values, lens); | |||
| } | |||
| } | |||
| void PsCacheManager::AllocMemForHashTable() { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| size_t max_embedding_size = 0; | |||
| for (auto &item : hash_tables_) { | |||
| size_t embedding_size = item.second.embedding_size; | |||
| auto &device_address = item.second.device_address; | |||
| device_address.size = cache_vocab_size_ * embedding_size * sizeof(float); | |||
| auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size); | |||
| MS_EXCEPTION_IF_NULL(addr); | |||
| device_address.addr = addr; | |||
| auto &host_address = item.second.host_address; | |||
| auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size]; | |||
| MS_EXCEPTION_IF_NULL(host_address_ptr); | |||
| host_address = std::shared_ptr<int[]>(host_address_ptr, std::default_delete<int[]>()); | |||
| MS_EXCEPTION_IF_NULL(host_address); | |||
| max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; | |||
| } | |||
| embedding_device_cache_->hash_swap_index_addr_ = | |||
| reinterpret_cast<int *>(embedding_device_cache_->cache_->MallocMemory(batch_elements_ * sizeof(int))); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_); | |||
| embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>( | |||
| embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); | |||
| embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_); | |||
| } | |||
| void PsCacheManager::SetLocalIdRank() { | |||
| auto worker_num = ::ps::NumWorkers(); | |||
| auto worker_id = ::ps::MyRank(); | |||
| auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num)); | |||
| range_bound_.first = local_shard_size * worker_id; | |||
| range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_); | |||
| MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first | |||
| << ", rank id end:" << range_bound_.second; | |||
| } | |||
| std::string PsCacheManager::channel_name() { | |||
| std::lock_guard<std::mutex> locker(channel_mutex_); | |||
| return channel_name_; | |||
| } | |||
| void PsCacheManager::set_channel_name(const std::string channel_name) { | |||
| if (channel_name_ == channel_name) { | |||
| return; | |||
| } | |||
| std::lock_guard<std::mutex> locker(channel_mutex_); | |||
| channel_name_ = channel_name; | |||
| } | |||
| void PsCacheManager::IncreaseStep() { | |||
| if (data_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; | |||
| } | |||
| data_step_++; | |||
| set_current_graph_step(); | |||
| } | |||
| void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| if (graph_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | |||
| } | |||
| graph_step_++; | |||
| set_channel_name(channel_name); | |||
| PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | |||
| data_prase_.notify_one(); | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,186 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ | |||
| #define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <thread> | |||
| #include <atomic> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <condition_variable> | |||
| #include "utils/ms_context.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ir/tensor.h" | |||
| #include "ps/ps.h" | |||
| #include "ps/common.h" | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/ps_cache/embedding_hash_map.h" | |||
| #include "ps/ps_cache/ps_cache_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| constexpr size_t kHostCacheScaleFactor = 10; | |||
| constexpr size_t kMaxThreadNum = 16; | |||
| using mindspore::kernel::Address; | |||
| struct HashTableInfo { | |||
| size_t cache_vocab_size{0}; | |||
| size_t host_cache_vocab_size{0}; | |||
| size_t embedding_size{0}; | |||
| size_t vocab_size{0}; | |||
| Address device_address{nullptr, 0}; | |||
| std::shared_ptr<int[]> host_address{nullptr}; | |||
| }; | |||
| struct EmbeddingDeviceCache { | |||
| EmbeddingDeviceCache(size_t batch_elements, size_t cache_vocab_size) { | |||
| device_to_host_index = std::make_unique<int[]>(batch_elements); | |||
| device_to_host_ids = std::make_unique<int[]>(batch_elements); | |||
| host_to_device_index = std::make_unique<int[]>(batch_elements); | |||
| host_to_device_ids = std::make_unique<int[]>(batch_elements); | |||
| device_hash_map_ = std::make_shared<EmbeddingHashMap>(0, cache_vocab_size); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| cache_ = PsCacheFactory::Get().ps_cache(devcie_target); | |||
| } | |||
| std::unique_ptr<int[]> device_to_host_index; | |||
| std::unique_ptr<int[]> device_to_host_ids; | |||
| std::unique_ptr<int[]> host_to_device_index; | |||
| std::unique_ptr<int[]> host_to_device_ids; | |||
| int *hash_swap_index_addr_; | |||
| float *hash_swap_value_addr_; | |||
| std::shared_ptr<EmbeddingHashMap> device_hash_map_; | |||
| std::shared_ptr<PsCacheBasic> cache_; | |||
| }; | |||
| struct EmbeddingHostCache { | |||
| EmbeddingHostCache(size_t batch_elements, size_t host_cache_vocab_size) { | |||
| host_to_server_index = std::make_unique<int[]>(batch_elements); | |||
| host_to_server_ids = std::make_unique<int[]>(batch_elements); | |||
| server_to_host_index = std::make_unique<int[]>(batch_elements); | |||
| server_to_host_ids = std::make_unique<int[]>(batch_elements); | |||
| host_to_device_index = std::make_unique<int[]>(batch_elements); | |||
| device_to_host_index = std::make_unique<int[]>(batch_elements); | |||
| host_hash_map_ = std::make_shared<EmbeddingHashMap>(0, host_cache_vocab_size); | |||
| } | |||
| std::unique_ptr<int[]> host_to_server_index; | |||
| std::unique_ptr<int[]> host_to_server_ids; | |||
| std::unique_ptr<int[]> server_to_host_index; | |||
| std::unique_ptr<int[]> server_to_host_ids; | |||
| std::unique_ptr<int[]> host_to_device_index; | |||
| std::unique_ptr<int[]> device_to_host_index; | |||
| std::shared_ptr<EmbeddingHashMap> host_hash_map_; | |||
| }; | |||
| struct PsCacheStatisticsInfo { | |||
| size_t batch_id_unique_count_{0}; | |||
| size_t device_to_host_size_{0}; | |||
| size_t host_to_device_size_{0}; | |||
| size_t host_to_server_size_{0}; | |||
| size_t server_to_host_size_{0}; | |||
| size_t hash_hit_count_{0}; | |||
| size_t mem_cache_swap_out_size_{0}; | |||
| size_t mem_cache_swap_in_size_{0}; | |||
| size_t mem_cache_hit_count_{0}; | |||
| }; | |||
| class PsCacheManager { | |||
| public: | |||
| static PsCacheManager &GetInstance() { | |||
| static PsCacheManager instance; | |||
| return instance; | |||
| } | |||
| void Initialize(); | |||
| void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size); | |||
| void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed); | |||
| void InsertAccumuInitInfo(const std::string ¶m_name, float init_val); | |||
| void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size); | |||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); | |||
| const Address &QueryHashTableAddr(const std::string ¶m_name) const; | |||
| bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } | |||
| void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } | |||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | |||
| void DoProcessData(uint32_t device_id, void *context); | |||
| void IncreaseGraphStep(const std::string &channel_name); | |||
| void DumpHashTables() const; | |||
| private: | |||
| PsCacheManager() = default; | |||
| ~PsCacheManager() = default; | |||
| PsCacheManager(const PsCacheManager &) = delete; | |||
| PsCacheManager &operator=(const PsCacheManager &) = delete; | |||
| void IncreaseStep(); | |||
| void set_current_graph_step() { graph_running_step_ = graph_step_; } | |||
| std::string channel_name(); | |||
| void set_channel_name(const std::string channel_name); | |||
| void InitParameterServer(); | |||
| void AllocMemForHashTable(); | |||
| void SetLocalIdRank(); | |||
| void ProcessDataTask(uint32_t device_id, void *context); | |||
| void ProcessData(); | |||
| void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||
| void WaitGraphRun(); | |||
| int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); | |||
| void ParseHostDataHostToDevice(size_t id); | |||
| void ParseHostDataDeviceToHost(size_t id); | |||
| void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| void HashSwapHostToDevice(const HashTableInfo &hash_info); | |||
| void HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||
| void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||
| void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||
| void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||
| float *hash_table_addr); | |||
| void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool initialized_ps_cache_{false}; | |||
| std::string channel_name_; | |||
| std::mutex channel_mutex_; | |||
| std::atomic_ulong graph_step_{0}; | |||
| size_t graph_running_step_{0}; | |||
| size_t data_step_{0}; | |||
| std::mutex data_mutex_; | |||
| std::condition_variable data_prase_; | |||
| std::map<std::string, HashTableInfo> hash_tables_; | |||
| std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | |||
| std::shared_ptr<EmbeddingHostCache> embedding_host_cache_; | |||
| size_t vocab_size_{0}; | |||
| size_t cache_vocab_size_{0}; | |||
| size_t host_cache_vocab_size_{0}; | |||
| size_t batch_elements_{0}; | |||
| PsCacheStatisticsInfo statistics_info_; | |||
| std::pair<size_t, size_t> range_bound_; | |||
| }; | |||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ | |||
| @@ -141,6 +141,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info. | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc") | |||