From: @anancds Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -67,7 +67,7 @@ constexpr int64_t kPullCmd = 51; | |||||
| constexpr size_t kInvalidKey = UINT64_MAX; | constexpr size_t kInvalidKey = UINT64_MAX; | ||||
| constexpr int64_t kInvalidID = -1; | constexpr int64_t kInvalidID = -1; | ||||
| using DataPtr = std::shared_ptr<unsigned char>; | |||||
| using DataPtr = std::shared_ptr<unsigned char[]>; | |||||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | ||||
| using Key = uint64_t; | using Key = uint64_t; | ||||
| using Keys = std::vector<Key>; | using Keys = std::vector<Key>; | ||||
| @@ -281,7 +281,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||||
| if (!Heartbeat(client)) { | if (!Heartbeat(client)) { | ||||
| MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | ||||
| << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; | << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; | ||||
| if (!CheckSchedulerTimeout() && on_node_event_message_) { | |||||
| if (CheckSchedulerTimeout() && on_node_event_message_) { | |||||
| MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) | ||||
| << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; | << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; | ||||
| is_finish_ = true; | is_finish_ = true; | ||||
| @@ -294,6 +294,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client) | |||||
| std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); | std::this_thread::sleep_for(std::chrono::seconds(ClusterMetadata::instance()->heartbeat_interval())); | ||||
| } | } | ||||
| }); | }); | ||||
| heart_beat_thread_->detach(); | |||||
| } | } | ||||
| bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_node_finish) { | ||||
| @@ -307,6 +308,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client, bool is_n | |||||
| if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), | if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), | ||||
| heartbeat_message.ByteSizeLong())) { | heartbeat_message.ByteSizeLong())) { | ||||
| MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; | ||||
| return false; | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -315,9 +317,7 @@ void AbstractNode::UpdateSchedulerTime() { | |||||
| struct timeval current_time {}; | struct timeval current_time {}; | ||||
| (void)gettimeofday(¤t_time, nullptr); | (void)gettimeofday(¤t_time, nullptr); | ||||
| scheduler_time_ = current_time; | scheduler_time_ = current_time; | ||||
| MS_LOG(DEBUG) << "The node role: " << CommUtil::NodeRoleToString(node_info_.node_role_) | |||||
| << ", the node id:" << node_info_.node_id_ << ", the node rank id:" << node_info_.rank_id_ | |||||
| << " update scheduler time, the current time is: " << current_time.tv_sec; | |||||
| MS_LOG(DEBUG) << "Update scheduler time, the current time is: " << current_time.tv_sec; | |||||
| } | } | ||||
| bool AbstractNode::CheckSchedulerTimeout() const { | bool AbstractNode::CheckSchedulerTimeout() const { | ||||
| @@ -430,10 +430,13 @@ bool AbstractNode::InitClientToScheduler() { | |||||
| MS_LOG(INFO) << "The node start a tcp client!"; | MS_LOG(INFO) << "The node start a tcp client!"; | ||||
| client_to_scheduler_->Start(); | client_to_scheduler_->Start(); | ||||
| }); | }); | ||||
| client_to_scheduler_thread_->detach(); | |||||
| client_to_scheduler_->set_disconnected_callback([&]() { | client_to_scheduler_->set_disconnected_callback([&]() { | ||||
| std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval())); | std::this_thread::sleep_for(std::chrono::milliseconds(ClusterMetadata::instance()->connect_interval())); | ||||
| client_to_scheduler_->Init(); | |||||
| if (is_ready_.load() == false) { | |||||
| client_to_scheduler_->Init(); | |||||
| } | |||||
| }); | }); | ||||
| return client_to_scheduler_->WaitConnected(); | return client_to_scheduler_->WaitConnected(); | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ class AbstractNode : public Node { | |||||
| typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | typedef void (AbstractNode::*ResponseHandler)(std::shared_ptr<MessageMeta> meta, const void *data, size_t size); | ||||
| using DataPtr = std::shared_ptr<unsigned char>; | |||||
| using DataPtr = std::shared_ptr<unsigned char[]>; | |||||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | ||||
| bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, | bool Broadcast(const enum NodeRole &node_role, const DataPtr &message, size_t size, int command, | ||||
| @@ -62,7 +62,7 @@ class ClusterMetadata { | |||||
| heartbeat_timeout_(30), | heartbeat_timeout_(30), | ||||
| cluster_available_timeout_(300), | cluster_available_timeout_(300), | ||||
| connect_interval_(100), | connect_interval_(100), | ||||
| scheduler_timeout_(3600 * 5) {} | |||||
| scheduler_timeout_(30) {} | |||||
| uint32_t worker_num_; | uint32_t worker_num_; | ||||
| uint32_t server_num_; | uint32_t server_num_; | ||||
| // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. | // The interval for sending heartbeat packets between worker node,server node and scheduler node is 3 seconds. | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT }; | |||||
| enum NodeEvent { CLUSTER_TIMEOUT = 0, NODE_TIMEOUT = 1, SCHEDULER_TIMEOUT = 2 }; | |||||
| struct NodeInfo { | struct NodeInfo { | ||||
| NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | NodeInfo() : port_(0), node_role_(NodeRole::SCHEDULER), rank_id_(0) {} | ||||
| @@ -105,7 +105,7 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share | |||||
| MS_EXCEPTION_IF_NULL(conn); | MS_EXCEPTION_IF_NULL(conn); | ||||
| MS_EXCEPTION_IF_NULL(meta); | MS_EXCEPTION_IF_NULL(meta); | ||||
| MS_EXCEPTION_IF_NULL(data); | MS_EXCEPTION_IF_NULL(data); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[size]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[size]); | |||||
| int ret = memcpy_s(res.get(), size, data, size); | int ret = memcpy_s(res.get(), size, data, size); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -131,14 +131,18 @@ bool ServerNode::Stop() { | |||||
| if (!is_already_stopped_.load()) { | if (!is_already_stopped_.load()) { | ||||
| is_already_stopped_ = true; | is_already_stopped_ = true; | ||||
| is_finish_ = true; | is_finish_ = true; | ||||
| heart_beat_thread_->join(); | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| client_to_scheduler_->Stop(); | client_to_scheduler_->Stop(); | ||||
| if (!connected_nodes_.empty()) { | if (!connected_nodes_.empty()) { | ||||
| for (auto &connected_node : connected_nodes_) { | for (auto &connected_node : connected_nodes_) { | ||||
| connected_node.second->Stop(); | connected_node.second->Stop(); | ||||
| } | } | ||||
| } | } | ||||
| client_to_scheduler_thread_->join(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| server_->Stop(); | server_->Stop(); | ||||
| server_thread_->join(); | server_thread_->join(); | ||||
| } | } | ||||
| @@ -311,7 +311,8 @@ bool TcpClient::SendMessage(std::shared_ptr<MessageMeta> meta, const Protos &pro | |||||
| } | } | ||||
| int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH); | int result = bufferevent_flush(buffer_event_, EV_READ | EV_WRITE, BEV_FLUSH); | ||||
| if (result < 0) { | if (result < 0) { | ||||
| MS_LOG(EXCEPTION) << "Bufferevent flush failed!"; | |||||
| MS_LOG(ERROR) << "Bufferevent flush failed!"; | |||||
| res = false; | |||||
| } | } | ||||
| bufferevent_unlock(buffer_event_); | bufferevent_unlock(buffer_event_); | ||||
| return res; | return res; | ||||
| @@ -63,14 +63,18 @@ bool WorkerNode::Stop() { | |||||
| is_ready_ = true; | is_ready_ = true; | ||||
| is_timeout_ = true; | is_timeout_ = true; | ||||
| is_finish_ = true; | is_finish_ = true; | ||||
| heart_beat_thread_->join(); | |||||
| if (heart_beat_thread_->joinable()) { | |||||
| heart_beat_thread_->join(); | |||||
| } | |||||
| client_to_scheduler_->Stop(); | client_to_scheduler_->Stop(); | ||||
| if (!connected_nodes_.empty()) { | if (!connected_nodes_.empty()) { | ||||
| for (auto &connected_node : connected_nodes_) { | for (auto &connected_node : connected_nodes_) { | ||||
| connected_node.second->Stop(); | connected_node.second->Stop(); | ||||
| } | } | ||||
| } | } | ||||
| client_to_scheduler_thread_->join(); | |||||
| if (client_to_scheduler_thread_->joinable()) { | |||||
| client_to_scheduler_thread_->join(); | |||||
| } | |||||
| is_already_stopped_ = true; | is_already_stopped_ = true; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -21,6 +21,8 @@ namespace ps { | |||||
| void ParameterServer::Run(const FuncGraphPtr &func_graph) { | void ParameterServer::Run(const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; | MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; | ||||
| server_node_ = std::make_shared<core::ServerNode>(); | |||||
| core::ClusterMetadata::instance()->Init( | core::ClusterMetadata::instance()->Init( | ||||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | ||||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | ||||
| @@ -30,14 +32,14 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||||
| return; | return; | ||||
| } | } | ||||
| Init(func_graph); | Init(func_graph); | ||||
| server_node_.Start(); | |||||
| rank_id_ = server_node_.rank_id(); | |||||
| server_node_->Start(); | |||||
| rank_id_ = server_node_->rank_id(); | |||||
| PSContext::instance()->SetPSRankId(rank_id_); | PSContext::instance()->SetPSRankId(rank_id_); | ||||
| thread_->join(); | thread_->join(); | ||||
| SyncEmbeddingTables(); | SyncEmbeddingTables(); | ||||
| MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; | MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; | ||||
| server_node_.Finish(); | |||||
| server_node_.Stop(); | |||||
| server_node_->Finish(); | |||||
| server_node_->Stop(); | |||||
| MS_LOG(INFO) << "PServer finalized successfully."; | MS_LOG(INFO) << "PServer finalized successfully."; | ||||
| } | } | ||||
| @@ -49,7 +51,14 @@ bool ParameterServer::Init(const FuncGraphPtr &func_graph) { | |||||
| handler_->Init(); | handler_->Init(); | ||||
| InitOptimInfoBuilders(); | InitOptimInfoBuilders(); | ||||
| server_node_.set_handler(*handler_); | |||||
| server_node_->set_handler(*handler_); | |||||
| server_node_->set_event_callback([&](const core::NodeEvent &event) { | |||||
| if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || | |||||
| (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { | |||||
| MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; | |||||
| Finalize(); | |||||
| } | |||||
| }); | |||||
| thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); | thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); | ||||
| GetEmbeddingTableParamPtr(); | GetEmbeddingTableParamPtr(); | ||||
| return true; | return true; | ||||
| @@ -496,7 +505,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect | |||||
| auto &handler_ptr = handlers_[meta->user_cmd()]; | auto &handler_ptr = handlers_[meta->user_cmd()]; | ||||
| (this->*handler_ptr)(data, size, output); | (this->*handler_ptr)(data, size, output); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[output->size()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[output->size()]); | |||||
| MS_LOG(DEBUG) << "The output size is:" << output->size(); | MS_LOG(DEBUG) << "The output size is:" << output->size(); | ||||
| if (output->size() > 0) { | if (output->size() > 0) { | ||||
| int ret = memcpy_s(res.get(), output->size(), output->data(), output->size()); | int ret = memcpy_s(res.get(), output->size(), output->data(), output->size()); | ||||
| @@ -505,7 +514,7 @@ void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnect | |||||
| } | } | ||||
| } | } | ||||
| ps_->server_node_.Response(conn, meta, res, output->size()); | |||||
| ps_->server_node_->Response(conn, meta, res, output->size()); | |||||
| MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:" | MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:" | ||||
| << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now()) | << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now()) | ||||
| .time_since_epoch() | .time_since_epoch() | ||||
| @@ -682,6 +691,7 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t | |||||
| *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()}; | *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()}; | ||||
| ps_->DoEmbeddingLookup(key, keys, &res_data); | ps_->DoEmbeddingLookup(key, keys, &res_data); | ||||
| res->resize(res_data.ByteSizeLong()); | res->resize(res_data.ByteSizeLong()); | ||||
| int ret = | int ret = | ||||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | ||||
| @@ -59,6 +59,7 @@ | |||||
| #include "proto/comm.pb.h" | #include "proto/comm.pb.h" | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "ps/core/server_node.h" | #include "ps/core/server_node.h" | ||||
| #include "ps/core/node.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -82,7 +83,8 @@ class ParameterServer { | |||||
| func_graph_(nullptr), | func_graph_(nullptr), | ||||
| sess_(nullptr), | sess_(nullptr), | ||||
| running_(true), | running_(true), | ||||
| thread_(nullptr) {} | |||||
| thread_(nullptr), | |||||
| server_node_(nullptr) {} | |||||
| ~ParameterServer() = default; | ~ParameterServer() = default; | ||||
| ParameterServer(const ParameterServer &) = delete; | ParameterServer(const ParameterServer &) = delete; | ||||
| ParameterServer &operator=(const ParameterServer &) = delete; | ParameterServer &operator=(const ParameterServer &) = delete; | ||||
| @@ -167,7 +169,7 @@ class ParameterServer { | |||||
| std::condition_variable apply_grads_cv_; | std::condition_variable apply_grads_cv_; | ||||
| std::unique_ptr<std::thread> thread_; | std::unique_ptr<std::thread> thread_; | ||||
| core::ServerNode server_node_; | |||||
| std::shared_ptr<core::ServerNode> server_node_; | |||||
| std::map<Key, ParameterPtr> embedding_tables_; | std::map<Key, ParameterPtr> embedding_tables_; | ||||
| friend class ServerHandler; | friend class ServerHandler; | ||||
| @@ -15,11 +15,13 @@ | |||||
| */ | */ | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| #include "pipeline/jit/pipeline.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| void Worker::Run() { | void Worker::Run() { | ||||
| std::lock_guard<std::mutex> lock(running_mutex_); | std::lock_guard<std::mutex> lock(running_mutex_); | ||||
| core::ClusterMetadata::instance()->Init( | core::ClusterMetadata::instance()->Init( | ||||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | ||||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | ||||
| @@ -33,6 +35,14 @@ void Worker::Run() { | |||||
| } | } | ||||
| Initialize(); | Initialize(); | ||||
| worker_node_.set_event_callback([&](const core::NodeEvent &event) { | |||||
| if ((event == core::NodeEvent::CLUSTER_TIMEOUT) || | |||||
| (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) { | |||||
| MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!"; | |||||
| Finalize(); | |||||
| exit(0); | |||||
| } | |||||
| }); | |||||
| MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; | MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; | ||||
| worker_node_.Start(); | worker_node_.Start(); | ||||
| MS_LOG(INFO) << "Worker connected successfully."; | MS_LOG(INFO) << "Worker connected successfully."; | ||||
| @@ -86,7 +96,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, | |||||
| } | } | ||||
| MS_LOG(INFO) << "The total size is:" << total_size; | MS_LOG(INFO) << "The total size is:" << total_size; | ||||
| while (!IsReadyForPush(keys[0])) { | |||||
| while (running_ && (!IsReadyForPush(keys[0]))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<int> sizes_int; | std::vector<int> sizes_int; | ||||
| @@ -109,7 +119,7 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, | |||||
| void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { | void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { | ||||
| MS_EXCEPTION_IF_NULL(dev_addr); | MS_EXCEPTION_IF_NULL(dev_addr); | ||||
| std::vector<float> variables(size / sizeof(float), 0); | std::vector<float> variables(size / sizeof(float), 0); | ||||
| while (!IsReadyForPull(key)) { | |||||
| while (running_ && (!IsReadyForPull(key))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| PullData({key}, &variables, nullptr, kPullCmd); | PullData({key}, &variables, nullptr, kPullCmd); | ||||
| @@ -214,7 +224,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> & | |||||
| std::string kv_data = embedding_table_meta.SerializeAsString(); | std::string kv_data = embedding_table_meta.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -280,7 +290,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||||
| rank_ids.push_back(i); | rank_ids.push_back(i); | ||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -303,7 +313,7 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||||
| for (auto j = 0; j < message.values_size(); j++) { | for (auto j = 0; j < message.values_size(); j++) { | ||||
| values->push_back(message.values(j)); | values->push_back(message.values(j)); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "The embedding resp:" << values; | |||||
| MS_LOG(DEBUG) << "The embedding resp:" << *values; | |||||
| for (auto k = 0; k < message.keys_size(); k++) { | for (auto k = 0; k < message.keys_size(); k++) { | ||||
| const Key &key = message.keys(k); | const Key &key = message.keys(k); | ||||
| float *addr = values->data() + value_offset; | float *addr = values->data() + value_offset; | ||||
| @@ -358,7 +368,7 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto | |||||
| rank_ids.push_back(i); | rank_ids.push_back(i); | ||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -378,7 +388,7 @@ void Worker::Finalize() { | |||||
| kvs.add_keys(0); | kvs.add_keys(0); | ||||
| kvs.add_values(0.0f); | kvs.add_values(0.0f); | ||||
| std::string kv_data = kvs.SerializeAsString(); | std::string kv_data = kvs.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -619,7 +629,7 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va | |||||
| SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {}); | SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {}); | ||||
| } else { | } else { | ||||
| std::string kv_data = kvs.SerializeAsString(); | std::string kv_data = kvs.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -920,7 +930,7 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa | |||||
| rank_ids.push_back(i); | rank_ids.push_back(i); | ||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| @@ -945,7 +955,7 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa | |||||
| rank_ids.push_back(i); | rank_ids.push_back(i); | ||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()]); | |||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | |||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||