Browse Source

remove static worker

tags/v1.2.0-rc1
chendongsheng 5 years ago
parent
commit
d29f2b2634
9 changed files with 31 additions and 32 deletions
  1. +4
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
  2. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h
  4. +4
    -4
      mindspore/ccsrc/backend/session/session_basic.cc
  5. +1
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  6. +1
    -1
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  7. +2
    -2
      mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc
  8. +15
    -15
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc
  9. +0
    -2
      mindspore/ccsrc/ps/worker.h

+ 4
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc View File

@@ -52,9 +52,9 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()),
SizeToLong(output_shape.size())};
if (mindspore::ps::PSContext::instance()->is_worker()) {
mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]);
mindspore::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape[axis]);
mindspore::ps::ParamInitInfoMessage info;
mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info);
mindspore::ps::Worker::GetInstance().InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info);
}
}

@@ -81,7 +81,8 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
return false;
}
mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
mindspore::ps::Worker::GetInstance().DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result,
mindspore::ps::kEmbeddingLookupCmd);

auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size);
if (ret2 != EOK) {


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h View File

@@ -36,10 +36,10 @@ class PullKernel : public CPUKernel {
if (inputs.size() != 2) {
MS_LOG(EXCEPTION) << "Inputs size is " << inputs.size() << ", but PullKernel needs 2.";
}
bool init_in_server = mindspore::ps::worker.GetParamInitInServer(param_name_);
bool init_in_server = mindspore::ps::Worker::GetInstance().GetParamInitInServer(param_name_);
// If init_in_server, forward kernel should run in server too.
if (!init_in_server) {
mindspore::ps::worker.Pull(key_, inputs[1]->addr, inputs[1]->size);
mindspore::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size);
}
return true;
}


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h View File

@@ -45,7 +45,7 @@ class PushKernel : public CPUKernel {
addrs.push_back(reinterpret_cast<uintptr_t>(input->addr));
sizes.push_back(SizeToLong(input->size) / sizeof(T));
}
mindspore::ps::worker.Push(keys, addrs, sizes);
mindspore::ps::Worker::GetInstance().Push(keys, addrs, sizes);
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, &key_, sizeof(size_t));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
@@ -63,7 +63,7 @@ class PushKernel : public CPUKernel {
MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices;
for (size_t i = 0; i < optim_input_shapes.size(); i++) {
auto shape = optim_input_shapes[i];
mindspore::ps::worker.SetOptimInputShapes(key_, shape);
mindspore::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape);
if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) {
size_t size = sizeof(T);
for (size_t j = 0; j < shape.size(); j++) {


+ 4
- 4
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -2515,7 +2515,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
size_t embedding_table_idx = 0;
auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
size_t key = ps::worker.SetParamKey(embedding_table->fullname_with_scope());
size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
} else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
auto pull_node = FindPullNode(node, node_list);
@@ -2526,12 +2526,12 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
// Second input of Pull node is the trainable parameter.
size_t parameter_index = 1;
auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
size_t key = ps::worker.SetParamKey(parameter_node->fullname_with_scope());
size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);

std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
ps::worker.SetKeyOptimId(key, optimizer_name);
ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name);
}
}
}
@@ -2553,7 +2553,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
ps::worker.InitPSParamAndOptim(input_node, tensor);
ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor);
}
}
}


+ 1
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -544,7 +544,7 @@ bool ExecuteAction(const ResourcePtr &res) {

#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
bool StartPSWorkerAction(const ResourcePtr &res) {
ps::worker.Run();
ps::Worker::GetInstance().Run();
return true;
}



+ 1
- 1
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -1159,7 +1159,7 @@ void ClearResAtexit() {
ps::ps_cache_instance.Finalize();
}
MS_LOG(INFO) << "ps::worker.Finalize";
ps::worker.Finalize();
ps::Worker::GetInstance().Finalize();
}
#endif
ad::g_k_prims.clear();


+ 2
- 2
mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc View File

@@ -188,7 +188,7 @@ bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(src);
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtMemcpyAsync failed";
MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret;
return false;
}
return true;
@@ -199,7 +199,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(src);
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtMemcpyAsync failed";
MS_LOG(ERROR) << "rtMemcpyAsync failed, the error num is:" << ret;
return false;
}
return true;


+ 15
- 15
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -144,8 +144,8 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string &param_name)

void PsCacheManager::Initialize() {
MS_LOG(INFO) << "PS cache initialize.";
if (!worker.running()) {
worker.Run();
if (!Worker::GetInstance().running()) {
Worker::GetInstance().Run();
}
embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_);
embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_vocab_cache_size_);
@@ -159,10 +159,10 @@ void PsCacheManager::Initialize() {
void PsCacheManager::AddEmbeddingTable() const {
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
size_t key = Worker::GetInstance().SetParamKey(param_name);
size_t row_count = item.second.vocab_size;
// if worker role
worker.AddEmbeddingTable(key, row_count);
Worker::GetInstance().AddEmbeddingTable(key, row_count);
}
}

@@ -175,7 +175,7 @@ void PsCacheManager::InitParameterServer() {
}
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
size_t key = Worker::GetInstance().SetParamKey(param_name);
const auto &hash_table_info = item.second;
const auto &param_init_info = hash_table_info.param_init_info_;

@@ -188,7 +188,7 @@ void PsCacheManager::InitParameterServer() {
info.set_global_seed(param_init_info.global_seed_);
info.set_op_seed(param_init_info.op_seed_);
// if worker role
worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
Worker::GetInstance().InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info);
}

finish_init_parameter_server_ = true;
@@ -380,7 +380,7 @@ bool PsCacheManager::ProcessData() {
return false;
}
for (const auto &item : hash_tables_) {
auto key = worker.GetParamKey(item.first);
auto key = Worker::GetInstance().GetParamKey(item.first);
auto hash_info = item.second;
RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info));
RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info));
@@ -839,7 +839,7 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
return true;
}

@@ -861,7 +861,7 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
lookup_result.data(), host_hash_table_addr));
return true;
@@ -915,7 +915,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
Worker::GetInstance().DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
// Hash swap-in in device.
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
@@ -945,7 +945,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_dat
}
// Need synchronize event to ensure that the swap-out in device is completed.
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent());
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
return true;
}

@@ -987,7 +987,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
if (hash_info.param_init_info_.param_type_ != kWeight) {
continue;
}
auto key = worker.GetParamKey(item.first);
auto key = Worker::GetInstance().GetParamKey(item.first);
std::vector<int> lookup_ids(swap_indices_lens, 0);
std::vector<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
@@ -1003,7 +1003,7 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
}
return true;
}
@@ -1031,7 +1031,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() {
if (hash_info.param_init_info_.param_type_ != kWeight) {
continue;
}
auto key = worker.GetParamKey(item.first);
auto key = Worker::GetInstance().GetParamKey(item.first);
std::vector<int> lookup_ids(swap_indices_lens, 0);
std::vector<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
@@ -1055,7 +1055,7 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
Worker::GetInstance().UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
}
return true;
}


+ 0
- 2
mindspore/ccsrc/ps/worker.h View File

@@ -148,8 +148,6 @@ class Worker {

std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
};

static Worker &worker = Worker::GetInstance();
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_WORKER_H_

Loading…
Cancel
Save