From d29f2b2634c2b2e6c9cc4127ed7cd0a2fb2af8f7 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Tue, 9 Mar 2021 11:30:57 +0800 Subject: [PATCH] remove static worker --- .../cpu/ps/embedding_look_up_proxy_kernel.cc | 7 +++-- .../kernel_compiler/cpu/ps/pull_kernel.h | 4 +-- .../kernel_compiler/cpu/ps/push_kernel.h | 4 +-- .../ccsrc/backend/session/session_basic.cc | 8 ++--- mindspore/ccsrc/pipeline/jit/action.cc | 2 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 2 +- .../ps/ps_cache/ascend/ascend_ps_cache.cc | 4 +-- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 30 +++++++++---------- mindspore/ccsrc/ps/worker.h | 2 -- 9 files changed, 31 insertions(+), 32 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index f533195803..11b44db9ed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -52,9 +52,9 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { std::vector lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), SizeToLong(output_shape.size())}; if (mindspore::ps::PSContext::instance()->is_worker()) { - mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]); + mindspore::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); mindspore::ps::ParamInitInfoMessage info; - mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); + mindspore::ps::Worker::GetInstance().InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); } } @@ -81,7 +81,8 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector &i MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; return false; } - mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + mindspore::ps::Worker::GetInstance().DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, + mindspore::ps::kEmbeddingLookupCmd); auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size); if (ret2 != EOK) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h index 221d8cbcea..bfafe83d41 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -36,10 +36,10 @@ class PullKernel : public CPUKernel { if (inputs.size() != 2) { MS_LOG(EXCEPTION) << "Inputs size is " << inputs.size() << ", but PullKernel needs 2."; } - bool init_in_server = mindspore::ps::worker.GetParamInitInServer(param_name_); + bool init_in_server = mindspore::ps::Worker::GetInstance().GetParamInitInServer(param_name_); // If init_in_server, forward kernel should run in server too. if (!init_in_server) { - mindspore::ps::worker.Pull(key_, inputs[1]->addr, inputs[1]->size); + mindspore::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h index fa60043d2a..b76b25ef3d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -45,7 +45,7 @@ class PushKernel : public CPUKernel { addrs.push_back(reinterpret_cast(input->addr)); sizes.push_back(SizeToLong(input->size) / sizeof(T)); } - mindspore::ps::worker.Push(keys, addrs, sizes); + mindspore::ps::Worker::GetInstance().Push(keys, addrs, sizes); auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, &key_, sizeof(size_t)); if (ret != EOK) { MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; @@ -63,7 +63,7 @@ class PushKernel : public CPUKernel { MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; for (size_t i = 0; i < optim_input_shapes.size(); i++) { auto shape = optim_input_shapes[i]; - mindspore::ps::worker.SetOptimInputShapes(key_, shape); + mindspore::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { size_t size = sizeof(T); for (size_t j = 0; j < shape.size(); j++) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 3aca8f01e5..14ad088a13 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -2515,7 +2515,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { size_t embedding_table_idx = 0; auto embedding_table = AnfAlgo::GetInputNode(node->cast(), embedding_table_idx); - size_t key = ps::worker.SetParamKey(embedding_table->fullname_with_scope()); + size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope()); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) { auto pull_node = FindPullNode(node, node_list); @@ -2526,12 +2526,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { // Second input of Pull node is the trainable parameter. size_t parameter_index = 1; auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast(), parameter_index); - size_t key = ps::worker.SetParamKey(parameter_node->fullname_with_scope()); + size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope()); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node); std::string optimizer_name = AnfAlgo::GetNodeAttr(node, kAttrOptimizerType); - ps::worker.SetKeyOptimId(key, optimizer_name); + ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name); } } } @@ -2553,7 +2553,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, auto input_node = input_nodes[i]; MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - ps::worker.InitPSParamAndOptim(input_node, tensor); + ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor); } } } diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index aaa76c17e3..c0a35223f2 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -544,7 +544,7 @@ bool ExecuteAction(const ResourcePtr &res) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) bool StartPSWorkerAction(const ResourcePtr &res) { - ps::worker.Run(); + ps::Worker::GetInstance().Run(); return true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 0c33a6c2a2..f0bd195f2c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1159,7 +1159,7 @@ void ClearResAtexit() { ps::ps_cache_instance.Finalize(); } MS_LOG(INFO) << "ps::worker.Finalize"; - ps::worker.Finalize(); + ps::Worker::GetInstance().Finalize(); } #endif ad::g_k_prims.clear(); diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc index 0eb8c02cf0..f4988048e7 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc @@ -188,7 +188,7 @@ bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { MS_ERROR_IF_NULL(src); auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtMemcpyAsync failed"; + MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret; return false; } return true; @@ -199,7 +199,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { MS_ERROR_IF_NULL(src); auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtMemcpyAsync failed"; + MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret; return false; } return true; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 720e026d7e..50369de543 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -144,8 +144,8 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) void PsCacheManager::Initialize() { MS_LOG(INFO) << "PS cache initialize."; - if (!worker.running()) { - worker.Run(); + if (!Worker::GetInstance().running()) { + Worker::GetInstance().Run(); } embedding_device_cache_ = std::make_shared(batch_elements_, vocab_cache_size_); embedding_host_cache_ = std::make_shared(batch_elements_, host_vocab_cache_size_); @@ -159,10 +159,10 @@ void PsCacheManager::Initialize() { void PsCacheManager::AddEmbeddingTable() const { for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; - size_t key = worker.SetParamKey(param_name); + size_t key = Worker::GetInstance().SetParamKey(param_name); size_t row_count = item.second.vocab_size; // if worker role - worker.AddEmbeddingTable(key, row_count); + Worker::GetInstance().AddEmbeddingTable(key, row_count); } } @@ -175,7 +175,7 @@ void PsCacheManager::InitParameterServer() { } for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; - size_t key = worker.SetParamKey(param_name); + size_t key = Worker::GetInstance().SetParamKey(param_name); const auto &hash_table_info = item.second; const auto ¶m_init_info = hash_table_info.param_init_info_; @@ -188,7 +188,7 @@ void PsCacheManager::InitParameterServer() { info.set_global_seed(param_init_info.global_seed_); info.set_op_seed(param_init_info.op_seed_); // if worker role - worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); + Worker::GetInstance().InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); } finish_init_parameter_server_ = true; @@ -380,7 +380,7 @@ bool PsCacheManager::ProcessData() { return false; } for (const auto &item : hash_tables_) { - auto key = worker.GetParamKey(item.first); + auto key = Worker::GetInstance().GetParamKey(item.first); auto hash_info = item.second; RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info)); RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info)); @@ -839,7 +839,7 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); return true; } @@ -861,7 +861,7 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), host_hash_table_addr)); return true; @@ -915,7 +915,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); // Hash swap-in in device. RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), @@ -945,7 +945,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector &swap_out_dat } // Need synchronize event to ensure that the swap-out in device is completed. RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent()); - worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); return true; } @@ -987,7 +987,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() { if (hash_info.param_init_info_.param_type_ != kWeight) { continue; } - auto key = worker.GetParamKey(item.first); + auto key = Worker::GetInstance().GetParamKey(item.first); std::vector lookup_ids(swap_indices_lens, 0); std::vector swap_out_data; auto embedding_size = hash_info.embedding_size; @@ -1003,7 +1003,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); } return true; } @@ -1031,7 +1031,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { if (hash_info.param_init_info_.param_type_ != kWeight) { continue; } - auto key = worker.GetParamKey(item.first); + auto key = Worker::GetInstance().GetParamKey(item.first); std::vector lookup_ids(swap_indices_lens, 0); std::vector swap_out_data; auto embedding_size = hash_info.embedding_size; @@ -1055,7 +1055,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; } - worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); } return true; } diff --git a/mindspore/ccsrc/ps/worker.h b/mindspore/ccsrc/ps/worker.h index bef27a9c88..efbabd33fc 100644 --- a/mindspore/ccsrc/ps/worker.h +++ b/mindspore/ccsrc/ps/worker.h @@ -148,8 +148,6 @@ class Worker { std::unordered_map>> embedding_table_ranges_; }; - -static Worker &worker = Worker::GetInstance(); } // namespace ps } // namespace mindspore #endif // MINDSPORE_CCSRC_PS_WORKER_H_