| @@ -52,9 +52,9 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), | |||
| SizeToLong(output_shape.size())}; | |||
| if (mindspore::ps::PSContext::instance()->is_worker()) { | |||
| mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]); | |||
| mindspore::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]); | |||
| mindspore::ps::ParamInitInfoMessage info; | |||
| mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); | |||
| mindspore::ps::Worker::GetInstance().InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); | |||
| } | |||
| } | |||
| @@ -81,7 +81,8 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| mindspore::ps::Worker::GetInstance().DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, | |||
| mindspore::ps::kEmbeddingLookupCmd); | |||
| auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size); | |||
| if (ret2 != EOK) { | |||
| @@ -36,10 +36,10 @@ class PullKernel : public CPUKernel { | |||
| if (inputs.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Inputs size is " << inputs.size() << ", but PullKernel needs 2."; | |||
| } | |||
| bool init_in_server = mindspore::ps::worker.GetParamInitInServer(param_name_); | |||
| bool init_in_server = mindspore::ps::Worker::GetInstance().GetParamInitInServer(param_name_); | |||
| // If init_in_server, forward kernel should run in server too. | |||
| if (!init_in_server) { | |||
| mindspore::ps::worker.Pull(key_, inputs[1]->addr, inputs[1]->size); | |||
| mindspore::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -45,7 +45,7 @@ class PushKernel : public CPUKernel { | |||
| addrs.push_back(reinterpret_cast<uintptr_t>(input->addr)); | |||
| sizes.push_back(SizeToLong(input->size) / sizeof(T)); | |||
| } | |||
| mindspore::ps::worker.Push(keys, addrs, sizes); | |||
| mindspore::ps::Worker::GetInstance().Push(keys, addrs, sizes); | |||
| auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, &key_, sizeof(size_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| @@ -63,7 +63,7 @@ class PushKernel : public CPUKernel { | |||
| MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; | |||
| for (size_t i = 0; i < optim_input_shapes.size(); i++) { | |||
| auto shape = optim_input_shapes[i]; | |||
| mindspore::ps::worker.SetOptimInputShapes(key_, shape); | |||
| mindspore::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); | |||
| if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { | |||
| size_t size = sizeof(T); | |||
| for (size_t j = 0; j < shape.size(); j++) { | |||
| @@ -2515,7 +2515,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||
| if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { | |||
| size_t embedding_table_idx = 0; | |||
| auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx); | |||
| size_t key = ps::worker.SetParamKey(embedding_table->fullname_with_scope()); | |||
| size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); | |||
| } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) { | |||
| auto pull_node = FindPullNode(node, node_list); | |||
| @@ -2526,12 +2526,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||
| // Second input of Pull node is the trainable parameter. | |||
| size_t parameter_index = 1; | |||
| auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index); | |||
| size_t key = ps::worker.SetParamKey(parameter_node->fullname_with_scope()); | |||
| size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope()); | |||
| AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node); | |||
| AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node); | |||
| std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType); | |||
| ps::worker.SetKeyOptimId(key, optimizer_name); | |||
| ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name); | |||
| } | |||
| } | |||
| } | |||
| @@ -2553,7 +2553,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, | |||
| auto input_node = input_nodes[i]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| ps::worker.InitPSParamAndOptim(input_node, tensor); | |||
| ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor); | |||
| } | |||
| } | |||
| } | |||
| @@ -544,7 +544,7 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| bool StartPSWorkerAction(const ResourcePtr &res) { | |||
| ps::worker.Run(); | |||
| ps::Worker::GetInstance().Run(); | |||
| return true; | |||
| } | |||
| @@ -1159,7 +1159,7 @@ void ClearResAtexit() { | |||
| ps::ps_cache_instance.Finalize(); | |||
| } | |||
| MS_LOG(INFO) << "ps::worker.Finalize"; | |||
| ps::worker.Finalize(); | |||
| ps::Worker::GetInstance().Finalize(); | |||
| } | |||
| #endif | |||
| ad::g_k_prims.clear(); | |||
| @@ -188,7 +188,7 @@ bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret; | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -199,7 +199,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret; | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -144,8 +144,8 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) | |||
| void PsCacheManager::Initialize() { | |||
| MS_LOG(INFO) << "PS cache initialize."; | |||
| if (!worker.running()) { | |||
| worker.Run(); | |||
| if (!Worker::GetInstance().running()) { | |||
| Worker::GetInstance().Run(); | |||
| } | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_); | |||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_); | |||
| @@ -159,10 +159,10 @@ void PsCacheManager::Initialize() { | |||
| void PsCacheManager::AddEmbeddingTable() const { | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| size_t key = Worker::GetInstance().SetParamKey(param_name); | |||
| size_t row_count = item.second.vocab_size; | |||
| // if worker role | |||
| worker.AddEmbeddingTable(key, row_count); | |||
| Worker::GetInstance().AddEmbeddingTable(key, row_count); | |||
| } | |||
| } | |||
| @@ -175,7 +175,7 @@ void PsCacheManager::InitParameterServer() { | |||
| } | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| size_t key = Worker::GetInstance().SetParamKey(param_name); | |||
| const auto &hash_table_info = item.second; | |||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||
| @@ -188,7 +188,7 @@ void PsCacheManager::InitParameterServer() { | |||
| info.set_global_seed(param_init_info.global_seed_); | |||
| info.set_op_seed(param_init_info.op_seed_); | |||
| // if worker role | |||
| worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); | |||
| Worker::GetInstance().InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); | |||
| } | |||
| finish_init_parameter_server_ = true; | |||
| @@ -380,7 +380,7 @@ bool PsCacheManager::ProcessData() { | |||
| return false; | |||
| } | |||
| for (const auto &item : hash_tables_) { | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto key = Worker::GetInstance().GetParamKey(item.first); | |||
| auto hash_info = item.second; | |||
| RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info)); | |||
| RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info)); | |||
| @@ -839,7 +839,7 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| return true; | |||
| } | |||
| @@ -861,7 +861,7 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, | |||
| lookup_result.data(), host_hash_table_addr)); | |||
| return true; | |||
| @@ -915,7 +915,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| // Hash swap-in in device. | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), | |||
| @@ -945,7 +945,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_dat | |||
| } | |||
| // Need synchronize event to ensure that the swap-out in device is completed. | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent()); | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| return true; | |||
| } | |||
| @@ -987,7 +987,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() { | |||
| if (hash_info.param_init_info_.param_type_ != kWeight) { | |||
| continue; | |||
| } | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto key = Worker::GetInstance().GetParamKey(item.first); | |||
| std::vector<int> lookup_ids(swap_indices_lens, 0); | |||
| std::vector<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| @@ -1003,7 +1003,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() { | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -1031,7 +1031,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { | |||
| if (hash_info.param_init_info_.param_type_ != kWeight) { | |||
| continue; | |||
| } | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto key = Worker::GetInstance().GetParamKey(item.first); | |||
| std::vector<int> lookup_ids(swap_indices_lens, 0); | |||
| std::vector<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| @@ -1055,7 +1055,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -148,8 +148,6 @@ class Worker { | |||
| std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_; | |||
| }; | |||
| static Worker &worker = Worker::GetInstance(); | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_WORKER_H_ | |||