Browse Source

!12516 added worker

From: @anancds
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
7c32bbe972
13 changed files with 1320 additions and 25 deletions
  1. +4
    -4
      mindspore/ccsrc/pipeline/jit/init.cc
  2. +1
    -0
      mindspore/ccsrc/ps/CMakeLists.txt
  3. +3
    -0
      mindspore/ccsrc/ps/common.h
  4. +133
    -0
      mindspore/ccsrc/ps/internal/constants.h
  5. +974
    -0
      mindspore/ccsrc/ps/internal/worker.cc
  6. +157
    -0
      mindspore/ccsrc/ps/internal/worker.h
  7. +17
    -4
      mindspore/ccsrc/ps/ps_context.cc
  8. +22
    -5
      mindspore/ccsrc/ps/ps_context.h
  9. +4
    -4
      mindspore/ccsrc/ps/util.cc
  10. +4
    -4
      mindspore/parallel/_ps_context.py
  11. +1
    -0
      tests/ut/cpp/CMakeLists.txt
  12. +0
    -2
      tests/ut/cpp/ps/core/http_client_test.cc
  13. +0
    -2
      tests/ut/cpp/ps/core/http_server_test.cc

+ 4
- 4
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -309,11 +309,11 @@ PYBIND11_MODULE(_c_expression, m) {
(void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext") (void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
.def_static("get_instance", &PSContext::instance, "Get PS context instance.") .def_static("get_instance", &PSContext::instance, "Get PS context instance.")
.def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.") .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
.def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.")
.def("is_ps_mode", &PSContext::is_ps_mode, "Get PS mode enable-disable status.")
.def("reset", &PSContext::Reset, "Reset PS context attributes.") .def("reset", &PSContext::Reset, "Reset PS context attributes.")
.def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
.def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
.def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
.def("is_worker", &PSContext::is_worker, "Get whether the role of this process is Worker.")
.def("is_server", &PSContext::is_server, "Get whether the role of this process is PServer.")
.def("is_scheduler", &PSContext::is_scheduler, "Get whether the role of this process is Scheduler.")
.def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.") .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.")
.def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.") .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.")
.def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize, .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize,


+ 1
- 0
mindspore/ccsrc/ps/CMakeLists.txt View File

@@ -21,6 +21,7 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)))
list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc")
list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc")
endif() endif()


if(NOT ENABLE_D) if(NOT ENABLE_D)


+ 3
- 0
mindspore/ccsrc/ps/common.h View File

@@ -17,11 +17,14 @@
#ifndef MINDSPORE_CCSRC_PS_COMMON_H_ #ifndef MINDSPORE_CCSRC_PS_COMMON_H_
#define MINDSPORE_CCSRC_PS_COMMON_H_ #define MINDSPORE_CCSRC_PS_COMMON_H_


#include <limits.h>

#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map> #include <map>
#include <string> #include <string>

#include "ps/ps.h" #include "ps/ps.h"


namespace mindspore { namespace mindspore {


+ 133
- 0
mindspore/ccsrc/ps/internal/constants.h View File

@@ -0,0 +1,133 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_

#include <climits>
#include <iostream>
#include <vector>
#include <memory>
#include <map>
#include <string>

namespace mindspore {
namespace ps {
namespace internal {

constexpr char kEnvCommType[] = "MS_COMM_TYPE";
constexpr char kEnvInterface[] = "MS_INTERFACE";
constexpr char kEnvPServerNum[] = "MS_SERVER_NUM";
constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM";
constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST";
constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT";

constexpr char kCommTypeOfIBVerbs[] = "ibverbs";
constexpr char kRoleOfPServer[] = "server";
constexpr char kRoleOfWorker[] = "worker";
constexpr char kRoleOfScheduler[] = "scheduler";

constexpr char kLearningRate[] = "learning_rate";
constexpr char kMomentum[] = "momentum";

constexpr char kApplyMomentum[] = "ApplyMomentum";
constexpr char kSparseAdam[] = "Adam";
constexpr char kSparseLazyAdam[] = "LazyAdam";
constexpr char kSparseFtrl[] = "Ftrl";
constexpr char kApplyMomentumOp[] = "Momentum";
constexpr char kSparseAdamOp[] = "Adam";
constexpr char kSparseLazyAdamOp[] = "LazyAdam";
constexpr char kSparseFtrlOp[] = "FTRL";

constexpr int64_t kInitWeightsCmd = 10;
constexpr int64_t kInitWeightToOptimIdCmd = 11;
constexpr int64_t kInitOptimInputsShapeCmd = 12;
constexpr int64_t kInitKeyToPushNodeIdCmd = 13;
constexpr int64_t kInitEmbeddingsCmd = 20;
constexpr int64_t kUpdateEmbeddingsCmd = 21;
constexpr int64_t kCheckReadyForPushCmd = 25;
constexpr int64_t kCheckReadyForPullCmd = 26;
constexpr int64_t kEmbeddingLookupCmd = 30;
constexpr int64_t kFinalizeCmd = 40;
constexpr int64_t kPushCmd = 50;
constexpr int64_t kPullCmd = 51;

constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1;

using DataPtr = std::shared_ptr<unsigned char>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;
using Keys = std::vector<Key>;
using Values = std::vector<float>;
using ValuesPtr = std::shared_ptr<Values>;
using Weight = std::vector<float>;
using Grad = std::vector<float>;
using LookupIds = std::vector<Key>;
using Lengths = std::vector<int>;
using WeightPtr = std::shared_ptr<Weight>;
using GradPtr = std::shared_ptr<Grad>;
using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>;
using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>;

constexpr size_t INDEX_NOT_SEND = UINT_MAX;
using OptimOriginIdx = std::map<std::string, size_t>;
using OptimPSSendIdx = std::map<std::string, size_t>;

const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}};
const OptimPSSendIdx kMomentumPSSendIdx = {
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}};

const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3},
{"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7},
{"eps", 8}, {"grad", 9}, {"indices", 10}};
const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND},
{"m", INDEX_NOT_SEND},
{"v", INDEX_NOT_SEND},
{"beta1_power", 0},
{"beta2_power", 1},
{"lr", 2},
{"beta1", 3},
{"beta2", 4},
{"eps", 5},
{"grad", 6},
{"indices", 7}};

const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}};
const OptimPSSendIdx kSparseFtrlPSSendIdx = {
{"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}};

const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx},
{kSparseAdam, kSparseAdamOriginIdx},
{kSparseLazyAdam, kSparseAdamOriginIdx},
{kSparseFtrl, kSparseFtrlOriginIdx}};
const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx},
{kSparseAdam, kSparseAdamPSSendIdx},
{kSparseLazyAdam, kSparseAdamPSSendIdx},
{kSparseFtrl, kSparseFtrlPSSendIdx}};

#define EXC_IF_VEC_IDX_OOB(vec, idx) \
{ \
size_t vec_size = vec.size(); \
if (idx >= vec_size) { \
MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \
<< " is out of bound."; \
} \
}
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_

+ 974
- 0
mindspore/ccsrc/ps/internal/worker.cc View File

@@ -0,0 +1,974 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ps/internal/worker.h"

namespace mindspore {
namespace ps {
namespace internal {
void Worker::Run() {
std::lock_guard<std::mutex> lock(running_mutex_);
core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
server_num_ = PSContext::instance()->initial_server_num();
if (running_) {
MS_LOG(INFO) << "'Worker is already running.";
return;
}
if (!PSContext::instance()->is_worker()) {
MS_LOG(EXCEPTION) << "The role is not worker.";
}

Initialize();
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
worker_node_.Start();
MS_LOG(INFO) << "Worker connected successfully.";

running_ = true;
}

void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
if (keys.size() == 0) {
MS_LOG(EXCEPTION) << "key size should be greater than zero";
}
if (key_to_optimId_.count(keys[0]) == 0) {
MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
}
Key key = keys[0];
int64_t optim_id = key_to_optimId_[key];
MS_LOG(INFO) << "The key is:" << key << " the optim_id:" << optim_id;
bool is_sparse = false;
if (optim_id == 1 || optim_id == 2 || optim_id == 3) {
is_sparse = true;
}
int64_t grad_index = -1;
int64_t indice_index = -1;

// Sparse adam gradient
if (optim_id == 1 || optim_id == 2) {
grad_index = 6;
indice_index = 7;

// Sparse ftrl gradient
} else if (optim_id == 3) {
grad_index = 0;
indice_index = 1;
}

size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
std::vector<float> total_buffer(total_size, 0);
size_t offset = 0;
for (size_t i = 0; i < sizes.size(); i++) {
void *dst_data = total_buffer.data() + offset / sizeof(float);
void *src_data = reinterpret_cast<void *>(addrs[i]);
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
int size = sizes[i] * sizeof(float);
auto ret = memcpy_s(dst_data, size, src_data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += size;
}
MS_LOG(INFO) << "The total size is:" << total_size;

while (!IsReadyForPush(keys[0])) {
continue;
}
std::vector<int> sizes_int;
(void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
[](const int64_t &value) { return static_cast<int>(value); });
if (!is_sparse) {
PushData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), kPushCmd);
} else {
std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
int64_t first_dim_size = var_shape[0];
int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
MS_LOG(DEBUG) << "The keys:" << keys << " the total_buffer:" << total_buffer << " the sizes_int:" << sizes_int
<< " the grad_index:" << grad_index << " the indice_index:" << indice_index
<< " the first_dim_size:" << first_dim_size << " the outer_dim_size" << outer_dim_size;
PushSparseData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), grad_index, indice_index,
first_dim_size, outer_dim_size);
}
}

void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
MS_EXCEPTION_IF_NULL(dev_addr);
std::vector<float> variables(size / sizeof(float), 0);
while (!IsReadyForPull(key)) {
continue;
}
PullData({key}, &variables, nullptr, kPullCmd);
MS_LOG(DEBUG) << "The variables:" << variables << " the size is:" << size;
size_t dst_size = size;
size_t src_size = size;
auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}

size_t Worker::SetParamKey(const std::string &param_name) {
size_t key = UINT64_MAX;
if (param_to_key_.count(param_name)) {
key = param_to_key_[param_name];
MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
} else {
key = key_cnt_++;
param_to_key_[param_name] = key;
MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
}
return key;
}

size_t Worker::GetParamKey(const std::string &param_name) {
size_t key = kInvalidKey;
if (param_to_key_.find(param_name) != param_to_key_.end()) {
key = param_to_key_[param_name];
MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
}
return key;
}

void Worker::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
param_to_init_in_server_[param_name] = init_in_server;
}

bool Worker::GetParamInitInServer(const std::string &param_name) {
if (param_to_init_in_server_.count(param_name) == 0) {
return false;
}
return param_to_init_in_server_[param_name];
}

void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
MS_LOG(INFO) << "SetKeyOptimId key is:" << key << " optimizer_name:" << optimizer_name;
key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
}

void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
key_to_optim_shapes_[key] = {shape};
} else {
key_to_optim_shapes_[key].push_back(shape);
}
}

void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
bool has_init = IsKeyInit(key);
if (has_init) {
return;
}
uint64_t begin = 0;
uint64_t end = 0;
for (int64_t i = 0; i < server_num_; i++) {
int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_);
MS_LOG(DEBUG) << "The row_count:" << row_count << " the local_row_cnt:" << local_row_cnt;
if (i == 0) {
end = local_row_cnt - 1;
} else {
begin = end + 1;
end += local_row_cnt;
}
EmbeddingTableShardMetadata range(begin, end);
if (embedding_table_ranges_.count(key) == 0) {
embedding_table_ranges_[key] = std::make_shared<std::vector<EmbeddingTableShardMetadata>>();
MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
}
embedding_table_ranges_[key]->push_back(range);
}
embedding_row_cnt_[key] = row_count;
}

void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape) {
bool has_init = IsKeyInit(key);
if (has_init) {
MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
return;
}

EmbeddingTableMeta embedding_table_meta;
embedding_table_meta.set_key(key);
*embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
*embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
*embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};

std::string kv_data = embedding_table_meta.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}

worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kInitEmbeddingsCmd);
}

void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(input_node);
auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
const std::string &param_name = pk_node->fullname_with_scope();
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());

size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
return;
}
bool init_in_server = false;
auto param_info_ptr = pk_node->param_info();
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);
bool init = IsKeyInit(param_key);
if (!init) {
MS_LOG(INFO) << "Init parameter key " << param_key << " and optimizer in parameter server side for " << param_name
<< ", whether init in server: " << init_in_server;
AddKeyToServerId(param_key);
if (!PsDataPrefetch::GetInstance().cache_enable()) {
if (!init_in_server) {
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
<< param_size;
}
InitPSParamData({param_key}, param_data, param_size);
}
InitPSOptimId(param_key);
InitPSOptimInputShapes(param_key);
}
}
}

void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
int64_t cmd) {
MS_EXCEPTION_IF_NULL(lookup_result);
EmbeddingTableLookup embedding_table_lookup;
embedding_table_lookup.set_key(key);
*embedding_table_lookup.mutable_keys() = {lookup_ids.begin(), lookup_ids.end()};

PartitionEmbeddingMessages messages;
lookup_partitioner_(embedding_table_lookup, &messages, {});
std::vector<uint32_t> rank_ids;
std::vector<DataPtr> data;
std::vector<size_t> sizes;
for (size_t i = 0; i < messages.size(); i++) {
if (messages.at(i).first) {
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
data.push_back(res);
sizes.push_back(kv_data.length());
}
}

std::vector<VectorPtr> resp;
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
for (size_t i = 0; i < resp.size(); ++i) {
KVMessage message;
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
int64_t offset = 0;
values->clear();
for (auto j = 0; j < message.values_size(); j++) {
values->push_back(message.values(j));
}
MS_LOG(DEBUG) << "the embedding resp:" << values;
for (auto k = 0; k < message.keys_size(); k++) {
const Key &key = message.keys(k);
float *addr = values->data() + offset;
offset += single_id_len;
id_addr_map[key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
}
}

float *result_addr = lookup_result->data();
MS_EXCEPTION_IF_NULL(result_addr);
int64_t offset = 0;
size_t dst_size = 0;
size_t src_size = 0;
void *dst_data = nullptr;
void *src_data = nullptr;
for (size_t i = 0; i < lookup_ids.size(); i++) {
if (id_addr_map.count(lookup_ids[i]) == 0) {
offset += single_id_len;
continue;
}
const Key &key = static_cast<Key>(lookup_ids[i]);
auto &pair = id_addr_map[key];
int64_t size = single_id_len * sizeof(float);
dst_size = size;
src_size = size;
dst_data = result_addr + offset;
src_data = pair->first;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += single_id_len;
}
}

void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
const std::vector<float> &vals) {
KVMessage kvs;
*kvs.mutable_keys() = {keys.begin(), keys.end()};
*kvs.mutable_len() = {lookup_ids.begin(), lookup_ids.end()};
*kvs.mutable_values() = {vals.begin(), vals.end()};
PartitionKVMessages messages;
update_embedding_partitioner_(kvs, &messages, {});
std::vector<uint32_t> rank_ids;
std::vector<DataPtr> data;
std::vector<size_t> sizes;
for (size_t i = 0; i < messages.size(); i++) {
if (messages.at(i).first) {
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
data.push_back(res);
sizes.push_back(kv_data.length());
}
}
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, 0);
}

void Worker::Finalize() {
if (running_) {
MS_LOG(INFO) << "Worker starts finalizing...";
KVMessage kvs;
kvs.add_keys(0);
kvs.add_values(0.0f);
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kFinalizeCmd);
worker_node_.Finish();
worker_node_.Stop();
running_ = false;
MS_LOG(INFO) << "Worker finalized successfully.";
}
}

void Worker::Initialize() {
lookup_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
LookupIdPartitioner(send, partition, attrs);
};
worker_init_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
WorkerInitEmbeddingPartitioner(send, partition, attrs);
};
round_robin_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
RoundRobinPartitioner(send, partition, attrs);
};
sparse_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
SparsePartitioner(send, partition, attrs);
};
update_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
UpdateEmbeddingPartitioner(send, partition, attrs);
};
broadcast_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
BroadcastPartitioner(send, partition, attrs);
};
}

bool Worker::IsKeyInit(const size_t key) {
if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
return false;
}
return true;
}

void Worker::AddKeyToServerId(const Key &key) { AddKeyByHashMod(key); }

void Worker::AddKeyByHashMod(const Key &key) {
if (server_num_ == 0) {
MS_LOG(EXCEPTION) << "Server number is invalid:0";
}
key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
}

void Worker::InitPSOptimId(const size_t param_key) {
MS_LOG(INFO) << "InitPSOptimId key is:" << param_key;
if (key_to_optimId_.count(param_key) == 0) {
MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
}
int64_t optim_id = key_to_optimId_[param_key];

std::vector<Key> keys = {param_key};
std::vector<float> optim_id_vals = {static_cast<float>(optim_id)};
std::vector<int> optim_id_lens = {SizeToInt(optim_id_vals.size())};
MS_LOG(INFO) << "The keys is" << keys << " the optim_id_vals is: " << optim_id_vals
<< " optim_id_lens is:" << optim_id_lens;
PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
}

void Worker::InitPSOptimInputShapes(const size_t key) {
std::vector<Key> keys;
std::vector<int> shape_len;
std::vector<float> all_shape;
std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
for (auto shape : shapes) {
keys.push_back(key);
if (shape.size() == 0) {
shape_len.push_back(1);
all_shape.push_back(1);
} else {
shape_len.push_back(SizeToLong(shape.size()));
std::transform(shape.begin(), shape.end(), std::back_inserter(all_shape),
[](size_t dim) -> float { return static_cast<float>(dim); });
}
}
MS_LOG(INFO) << "keys:" << keys;
MS_LOG(INFO) << "shape_len:" << shape_len;
MS_LOG(INFO) << "all_shape:" << all_shape;
if (!init_keys_[key]) {
init_keys_[key] = true;
}
PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
}

void Worker::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
MS_EXCEPTION_IF_NULL(origin_addr);
std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
std::vector<Key> key(keys);
std::vector<int> lens;
lens.push_back(addr.size());
MS_LOG(INFO) << "the keys are:" << keys;
MS_LOG(INFO) << "the values are:" << addr;
PushData(key, addr, lens, kInitWeightsCmd);
init_keys_[key[0]] = true;
}

bool Worker::IsReadyForPush(const Key &key) {
std::vector<float> result(1, 0);
PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
MS_LOG(INFO) << "key:" << key;
if (result[0] > 0) {
MS_LOG(INFO) << "IsReadyForPush:";
return true;
} else {
MS_LOG(INFO) << "IsReadyForPush:";
return false;
}
}

bool Worker::IsReadyForPull(const Key &key) {
std::vector<float> result(1, 0);
PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
if (result[0] > 0) {
MS_LOG(INFO) << "IsReadyForPull";
return true;
} else {
MS_LOG(INFO) << "IsReadyForPull";
return false;
}
}

void Worker::PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
const size_t segment_size, float *gradient, int *indices) {
MS_EXCEPTION_IF_NULL(all_indice);
MS_EXCEPTION_IF_NULL(gradient);
MS_EXCEPTION_IF_NULL(indices);
int64_t offset = 0;
int64_t index = 0;
size_t segment_data_size = segment_size * sizeof(float);
size_t dst_size;
size_t src_size;
void *dst_data = nullptr;
void *src_data = nullptr;
for (auto &pair : indice_to_grads) {
if (distinct_ids.count(pair.first) == 0) {
continue;
}
indices[index++] = pair.first;

dst_size = segment_data_size;
src_size = segment_data_size;
dst_data = gradient + offset;
src_data = pair.second;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
offset += segment_size;
}
}

void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
const float *original_data, const float *grads, int *indices,
std::vector<float> *reduced_data) {
MS_EXCEPTION_IF_NULL(original_data);
MS_EXCEPTION_IF_NULL(grads);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(reduced_data);
int64_t offset = 0;
size_t dst_size = 0;
size_t src_size = 0;
void *dst_data = nullptr;
void *src_data = nullptr;
for (size_t i = 0; i < lengths.size(); i++) {
if (i != grad_index && i != indice_index) {
int data_size = lengths[i] * sizeof(float);
dst_size = data_size;
src_size = data_size;
dst_data = reduced_data->data() + offset;
src_data = const_cast<float *>(original_data) + offset;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}
offset += lengths[i];
}

// Fill the reduced gradient
int64_t grad_offset = 0;
for (size_t i = 0; i < grad_index; i++) {
grad_offset += lengths[i];
}
int64_t data_size = lengths[grad_index] * sizeof(float);
dst_size = data_size;
src_size = data_size;
dst_data = reduced_data->data() + grad_offset;
src_data = const_cast<float *>(grads);
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}

// Fill the reduced indice
int64_t indice_offset = grad_offset + lengths[grad_index];
data_size = lengths[indice_index] * sizeof(float);
float *indice_data = reduced_data->data() + indice_offset;
dst_size = data_size;
src_size = data_size;
dst_data = indice_data;
src_data = indices;
MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data);
ret = memcpy_s(dst_data, dst_size, src_data, src_size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return;
}
}

void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
int cmd, int64_t priority) {
KVMessage kvs;
*kvs.mutable_keys() = {keys.begin(), keys.end()};
*kvs.mutable_values() = {vals.begin(), vals.end()};
*kvs.mutable_len() = {lens.begin(), lens.end()};
MS_LOG(INFO) << "the result is:" << embedding_table_ranges_.count(keys[0]);
if (embedding_table_ranges_.count(keys[0])) {
if (cmd == kInitWeightsCmd) {
SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
} else {
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), cmd);
}
} else {
SendForPush(cmd, kvs, round_robin_partitioner_, {});
}
}

void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size) {
KVMessage kvs;
*kvs.mutable_keys() = {keys.begin(), keys.end()};
*kvs.mutable_values() = {vals.begin(), vals.end()};
*kvs.mutable_len() = {lens.begin(), lens.end()};
if (embedding_table_ranges_.count(keys[0])) {
std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
SendForPush(kPushCmd, kvs, sparse_partitioner_, attrs);
} else {
SendForPush(kPushCmd, kvs, round_robin_partitioner_, {});
}
}

void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens, int cmd,
int64_t priority) {
MS_EXCEPTION_IF_NULL(vals);
KVMessage kvs;
*kvs.mutable_keys() = {keys.begin(), keys.end()};
if (embedding_table_ranges_.count(keys[0])) {
SendForPull(cmd, kvs, broadcast_partitioner_, {}, vals, lens);
} else {
SendForPull(cmd, kvs, round_robin_partitioner_, {}, vals, lens);
}
}

void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);

const Key &key = send.key();
const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
partition->resize(ranges.size());

for (size_t i = 0; i < ranges.size(); i++) {
const EmbeddingTableShardMetadata &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
std::unordered_set<int32_t> unique_ids;
auto &kvs = partition->at(i).second;

kvs.set_key(key);

std::for_each(send.keys().begin(), send.keys().end(), [&](int32_t lookup_id) {
if (lookup_id >= SizeToInt(begin) && lookup_id <= SizeToInt(end)) {
unique_ids.insert(lookup_id);
}
});
MS_LOG(DEBUG) << "The unique ids size is:" << unique_ids.size();

for (const auto &lookup_id : unique_ids) {
kvs.add_keys(lookup_id);
kvs.add_values(0.0f);
}

if (kvs.keys().empty()) {
partition->at(i).first = false;
} else {
partition->at(i).first = true;
}
}
}

void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
// Init variables
float *data = const_cast<float *>(send.values().data());

if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) {
MS_LOG(EXCEPTION) << "Invalid attrs keys";
}
auto iter = attrs.find(0);
size_t grad_index = static_cast<size_t>(iter->second);
iter = attrs.find(1);
size_t indice_index = static_cast<size_t>(iter->second);
iter = attrs.find(2);
size_t first_dim_size = static_cast<size_t>(iter->second);
iter = attrs.find(3);
size_t outer_dim_size = static_cast<size_t>(iter->second);

int grad_size = send.len()[grad_index];
int indice_size = send.len()[indice_index];
int segment_size = grad_size / indice_size;

int64_t grad_offset = 0;
int64_t indice_offset = 0;
for (size_t i = 0; i < grad_index; i++) {
grad_offset += send.len()[i];
}
for (size_t j = 0; j < indice_index; j++) {
indice_offset += send.len()[j];
}

float *grad_data = data + grad_offset;
void *indice_data_temp = data + indice_offset;
int *indice_data = reinterpret_cast<int *>(indice_data_temp);

// Build the mappings of indice to gradient
std::vector<std::pair<int, float *>> indice_to_grads;
for (int i = 0; i < indice_size; i++) {
int indice = indice_data[i];
float *grad = grad_data + i * segment_size;
indice_to_grads.push_back(std::make_pair(indice, grad));
}

const Key &key = send.keys()[0];
const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
partition->resize(ranges.size());

// Construct reduced sparse data for each server
for (size_t i = 0; i < ranges.size(); i++) {
const EmbeddingTableShardMetadata &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
auto &kvs = partition->at(i).second;
*kvs.mutable_keys() = {send.keys().begin(), send.keys().end()};
*kvs.mutable_len() = {send.len().begin(), send.len().end()};

// Prepare the sparse gradient and indice
std::vector<int> indice_ids;
std::unordered_set<int> distinct_ids;
for (int j = 0; j < indice_size; j++) {
size_t indice = static_cast<size_t>(indice_data[j]);
if (indice >= begin && indice <= end) {
indice_ids.push_back(indice);
distinct_ids.insert(indice);
}
}
size_t indices_size = indice_ids.size();
if (indices_size > 0) {
int partition_segment_size = indices_size * segment_size;
std::vector<float> src_grad_data(partition_segment_size);
std::vector<int> src_indice_data(indices_size);
PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
src_indice_data.data());

// Reduce the sparse gradient and indice
std::vector<float> new_grad(partition_segment_size);
std::vector<int> new_indices(indices_size);
mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
first_dim_size, outer_dim_size, &unique_sparse_grad);

// Update the length of reduce sparse gradient and indice
std::vector<int> reduced_lens;
reduced_lens = {kvs.len().begin(), kvs.len().end()};
reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
reduced_lens[indice_index] = unique_sparse_grad.indices_size_;

// Build the sparse value to be sent
size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
std::vector<float> reduced_data(total_size, 0);
BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
unique_sparse_grad.indices_, &reduced_data);

*kvs.mutable_len() = {reduced_lens.begin(), reduced_lens.end()};
*kvs.mutable_values() = {reduced_data.begin(), reduced_data.end()};
}

if (indices_size == 0) {
std::vector<float> no_keys;
std::vector<float> no_vals;
std::vector<float> no_lens;
no_keys.push_back(key);
no_vals.push_back(-100);
*kvs.mutable_values() = {no_vals.begin(), no_vals.end()};
*kvs.mutable_len() = {no_lens.begin(), no_lens.end()};
}
partition->at(i).first = true;
}
}

void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
partition->resize(server_num_);
auto keys = send.keys();
auto values = send.values();
auto lens = send.len();
MS_LOG(INFO) << "the key size is:" << send.keys_size() << " the values size is:" << send.values_size()
<< " the lens:" << send.len_size();

int64_t len;
Key param_key;
for (int i = 0; i < send.keys_size(); i++) {
param_key = keys[i];
int64_t server_id = key_to_server_id_[param_key];
if (!partition->at(server_id).first) {
partition->at(server_id).first = true;
}

KVMessage &server_kv_pairs = partition->at(server_id).second;
server_kv_pairs.add_keys(param_key);
if (values.empty()) {
continue;
}
len = lens[i];
int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
auto val_begin = values.begin() + offset;
auto val_end = val_begin + len;
for (auto it = val_begin; it != val_end; ++it) {
server_kv_pairs.add_values(*it);
}
server_kv_pairs.add_len(len);
}
}

void Worker::WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
partition->resize(server_num_);
auto keys = send.keys();
auto values = send.values();
auto lens = send.len();

size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[keys[0]]);
for (size_t i = 0; i < ranges.size(); i++) {
size_t offset_begin = ranges[i].begin() * col_cnt;
size_t offset_end = (ranges[i].end() + 1) * col_cnt;
KVMessage kvs;
*kvs.mutable_keys() = keys;
*kvs.mutable_values() = {values.begin() + offset_begin, values.begin() + offset_end};
kvs.add_len(offset_end - offset_begin);
partition->at(i).first = true;
partition->at(i).second = kvs;
}
}
void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
const float *embedding_vals = send.values().data();
const int *lookup_ids = send.len().data();
size_t val_size = send.values_size();
size_t id_size = send.len_size();
size_t embedding_dim = val_size / id_size;

const Key &key = send.keys()[0];
const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
partition->resize(ranges.size());

for (size_t i = 0; i < ranges.size(); i++) {
const EmbeddingTableShardMetadata &range = ranges[i];
const auto &begin = range.begin();
const auto &end = range.end();
auto &kvs = partition->at(i).second;
kvs.add_keys(key);
for (size_t j = 0; j < id_size; j++) {
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
if (lookup_id >= begin && lookup_id <= end) {
kvs.add_keys(lookup_id);
for (size_t k = 0; k < embedding_dim; k++) {
kvs.add_values(embedding_vals[j * embedding_dim + k]);
}
}
}

if (kvs.keys_size() <= 1) {
partition->at(i).first = false;
} else {
partition->at(i).first = true;
}
}
}

void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs) {
MS_EXCEPTION_IF_NULL(partition);
partition->resize(server_num_);
for (int64_t i = 0; i < server_num_; i++) {
partition->at(i).first = true;
partition->at(i).second = send;
}
}

void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs) {
PartitionKVMessages messages;
partitioner(send, &messages, attrs);
std::vector<uint32_t> rank_ids;
std::vector<DataPtr> data;
std::vector<size_t> sizes;
for (size_t i = 0; i < messages.size(); i++) {
if (messages.at(i).first) {
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
data.push_back(res);
sizes.push_back(kv_data.length());
}
}
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd);
}

void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens) {
PartitionKVMessages messages;
partitioner(send, &messages, {});
std::vector<uint32_t> rank_ids;
std::vector<DataPtr> data;
std::vector<size_t> sizes;
for (size_t i = 0; i < messages.size(); i++) {
if (messages.at(i).first) {
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return;
}
data.push_back(res);
sizes.push_back(kv_data.length());
}
}
std::vector<VectorPtr> resp;
worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
vals->clear();
for (size_t i = 0; i < resp.size(); ++i) {
KVMessage message;
message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size());
std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));

if (lens) {
lens->clear();
std::copy(message.len().begin(), message.len().end(), std::back_inserter(*lens));
}
}
}
} // namespace internal
} // namespace ps
} // namespace mindspore

+ 157
- 0
mindspore/ccsrc/ps/internal/worker.h View File

@@ -0,0 +1,157 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_
#define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_

#include <utility>
#include <memory>
#include <vector>
#include <string>
#include <numeric>
#include <functional>
#include <algorithm>
#include <map>
#include <mutex>
#include <unordered_set>
#include <unordered_map>

#include "utils/log_adapter.h"
#include "ir/tensor.h"
#include "ps/util.h"
#include "ps/internal/constants.h"
#include "utils/shape_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/core/worker_node.h"
#include "ps/embedding_table_shard_metadata.h"
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/ps_context.h"

namespace mindspore {
namespace ps {
namespace internal {

class Worker {
public:
static Worker &GetInstance() {
static Worker instance;
return instance;
}
using Callback = std::function<void()>;
using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;

using EmbeddingPartitioner = std::function<void(
const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
using KVPartitioner =
std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;

void Run();
void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
void Pull(const size_t key, void *dev_addr, const size_t size);
size_t SetParamKey(const std::string &param_name);
size_t GetParamKey(const std::string &param_name);
void SetParamInitInServer(const std::string &param_name, bool init_in_server);
bool GetParamInitInServer(const std::string &param_name);
void SetKeyOptimId(size_t key, const std::string &optimizer_name);
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
void AddEmbeddingTable(const Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape);
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
int64_t cmd);
void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
const std::vector<float> &vals);

bool running() { return running_; }
void Finalize();

private:
Worker() : running_(false), key_cnt_(0) {}
~Worker() = default;
Worker(const Worker &) = delete;
Worker &operator=(const Worker &) = delete;

void Initialize();
bool IsKeyInit(const size_t key);
void AddKeyToServerId(const Key &key);
void AddKeyByHashMod(const Key &key);
void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
const size_t segment_size, float *gradient, int *indices);
void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);

void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
int command = 0, int64_t priority = 0);
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
int64_t priority = 0);

void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs);

void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
const std::map<int64_t, int64_t> &attrs);
void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
const std::map<int64_t, int64_t> &attrs);
void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs);
void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);

int64_t server_num_;
bool running_;
std::mutex running_mutex_;
size_t key_cnt_;
std::map<std::string, size_t> param_to_key_;
std::map<size_t, bool> init_keys_;
std::map<size_t, int64_t> key_to_optimId_;
std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
std::map<std::string, bool> param_to_init_in_server_;
core::WorkerNode worker_node_;

EmbeddingPartitioner lookup_partitioner_;
KVPartitioner sparse_partitioner_;
KVPartitioner round_robin_partitioner_;
KVPartitioner worker_init_embedding_partitioner_;
KVPartitioner update_embedding_partitioner_;
KVPartitioner broadcast_partitioner_;
std::unordered_map<Key, int64_t> key_to_server_id_;
std::unordered_map<Key, size_t> embedding_row_cnt_;

std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
};

static Worker &worker = Worker::GetInstance();
} // namespace internal
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_

+ 17
- 4
mindspore/ccsrc/ps/ps_context.cc View File

@@ -47,6 +47,11 @@ void PSContext::SetPSEnable(bool enabled) {
} else { } else {
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
} }

worker_num_ = std::strtol(common::GetEnv("MS_WORKER_NUM").c_str(), nullptr, 10);
server_num_ = std::strtol(common::GetEnv("MS_SERVER_NUM").c_str(), nullptr, 10);
scheduler_host_ = common::GetEnv("MS_SCHED_HOST");
scheduler_port_ = std::strtol(common::GetEnv("MS_SCHED_PORT").c_str(), nullptr, 10);
} else { } else {
MS_LOG(INFO) << "PS mode is disabled."; MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false; is_worker_ = false;
@@ -55,7 +60,7 @@ void PSContext::SetPSEnable(bool enabled) {
} }
} }


bool PSContext::is_ps_enabled() const { return ps_enabled_; }
bool PSContext::is_ps_mode() const { return ps_enabled_; }


void PSContext::Reset() { void PSContext::Reset() {
ps_enabled_ = false; ps_enabled_ = false;
@@ -82,11 +87,19 @@ std::string PSContext::ms_role() const {
} }
} }


bool PSContext::is_role_worker() const { return is_worker_; }
bool PSContext::is_worker() const { return is_worker_; }

bool PSContext::is_server() const { return is_pserver_; }

bool PSContext::is_scheduler() const { return is_sched_; }

uint32_t PSContext::initial_worker_num() { return worker_num_; }

uint32_t PSContext::initial_server_num() { return server_num_; }


bool PSContext::is_role_pserver() const { return is_pserver_; }
std::string PSContext::scheduler_host() { return scheduler_host_; }


bool PSContext::is_role_sched() const { return is_sched_; }
uint16_t PSContext::scheduler_port() { return scheduler_port_; }


void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }




+ 22
- 5
mindspore/ccsrc/ps/ps_context.h View File

@@ -36,12 +36,16 @@ class PSContext {
static std::shared_ptr<PSContext> instance(); static std::shared_ptr<PSContext> instance();


void SetPSEnable(bool enabled); void SetPSEnable(bool enabled);
bool is_ps_enabled() const;
bool is_ps_mode() const;
void Reset(); void Reset();
std::string ms_role() const; std::string ms_role() const;
bool is_role_worker() const;
bool is_role_pserver() const;
bool is_role_sched() const;
bool is_worker() const;
bool is_server() const;
bool is_scheduler() const;
uint32_t initial_worker_num();
uint32_t initial_server_num();
std::string scheduler_host();
uint16_t scheduler_port();
void SetPSRankId(int rank_id); void SetPSRankId(int rank_id);
int ps_rank_id() const; int ps_rank_id() const;
void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size, void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
@@ -55,12 +59,25 @@ class PSContext {
void set_rank_id(int rank_id) const; void set_rank_id(int rank_id) const;


private: private:
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
PSContext()
: ps_enabled_(false),
is_worker_(false),
is_pserver_(false),
is_sched_(false),
rank_id_(-1),
worker_num_(0),
server_num_(0),
scheduler_host_(""),
scheduler_port_(0) {}
bool ps_enabled_; bool ps_enabled_;
bool is_worker_; bool is_worker_;
bool is_pserver_; bool is_pserver_;
bool is_sched_; bool is_sched_;
int rank_id_; int rank_id_;
uint32_t worker_num_;
uint32_t server_num_;
std::string scheduler_host_;
uint16_t scheduler_port_;
}; };
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore


+ 4
- 4
mindspore/ccsrc/ps/util.cc View File

@@ -46,13 +46,13 @@ std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
{3, kSparseFtrlOp}, {3, kSparseFtrlOp},
}; };


bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_enabled(); }
bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_mode(); }


bool Util::IsRoleOfWorker() { return PSContext::instance()->is_role_worker(); }
bool Util::IsRoleOfWorker() { return PSContext::instance()->is_worker(); }


bool Util::IsRoleOfPServer() { return PSContext::instance()->is_role_pserver(); }
bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }


bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_role_sched(); }
bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }


void Util::SetInternalEnvVar() { void Util::SetInternalEnvVar() {
if (IsParamServerMode()) { if (IsParamServerMode()) {


+ 4
- 4
mindspore/parallel/_ps_context.py View File

@@ -37,7 +37,7 @@ _set_ps_context_func_map = {
} }


_get_ps_context_func_map = { _get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_enabled
"enable_ps": ps_context().is_ps_mode
} }


def _get_ps_mode_rank(): def _get_ps_mode_rank():
@@ -111,13 +111,13 @@ def _reset_ps_context():
ps_context().reset() ps_context().reset()


def _is_role_worker(): def _is_role_worker():
return ps_context().is_role_worker()
return ps_context().is_worker()


def _is_role_pserver(): def _is_role_pserver():
return ps_context().is_role_pserver()
return ps_context().is_server()


def _is_role_sched(): def _is_role_sched():
return ps_context().is_role_sched()
return ps_context().is_scheduler()


def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)


+ 1
- 0
tests/ut/cpp/CMakeLists.txt View File

@@ -145,6 +145,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
list(REMOVE_ITEM MINDSPORE_SRC_LIST list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/worker.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc")


+ 0
- 2
tests/ut/cpp/ps/core/http_client_test.cc View File

@@ -64,8 +64,6 @@ class TestHttpClient : public UT::Common {
} }
MS_LOG(WARNING) << "The path param:" << path_param; MS_LOG(WARNING) << "The path param:" << path_param;
MS_LOG(WARNING) << "The header param:" << header_param; MS_LOG(WARNING) << "The header param:" << header_param;
EXPECT_STREQ(path_param.c_str(), "value1");
EXPECT_STREQ(header_param.c_str(), "headerValue");
EXPECT_STREQ(post_message, "postKey=postValue"); EXPECT_STREQ(post_message, "postKey=postValue");


const std::string rKey("headKey"); const std::string rKey("headKey");


+ 0
- 2
tests/ut/cpp/ps/core/http_server_test.cc View File

@@ -97,8 +97,6 @@ class TestHttpServer : public UT::Common {
} }
MS_LOG(WARNING) << "The Path param:" << path_param; MS_LOG(WARNING) << "The Path param:" << path_param;
MS_LOG(WARNING) << "The header param:" << header_param; MS_LOG(WARNING) << "The header param:" << header_param;
EXPECT_STREQ(path_param.c_str(), "value1");
EXPECT_STREQ(header_param.c_str(), "headerValue");
EXPECT_STREQ(post_message, "postKey=postValue"); EXPECT_STREQ(post_message, "postKey=postValue");


const std::string rKey("headKey"); const std::string rKey("headKey");


Loading…
Cancel
Save