Browse Source

!12929 added event callback

From: @anancds
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
37560318ef
11 changed files with 68 additions and 34 deletions
  1. +1
    -1
      mindspore/ccsrc/ps/constants.h
  2. +8
    -5
      mindspore/ccsrc/ps/core/abstract_node.cc
  3. +1
    -1
      mindspore/ccsrc/ps/core/abstract_node.h
  4. +1
    -1
      mindspore/ccsrc/ps/core/cluster_metadata.h
  5. +1
    -1
      mindspore/ccsrc/ps/core/node_info.h
  6. +7
    -3
      mindspore/ccsrc/ps/core/server_node.cc
  7. +2
    -1
      mindspore/ccsrc/ps/core/tcp_client.cc
  8. +6
    -2
      mindspore/ccsrc/ps/core/worker_node.cc
  9. +17
    -7
      mindspore/ccsrc/ps/parameter_server.cc
  10. +4
    -2
      mindspore/ccsrc/ps/parameter_server.h
  11. +20
    -10
      mindspore/ccsrc/ps/worker.cc

+ 1
- 1
mindspore/ccsrc/ps/constants.h View File

@@ -67,7 +67,7 @@ constexpr int64_t kPullCmd = 51;
constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1;

using DataPtr = std::shared_ptr<unsigned char>;
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;
using Keys = std::vector<Key>;


+ 8
- 5
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -281,7 +281,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
if (!Heartbeat(client)) {
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
if (!CheckSchedulerTimeout() && on_node_event_message_) {
if (CheckSchedulerTimeout() && on_node_event_message_) {
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
is_finish_ = true;
@@ -294,6 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval()));
}
});
heart_beat_thread_->detach();
}

bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) {
@@ -307,6 +308,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong())) {
MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
return false;
}
return true;
}
@@ -315,9 +317,7 @@ void AbstractNode::UpdateSchedulerTime() {
struct timeval current_time {};
(void)gettimeofday(&current_time, nullptr);
scheduler_time_ = current_time;
MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_
<< " update scheduler time, the current time is: " << current_time.tv_sec;
MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec;
}

bool AbstractNode::CheckSchedulerTimeout() const {
@@ -430,10 +430,13 @@ bool AbstractNode::InitClientToScheduler() {
MS_LOG(INFO) << "The node start a tcp client!";
client_to_scheduler_->Start();
});
client_to_scheduler_thread_->detach();

client_to_scheduler_->set_disconnected_callback([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval()));
client_to_scheduler_->Init();
if (is_ready_.load() == false) {
client_to_scheduler_->Init();
}
});
return client_to_scheduler_->WaitConnected();
}


+ 1
- 1
mindspore/ccsrc/ps/core/abstract_node.h View File

@@ -37,7 +37,7 @@ class AbstractNode : public Node {

typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size);

using DataPtr = std::shared_ptr<unsigned char>;
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;

bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command,


+ 1
- 1
mindspore/ccsrc/ps/core/cluster_metadata.h View File

@@ -62,7 +62,7 @@ class ClusterMetadata {
heartbeat_timeout_(30),
cluster_available_timeout_(300),
connect_interval_(100),
scheduler_timeout_(3600 * 5) {}
scheduler_timeout_(30) {}
uint32_t worker_num_;
uint32_t server_num_;
// The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds.


+ 1
- 1
mindspore/ccsrc/ps/core/node_info.h View File

@@ -25,7 +25,7 @@
namespace mindspore {
namespace ps {
namespace core {
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT };
enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 };

struct NodeInfo {
NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {}


+ 7
- 3
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -105,7 +105,7 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char> res(new unsigned char[size]);
std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
int ret = memcpy_s(res.get(), size, data, size);
if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
@@ -131,14 +131,18 @@ bool ServerNode::Stop() {
if (!is_already_stopped_.load()) {
is_already_stopped_ = true;
is_finish_ = true;
heart_beat_thread_->join();
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
server_->Stop();
server_thread_->join();
}


+ 2
- 1
mindspore/ccsrc/ps/core/tcp_client.cc View File

@@ -311,7 +311,8 @@ bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &pro
}
int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH);
if (result < 0) {
MS_LOG(EXCEPTION) << "Bufferevent flush failed!";
MS_LOG(ERROR) << "Bufferevent flush failed!";
res = false;
}
bufferevent_unlock(buffer_event_);
return res;


+ 6
- 2
mindspore/ccsrc/ps/core/worker_node.cc View File

@@ -63,14 +63,18 @@ bool WorkerNode::Stop() {
is_ready_ = true;
is_timeout_ = true;
is_finish_ = true;
heart_beat_thread_->join();
if (heart_beat_thread_->joinable()) {
heart_beat_thread_->join();
}
client_to_scheduler_->Stop();
if (!connected_nodes_.empty()) {
for (auto &connected_node : connected_nodes_) {
connected_node.second->Stop();
}
}
client_to_scheduler_thread_->join();
if (client_to_scheduler_thread_->joinable()) {
client_to_scheduler_thread_->join();
}
is_already_stopped_ = true;
}
return true;


+ 17
- 7
mindspore/ccsrc/ps/parameter_server.cc View File

@@ -21,6 +21,8 @@ namespace ps {
void ParameterServer::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
server_node_ = std::make_shared<core::ServerNode>();

core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
@@ -30,14 +32,14 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) {
return;
}
Init(func_graph);
server_node_.Start();
rank_id_ = server_node_.rank_id();
server_node_->Start();
rank_id_ = server_node_->rank_id();
PSContext::instance()->SetPSRankId(rank_id_);
thread_->join();
SyncEmbeddingTables();
MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
server_node_.Finish();
server_node_.Stop();
server_node_->Finish();
server_node_->Stop();
MS_LOG(INFO) << "PServer finalized successfully.";
}

@@ -49,7 +51,14 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
handler_->Init();

InitOptimInfoBuilders();
server_node_.set_handler(*handler_);
server_node_->set_handler(*handler_);
server_node_->set_event_callback([&](const core::NodeEvent &event) {
if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
(event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
Finalize();
}
});
thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
GetEmbeddingTableParamPtr();
return true;
@@ -496,7 +505,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect

auto &handler_ptr = handlers_[meta->user_cmd()];
(this->*handler_ptr)(data, size, output);
std::shared_ptr<unsigned char> res(new unsigned char[output->size()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[output->size()]);
MS_LOG(DEBUG) << "The output size is:" << output->size();
if (output->size() > 0) {
int ret = memcpy_s(res.get(), output->size(), output->data(), output->size());
@@ -505,7 +514,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect
}
}

ps_->server_node_.Response(conn, meta, res, output->size());
ps_->server_node_->Response(conn, meta, res, output->size());
MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:"
<< std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
.time_since_epoch()
@@ -682,6 +691,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
*res_data.mutable_keys() = {input.keys().begin(), input.keys().end()};

ps_->DoEmbeddingLookup(key, keys, &res_data);

res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());


+ 4
- 2
mindspore/ccsrc/ps/parameter_server.h View File

@@ -59,6 +59,7 @@
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
#include "ps/core/server_node.h"
#include "ps/core/node.h"

namespace mindspore {
namespace ps {
@@ -82,7 +83,8 @@ class ParameterServer {
func_graph_(nullptr),
sess_(nullptr),
running_(true),
thread_(nullptr) {}
thread_(nullptr),
server_node_(nullptr) {}
~ParameterServer() = default;
ParameterServer(const ParameterServer &) = delete;
ParameterServer &operator=(const ParameterServer &) = delete;
@@ -167,7 +169,7 @@ class ParameterServer {
std::condition_variable apply_grads_cv_;

std::unique_ptr<std::thread> thread_;
core::ServerNode server_node_;
std::shared_ptr<core::ServerNode> server_node_;
std::map<Key, ParameterPtr> embedding_tables_;

friend class ServerHandler;


+ 20
- 10
mindspore/ccsrc/ps/worker.cc View File

@@ -15,11 +15,13 @@
*/

#include "ps/worker.h"
#include "pipeline/jit/pipeline.h"

namespace mindspore {
namespace ps {
void Worker::Run() {
std::lock_guard<std::mutex> lock(running_mutex_);

core::ClusterMetadata::instance()->Init(
PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(),
PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port());
@@ -33,6 +35,14 @@ void Worker::Run() {
}

Initialize();
worker_node_.set_event_callback([&](const core::NodeEvent &event) {
if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
(event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
Finalize();
exit(0);
}
});
MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
worker_node_.Start();
MS_LOG(INFO) << "Worker connected successfully.";
@@ -86,7 +96,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
}
MS_LOG(INFO) << "The total size is:" << total_size;

while (!IsReadyForPush(keys[0])) {
while (running_ && (!IsReadyForPush(keys[0]))) {
continue;
}
std::vector<int> sizes_int;
@@ -109,7 +119,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
MS_EXCEPTION_IF_NULL(dev_addr);
std::vector<float> variables(size / sizeof(float), 0);
while (!IsReadyForPull(key)) {
while (running_ && (!IsReadyForPull(key))) {
continue;
}
PullData({key}, &variables, nullptr, kPullCmd);
@@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &

std::string kv_data = embedding_table_meta.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
for (auto j = 0; j < message.values_size(); j++) {
values->push_back(message.values(j));
}
MS_LOG(DEBUG) << "The embedding resp:" << values;
MS_LOG(DEBUG) << "The embedding resp:" << *values;
for (auto k = 0; k < message.keys_size(); k++) {
const Key &key = message.keys(k);
float *addr = values->data() + value_offset;
@@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -378,7 +388,7 @@ void Worker::Finalize() {
kvs.add_keys(0);
kvs.add_values(0.0f);
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -619,7 +629,7 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va
SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
} else {
std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
@@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
rank_ids.push_back(i);
std::string kv_data = messages.at(i).second.SerializeAsString();

std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]);
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";


Loading…
Cancel
Save