diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index 5fafca3f1f..c45c7d281f 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -18,15 +18,18 @@ endif () if (NOT ENABLE_D) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") endif() if (NOT ENABLE_GPU) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") endif() list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") add_subdirectory(ps_cache) + set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc new file mode 100644 index 0000000000..f2c66a3da5 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -0,0 +1,217 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "ps/ps_cache/ps_cache_manager.h" +#include "utils/log_adapter.h" +#include "utils/ms_utils.h" + +using mindspore::kernel::Address; +namespace mindspore { +namespace ps { +void PsCacheManager::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, + size_t vocab_size) { + if (cache_vocab_size == 0 || embedding_size == 0 || vocab_size == 0) { + MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; + } + hash_tables_[param_name].cache_vocab_size = cache_vocab_size; + hash_tables_[param_name].host_cache_vocab_size = cache_vocab_size * kHostCacheScaleFactor; + hash_tables_[param_name].embedding_size = embedding_size; + hash_tables_[param_name].vocab_size = vocab_size; + + if (vocab_size_ == 0) { + vocab_size_ = vocab_size; + } + if (cache_vocab_size_ == 0) { + cache_vocab_size_ = cache_vocab_size; + } + if (host_cache_vocab_size_ == 0) { + host_cache_vocab_size_ = cache_vocab_size * kHostCacheScaleFactor; + } +} + +void PsCacheManager::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, + size_t cache_vocab_size, size_t embedding_size) { + if (cache_vocab_size == 0 || embedding_size == 0) { + MS_LOG(EXCEPTION) << "The size of hash table can not equal to zero."; + } + if (new_param_name.empty() || cur_param_name.empty()) { + MS_LOG(EXCEPTION) << "Parameter name can not be empty."; + } + if (new_param_name == cur_param_name) { + return; + } + auto iter = hash_tables_.find(cur_param_name); + if (iter != hash_tables_.end()) { + hash_tables_.emplace(new_param_name, iter->second); + hash_tables_.erase(iter); + } else { + hash_tables_[new_param_name].cache_vocab_size = cache_vocab_size; + hash_tables_[new_param_name].embedding_size = embedding_size; + } +} + +void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; + } + auto &hash_table_info = iter->second; + hash_table_info.param_init_info_.param_type_ = kWeight; + hash_table_info.param_init_info_.global_seed_ = global_seed; + hash_table_info.param_init_info_.op_seed_ = op_seed; +} + +void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; + } + auto &hash_table_info = iter->second; + hash_table_info.param_init_info_.param_type_ = kAccumulation; + hash_table_info.param_init_info_.init_val_ = init_val; +} + +void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { + if (dest_param_name == src_param_name) { + MS_LOG(INFO) << "The dest_param_name is same as src_param_name"; + return; + } + auto iter = hash_tables_.find(src_param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "The source hash table[" << src_param_name << "] does not exist, clone failed."; + } + hash_tables_.emplace(dest_param_name, iter->second); +} + +const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; + } + return iter->second.device_address; +} + +void PsCacheManager::Initialize() { + MS_LOG(INFO) << "PS cache initialize."; + if (!worker.running()) { + Util::SetInternalEnvVar(); + worker.Run(); + } + embedding_device_cache_ = std::make_shared(batch_elements_, cache_vocab_size_); + embedding_host_cache_ = std::make_shared(batch_elements_, host_cache_vocab_size_); + InitParameterServer(); + AllocMemForHashTable(); + SetLocalIdRank(); + initialized_ps_cache_ = true; +} + +void PsCacheManager::InitParameterServer() { + for (const auto &item : hash_tables_) { + const auto ¶m_name = item.first; + size_t key = worker.SetParamKey(param_name); + size_t row_count = item.second.vocab_size; + std::vector keys{key, key, key, key}; + std::vector values{ + SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; + std::vector lens{2, 2, 3}; + const auto &hash_table_info = item.second; + const auto ¶m_init_info = hash_table_info.param_init_info_; + if (param_init_info.param_type_ == kWeight) { + lens.push_back(0); + values.push_back(SizeToFloat(param_init_info.global_seed_)); + values.push_back(SizeToFloat(param_init_info.op_seed_)); + } else if (param_init_info.param_type_ == kAccumulation) { + lens.push_back(1); + values.push_back(param_init_info.init_val_); + } + // if worker role + worker.AddEmbeddingTable(key, row_count); + worker.InitPSEmbeddingTable(keys, values, lens); + } +} + +void PsCacheManager::AllocMemForHashTable() { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + size_t max_embedding_size = 0; + for (auto &item : hash_tables_) { + size_t embedding_size = item.second.embedding_size; + auto &device_address = item.second.device_address; + device_address.size = cache_vocab_size_ * embedding_size * sizeof(float); + auto addr = embedding_device_cache_->cache_->MallocMemory(device_address.size); + MS_EXCEPTION_IF_NULL(addr); + device_address.addr = addr; + + auto &host_address = item.second.host_address; + auto host_address_ptr = new int[host_cache_vocab_size_ * embedding_size]; + MS_EXCEPTION_IF_NULL(host_address_ptr); + host_address = std::shared_ptr(host_address_ptr, std::default_delete()); + MS_EXCEPTION_IF_NULL(host_address); + + max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; + } + embedding_device_cache_->hash_swap_index_addr_ = + reinterpret_cast(embedding_device_cache_->cache_->MallocMemory(batch_elements_ * sizeof(int))); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_index_addr_); + embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast( + embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); + embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_); +} + +void PsCacheManager::SetLocalIdRank() { + auto worker_num = ::ps::NumWorkers(); + auto worker_id = ::ps::MyRank(); + auto local_shard_size = FloatToSize(std::ceil(SizeToFloat(vocab_size_) / worker_num)); + range_bound_.first = local_shard_size * worker_id; + range_bound_.second = std::min(range_bound_.first + local_shard_size, vocab_size_); + MS_LOG(INFO) << "Worker num:" << worker_num << ", worker id:" << worker_id << ", rank id begin:" << range_bound_.first + << ", rank id end:" << range_bound_.second; +} + +std::string PsCacheManager::channel_name() { + std::lock_guard locker(channel_mutex_); + return channel_name_; +} + +void PsCacheManager::set_channel_name(const std::string channel_name) { + if (channel_name_ == channel_name) { + return; + } + std::lock_guard locker(channel_mutex_); + channel_name_ = channel_name; +} + +void PsCacheManager::IncreaseStep() { + if (data_step_ >= UINT64_MAX) { + MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; + } + data_step_++; + set_current_graph_step(); +} + +void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { + if (graph_step_ >= UINT64_MAX) { + MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; + } + graph_step_++; + set_channel_name(channel_name); + PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); + data_prase_.notify_one(); +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h new file mode 100644 index 0000000000..c36a50b2ee --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -0,0 +1,186 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/ms_context.h" +#include "backend/kernel_compiler/kernel.h" +#include "utils/shape_utils.h" +#include "ir/tensor.h" +#include "ps/ps.h" +#include "ps/common.h" +#include "ps/worker.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/ps_cache/embedding_hash_map.h" +#include "ps/ps_cache/ps_cache_factory.h" + +namespace mindspore { +namespace ps { +constexpr size_t kHostCacheScaleFactor = 10; +constexpr size_t kMaxThreadNum = 16; +using mindspore::kernel::Address; + +struct HashTableInfo { + size_t cache_vocab_size{0}; + size_t host_cache_vocab_size{0}; + size_t embedding_size{0}; + size_t vocab_size{0}; + Address device_address{nullptr, 0}; + std::shared_ptr host_address{nullptr}; +}; + +struct EmbeddingDeviceCache { + EmbeddingDeviceCache(size_t batch_elements, size_t cache_vocab_size) { + device_to_host_index = std::make_unique(batch_elements); + device_to_host_ids = std::make_unique(batch_elements); + host_to_device_index = std::make_unique(batch_elements); + host_to_device_ids = std::make_unique(batch_elements); + device_hash_map_ = std::make_shared(0, cache_vocab_size); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto devcie_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + cache_ = PsCacheFactory::Get().ps_cache(devcie_target); + } + std::unique_ptr device_to_host_index; + std::unique_ptr device_to_host_ids; + std::unique_ptr host_to_device_index; + std::unique_ptr host_to_device_ids; + int *hash_swap_index_addr_; + float *hash_swap_value_addr_; + std::shared_ptr device_hash_map_; + std::shared_ptr cache_; +}; + +struct EmbeddingHostCache { + EmbeddingHostCache(size_t batch_elements, size_t host_cache_vocab_size) { + host_to_server_index = std::make_unique(batch_elements); + host_to_server_ids = std::make_unique(batch_elements); + server_to_host_index = std::make_unique(batch_elements); + server_to_host_ids = std::make_unique(batch_elements); + host_to_device_index = std::make_unique(batch_elements); + device_to_host_index = std::make_unique(batch_elements); + host_hash_map_ = std::make_shared(0, host_cache_vocab_size); + } + std::unique_ptr host_to_server_index; + std::unique_ptr host_to_server_ids; + std::unique_ptr server_to_host_index; + std::unique_ptr server_to_host_ids; + std::unique_ptr host_to_device_index; + std::unique_ptr device_to_host_index; + std::shared_ptr host_hash_map_; +}; + +struct PsCacheStatisticsInfo { + size_t batch_id_unique_count_{0}; + size_t device_to_host_size_{0}; + size_t host_to_device_size_{0}; + size_t host_to_server_size_{0}; + size_t server_to_host_size_{0}; + size_t hash_hit_count_{0}; + size_t mem_cache_swap_out_size_{0}; + size_t mem_cache_swap_in_size_{0}; + size_t mem_cache_hit_count_{0}; +}; + +class PsCacheManager { + public: + static PsCacheManager &GetInstance() { + static PsCacheManager instance; + return instance; + } + void Initialize(); + void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, + size_t vocab_size); + void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed); + void InsertAccumuInitInfo(const std::string ¶m_name, float init_val); + void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, + size_t cache_vocab_size, size_t embedding_size); + void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); + const Address &QueryHashTableAddr(const std::string ¶m_name) const; + bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } + void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } + bool initialized_ps_cache() const { return initialized_ps_cache_; } + void DoProcessData(uint32_t device_id, void *context); + void IncreaseGraphStep(const std::string &channel_name); + void DumpHashTables() const; + + private: + PsCacheManager() = default; + ~PsCacheManager() = default; + PsCacheManager(const PsCacheManager &) = delete; + PsCacheManager &operator=(const PsCacheManager &) = delete; + void IncreaseStep(); + void set_current_graph_step() { graph_running_step_ = graph_step_; } + std::string channel_name(); + void set_channel_name(const std::string channel_name); + void InitParameterServer(); + void AllocMemForHashTable(); + void SetLocalIdRank(); + void ProcessDataTask(uint32_t device_id, void *context); + void ProcessData(); + void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); + void WaitGraphRun(); + int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); + void ParseHostDataHostToDevice(size_t id); + void ParseHostDataDeviceToHost(size_t id); + void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); + void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); + void HashSwapHostToDevice(const HashTableInfo &hash_info); + void HashSwapDeviceToHost(const HashTableInfo &hash_info); + void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); + void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); + void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, + float *hash_table_addr); + void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, + const int *indices_addr, float *output_addr); + void UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key); + void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, + const int *indices_addr, float *output_addr); + + bool initialized_ps_cache_{false}; + std::string channel_name_; + std::mutex channel_mutex_; + std::atomic_ulong graph_step_{0}; + size_t graph_running_step_{0}; + size_t data_step_{0}; + std::mutex data_mutex_; + std::condition_variable data_prase_; + + std::map hash_tables_; + std::shared_ptr embedding_device_cache_; + std::shared_ptr embedding_host_cache_; + + size_t vocab_size_{0}; + size_t cache_vocab_size_{0}; + size_t host_cache_vocab_size_{0}; + size_t batch_elements_{0}; + PsCacheStatisticsInfo statistics_info_; + std::pair range_bound_; +}; + +static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_CACHE_MANAGER_H_ diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index e6fd916db8..9950351f93 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -141,6 +141,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info. list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc")