From: @anancds Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -67,7 +67,7 @@ constexpr int64_t kPullCmd = 51; | |||
| constexpr size_t kInvalidKey = UINT64_MAX; | |||
| constexpr int64_t kInvalidID = -1; | |||
| using DataPtr = std::shared_ptr<unsigned char>; | |||
| using DataPtr = std::shared_ptr<unsigned char[]>; | |||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | |||
| using Key = uint64_t; | |||
| using Keys = std::vector<Key>; | |||
| @@ -281,7 +281,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||
| if (!Heartbeat(client)) { | |||
| MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| if (!CheckSchedulerTimeout() && on_node_event_message_) { | |||
| if (CheckSchedulerTimeout() && on_node_event_message_) { | |||
| MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; | |||
| is_finish_ = true; | |||
| @@ -294,6 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); | |||
| } | |||
| }); | |||
| heart_beat_thread_->detach(); | |||
| } | |||
| bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | |||
| @@ -307,6 +308,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n | |||
| if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), | |||
| heartbeat_message.ByteSizeLong())) { | |||
| MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -315,9 +317,7 @@ void AbstractNode::UpdateSchedulerTime() { | |||
| struct timeval current_time {}; | |||
| (void)gettimeofday(¤t_time, nullptr); | |||
| scheduler_time_ = current_time; | |||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||
| << " update scheduler time, the current time is: " << current_time.tv_sec; | |||
| MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec; | |||
| } | |||
| bool AbstractNode::CheckSchedulerTimeout() const { | |||
| @@ -430,10 +430,13 @@ bool AbstractNode::InitClientToScheduler() { | |||
| MS_LOG(INFO) << "The node start a tcp client!"; | |||
| client_to_scheduler_->Start(); | |||
| }); | |||
| client_to_scheduler_thread_->detach(); | |||
| client_to_scheduler_->set_disconnected_callback([&]() { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval())); | |||
| client_to_scheduler_->Init(); | |||
| if (is_ready_.load() == false) { | |||
| client_to_scheduler_->Init(); | |||
| } | |||
| }); | |||
| return client_to_scheduler_->WaitConnected(); | |||
| } | |||
| @@ -37,7 +37,7 @@ class AbstractNode : public Node { | |||
| typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | |||
| using DataPtr = std::shared_ptr<unsigned char>; | |||
| using DataPtr = std::shared_ptr<unsigned char[]>; | |||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | |||
| bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, | |||
| @@ -62,7 +62,7 @@ class ClusterMetadata { | |||
| heartbeat_timeout_(30), | |||
| cluster_available_timeout_(300), | |||
| connect_interval_(100), | |||
| scheduler_timeout_(3600 * 5) {} | |||
| scheduler_timeout_(30) {} | |||
| uint32_t worker_num_; | |||
| uint32_t server_num_; | |||
| // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. | |||
| @@ -25,7 +25,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; | |||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 }; | |||
| struct NodeInfo { | |||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | |||
| @@ -105,7 +105,7 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share | |||
| MS_EXCEPTION_IF_NULL(conn); | |||
| MS_EXCEPTION_IF_NULL(meta); | |||
| MS_EXCEPTION_IF_NULL(data); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[size]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[size]); | |||
| int ret = memcpy_s(res.get(), size, data, size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -131,14 +131,18 @@ bool ServerNode::Stop() { | |||
| if (!is_already_stopped_.load()) { | |||
| is_already_stopped_ = true; | |||
| is_finish_ = true; | |||
| heart_beat_thread_->join(); | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_thread_->join(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| server_->Stop(); | |||
| server_thread_->join(); | |||
| } | |||
| @@ -311,7 +311,8 @@ bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &pro | |||
| } | |||
| int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH); | |||
| if (result < 0) { | |||
| MS_LOG(EXCEPTION) << "Bufferevent flush failed!"; | |||
| MS_LOG(ERROR) << "Bufferevent flush failed!"; | |||
| res = false; | |||
| } | |||
| bufferevent_unlock(buffer_event_); | |||
| return res; | |||
| @@ -63,14 +63,18 @@ bool WorkerNode::Stop() { | |||
| is_ready_ = true; | |||
| is_timeout_ = true; | |||
| is_finish_ = true; | |||
| heart_beat_thread_->join(); | |||
| if (heart_beat_thread_->joinable()) { | |||
| heart_beat_thread_->join(); | |||
| } | |||
| client_to_scheduler_->Stop(); | |||
| if (!connected_nodes_.empty()) { | |||
| for (auto &connected_node : connected_nodes_) { | |||
| connected_node.second->Stop(); | |||
| } | |||
| } | |||
| client_to_scheduler_thread_->join(); | |||
| if (client_to_scheduler_thread_->joinable()) { | |||
| client_to_scheduler_thread_->join(); | |||
| } | |||
| is_already_stopped_ = true; | |||
| } | |||
| return true; | |||
| @@ -21,6 +21,8 @@ namespace ps { | |||
| void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; | |||
| server_node_ = std::make_shared<core::ServerNode>(); | |||
| core::ClusterMetadata::instance()->Init( | |||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | |||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | |||
| @@ -30,14 +32,14 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||
| return; | |||
| } | |||
| Init(func_graph); | |||
| server_node_.Start(); | |||
| rank_id_ = server_node_.rank_id(); | |||
| server_node_->Start(); | |||
| rank_id_ = server_node_->rank_id(); | |||
| PSContext::instance()->SetPSRankId(rank_id_); | |||
| thread_->join(); | |||
| SyncEmbeddingTables(); | |||
| MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; | |||
| server_node_.Finish(); | |||
| server_node_.Stop(); | |||
| server_node_->Finish(); | |||
| server_node_->Stop(); | |||
| MS_LOG(INFO) << "PServer finalized successfully."; | |||
| } | |||
| @@ -49,7 +51,14 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { | |||
| handler_->Init(); | |||
| InitOptimInfoBuilders(); | |||
| server_node_.set_handler(*handler_); | |||
| server_node_->set_handler(*handler_); | |||
| server_node_->set_event_callback([&](const core::NodeEvent &event) { | |||
| if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || | |||
| (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { | |||
| MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; | |||
| Finalize(); | |||
| } | |||
| }); | |||
| thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); | |||
| GetEmbeddingTableParamPtr(); | |||
| return true; | |||
| @@ -496,7 +505,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect | |||
| auto &handler_ptr = handlers_[meta->user_cmd()]; | |||
| (this->*handler_ptr)(data, size, output); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[output->size()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[output->size()]); | |||
| MS_LOG(DEBUG) << "The output size is:" << output->size(); | |||
| if (output->size() > 0) { | |||
| int ret = memcpy_s(res.get(), output->size(), output->data(), output->size()); | |||
| @@ -505,7 +514,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect | |||
| } | |||
| } | |||
| ps_->server_node_.Response(conn, meta, res, output->size()); | |||
| ps_->server_node_->Response(conn, meta, res, output->size()); | |||
| MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:" | |||
| << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now()) | |||
| .time_since_epoch() | |||
| @@ -682,6 +691,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t | |||
| *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()}; | |||
| ps_->DoEmbeddingLookup(key, keys, &res_data); | |||
| res->resize(res_data.ByteSizeLong()); | |||
| int ret = | |||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | |||
| @@ -59,6 +59,7 @@ | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/server_node.h" | |||
| #include "ps/core/node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -82,7 +83,8 @@ class ParameterServer { | |||
| func_graph_(nullptr), | |||
| sess_(nullptr), | |||
| running_(true), | |||
| thread_(nullptr) {} | |||
| thread_(nullptr), | |||
| server_node_(nullptr) {} | |||
| ~ParameterServer() = default; | |||
| ParameterServer(const ParameterServer &) = delete; | |||
| ParameterServer &operator=(const ParameterServer &) = delete; | |||
| @@ -167,7 +169,7 @@ class ParameterServer { | |||
| std::condition_variable apply_grads_cv_; | |||
| std::unique_ptr<std::thread> thread_; | |||
| core::ServerNode server_node_; | |||
| std::shared_ptr<core::ServerNode> server_node_; | |||
| std::map<Key, ParameterPtr> embedding_tables_; | |||
| friend class ServerHandler; | |||
| @@ -15,11 +15,13 @@ | |||
| */ | |||
| #include "ps/worker.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| void Worker::Run() { | |||
| std::lock_guard<std::mutex> lock(running_mutex_); | |||
| core::ClusterMetadata::instance()->Init( | |||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | |||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | |||
| @@ -33,6 +35,14 @@ void Worker::Run() { | |||
| } | |||
| Initialize(); | |||
| worker_node_.set_event_callback([&](const core::NodeEvent &event) { | |||
| if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || | |||
| (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { | |||
| MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; | |||
| Finalize(); | |||
| exit(0); | |||
| } | |||
| }); | |||
| MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; | |||
| worker_node_.Start(); | |||
| MS_LOG(INFO) << "Worker connected successfully."; | |||
| @@ -86,7 +96,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, | |||
| } | |||
| MS_LOG(INFO) << "The total size is:" << total_size; | |||
| while (!IsReadyForPush(keys[0])) { | |||
| while (running_ && (!IsReadyForPush(keys[0]))) { | |||
| continue; | |||
| } | |||
| std::vector<int> sizes_int; | |||
| @@ -109,7 +119,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, | |||
| void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dev_addr); | |||
| std::vector<float> variables(size / sizeof(float), 0); | |||
| while (!IsReadyForPull(key)) { | |||
| while (running_ && (!IsReadyForPull(key))) { | |||
| continue; | |||
| } | |||
| PullData({key}, &variables, nullptr, kPullCmd); | |||
| @@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> & | |||
| std::string kv_data = embedding_table_meta.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||
| rank_ids.push_back(i); | |||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||
| for (auto j = 0; j < message.values_size(); j++) { | |||
| values->push_back(message.values(j)); | |||
| } | |||
| MS_LOG(DEBUG) << "The embedding resp:" << values; | |||
| MS_LOG(DEBUG) << "The embedding resp:" << *values; | |||
| for (auto k = 0; k < message.keys_size(); k++) { | |||
| const Key &key = message.keys(k); | |||
| float *addr = values->data() + value_offset; | |||
| @@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto | |||
| rank_ids.push_back(i); | |||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -378,7 +388,7 @@ void Worker::Finalize() { | |||
| kvs.add_keys(0); | |||
| kvs.add_values(0.0f); | |||
| std::string kv_data = kvs.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -619,7 +629,7 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va | |||
| SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {}); | |||
| } else { | |||
| std::string kv_data = kvs.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa | |||
| rank_ids.push_back(i); | |||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| @@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa | |||
| rank_ids.push_back(i); | |||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | |||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||