From: @anancds Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107pull/15118/MERGE
| @@ -526,7 +526,9 @@ void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const | |||||
| auto it = receive_messages_.find(request_id); | auto it = receive_messages_.find(request_id); | ||||
| VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | ||||
| if (size > 0) { | if (size > 0) { | ||||
| auto ret = memcpy_s(received_data.get()->data(), size, data, size); | |||||
| size_t dest_size = size; | |||||
| size_t src_size = size; | |||||
| auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -586,7 +588,9 @@ void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const P | |||||
| // If they are equal, then call the callback function | // If they are equal, then call the callback function | ||||
| uint64_t rank_request_id = NextActualRankRequestId(rank_id); | uint64_t rank_request_id = NextActualRankRequestId(rank_id); | ||||
| std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0); | ||||
| int ret = memcpy_s(received_data->data(), size, data, size); | |||||
| size_t dest_size = size; | |||||
| size_t src_size = size; | |||||
| int ret = memcpy_s(received_data->data(), dest_size, data, src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -106,7 +106,7 @@ int HttpClient::ReadHeaderDoneCallback(struct evhttp_request *request, void *arg | |||||
| handler->set_request(request); | handler->set_request(request); | ||||
| struct evkeyvalq *headers = evhttp_request_get_input_headers(request); | struct evkeyvalq *headers = evhttp_request_get_input_headers(request); | ||||
| MS_EXCEPTION_IF_NULL(headers); | MS_EXCEPTION_IF_NULL(headers); | ||||
| struct evkeyval *header; | |||||
| struct evkeyval *header = nullptr; | |||||
| TAILQ_FOREACH(header, headers, next) { | TAILQ_FOREACH(header, headers, next) { | ||||
| MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value; | MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value; | ||||
| std::string len = "Content-Length"; | std::string len = "Content-Length"; | ||||
| @@ -204,7 +204,6 @@ Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, st | |||||
| bool HttpClient::Start() { | bool HttpClient::Start() { | ||||
| MS_EXCEPTION_IF_NULL(event_base_); | MS_EXCEPTION_IF_NULL(event_base_); | ||||
| // int ret = event_base_dispatch(event_base_); | |||||
| int ret = event_base_loop(event_base_, 0); | int ret = event_base_loop(event_base_, 0); | ||||
| if (ret == 0) { | if (ret == 0) { | ||||
| MS_LOG(DEBUG) << "Event base dispatch success!"; | MS_LOG(DEBUG) << "Event base dispatch success!"; | ||||
| @@ -237,7 +237,9 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) { | |||||
| void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | ||||
| MS_EXCEPTION_IF_NULL(buffer); | MS_EXCEPTION_IF_NULL(buffer); | ||||
| int ret = memcpy_s(body_->data() + offset_, num, buffer, num); | |||||
| size_t dest_size = num; | |||||
| size_t src_size = num; | |||||
| int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -53,10 +53,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { | |||||
| remaining_length_ -= copy_len; | remaining_length_ -= copy_len; | ||||
| num -= copy_len; | num -= copy_len; | ||||
| int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len); | |||||
| size_t dest_size = copy_len; | |||||
| size_t src_size = copy_len; | |||||
| auto ret = memcpy_s(message_buffer_.get() + last_copy_len_, dest_size, buffer_data, src_size); | |||||
| last_copy_len_ += copy_len; | last_copy_len_ += copy_len; | ||||
| buffer_data += copy_len; | buffer_data += copy_len; | ||||
| if (ret != 0) { | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -48,7 +48,7 @@ class TcpMessageHandler { | |||||
| bool is_parsed_; | bool is_parsed_; | ||||
| std::unique_ptr<unsigned char> message_buffer_; | std::unique_ptr<unsigned char> message_buffer_; | ||||
| size_t remaining_length_; | size_t remaining_length_; | ||||
| char header_[16]; | |||||
| char header_[16]{0}; | |||||
| int header_index_; | int header_index_; | ||||
| size_t last_copy_len_; | size_t last_copy_len_; | ||||
| MessageHeader message_header_; | MessageHeader message_header_; | ||||
| @@ -102,7 +102,9 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share | |||||
| MS_EXCEPTION_IF_NULL(meta); | MS_EXCEPTION_IF_NULL(meta); | ||||
| MS_EXCEPTION_IF_NULL(data); | MS_EXCEPTION_IF_NULL(data); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[size]); | std::shared_ptr<unsigned char[]> res(new unsigned char[size]); | ||||
| auto ret = memcpy_s(res.get(), size, data, size); | |||||
| size_t dest_size = size; | |||||
| size_t src_size = size; | |||||
| auto ret = memcpy_s(res.get(), dest_size, data, src_size); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -548,8 +548,9 @@ void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, Ve | |||||
| auto weight = ps_->weight(key); | auto weight = ps_->weight(key); | ||||
| *res_data.mutable_values() = {weight->begin(), weight->end()}; | *res_data.mutable_values() = {weight->begin(), weight->end()}; | ||||
| res->resize(res_data.ByteSizeLong()); | res->resize(res_data.ByteSizeLong()); | ||||
| int ret = | |||||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | |||||
| size_t dest_size = res_data.ByteSizeLong(); | |||||
| size_t src_size = res_data.ByteSizeLong(); | |||||
| int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -662,8 +663,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_ | |||||
| res_data.add_keys(key); | res_data.add_keys(key); | ||||
| res_data.add_values(ready); | res_data.add_values(ready); | ||||
| res->resize(res_data.ByteSizeLong()); | res->resize(res_data.ByteSizeLong()); | ||||
| int ret = | |||||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | |||||
| size_t dest_size = res_data.ByteSizeLong(); | |||||
| size_t src_size = res_data.ByteSizeLong(); | |||||
| int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -679,8 +681,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_ | |||||
| res_data.add_keys(key); | res_data.add_keys(key); | ||||
| res_data.add_values(ready); | res_data.add_values(ready); | ||||
| res->resize(res_data.ByteSizeLong()); | res->resize(res_data.ByteSizeLong()); | ||||
| int ret = | |||||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | |||||
| size_t dest_size = res_data.ByteSizeLong(); | |||||
| size_t src_size = res_data.ByteSizeLong(); | |||||
| int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -699,8 +702,9 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t | |||||
| ps_->DoEmbeddingLookup(key, keys, &res_data); | ps_->DoEmbeddingLookup(key, keys, &res_data); | ||||
| res->resize(res_data.ByteSizeLong()); | res->resize(res_data.ByteSizeLong()); | ||||
| int ret = | |||||
| memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong()); | |||||
| size_t dest_size = res_data.ByteSizeLong(); | |||||
| size_t src_size = res_data.ByteSizeLong(); | |||||
| int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -936,7 +936,8 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||||
| bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, | |||||
| size_t key) { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | MS_ERROR_IF_NULL(embedding_device_cache_); | ||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | ||||
| MS_ERROR_IF_NULL(swap_out_ids); | MS_ERROR_IF_NULL(swap_out_ids); | ||||
| @@ -165,7 +165,7 @@ class PsCacheManager { | |||||
| const float *insert_data, float *hash_table_addr); | const float *insert_data, float *hash_table_addr); | ||||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | ||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key); | |||||
| bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, size_t key); | |||||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | ||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| bool CheckFinishInsertInitInfo() const; | bool CheckFinishInsertInitInfo() const; | ||||
| @@ -84,7 +84,9 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, | |||||
| MS_EXCEPTION_IF_NULL(dst_data); | MS_EXCEPTION_IF_NULL(dst_data); | ||||
| MS_EXCEPTION_IF_NULL(src_data); | MS_EXCEPTION_IF_NULL(src_data); | ||||
| int size = sizes[i] * sizeof(float); | int size = sizes[i] * sizeof(float); | ||||
| auto ret = memcpy_s(dst_data, size, src_data, size); | |||||
| size_t dest_size = IntToSize(size); | |||||
| size_t src_size = IntToSize(size); | |||||
| auto ret = memcpy_s(dst_data, dest_size, src_data, src_size); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -222,7 +224,8 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> & | |||||
| std::string kv_data = embedding_table_meta.SerializeAsString(); | std::string kv_data = embedding_table_meta.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -288,7 +291,8 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -371,7 +375,8 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto | |||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -391,7 +396,8 @@ void Worker::Finalize() { | |||||
| kvs.add_values(0.0f); | kvs.add_values(0.0f); | ||||
| std::string kv_data = kvs.SerializeAsString(); | std::string kv_data = kvs.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -482,7 +488,7 @@ void Worker::InitPSOptimInputShapes(const size_t key) { | |||||
| PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); | PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); | ||||
| } | } | ||||
| void Worker::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) { | |||||
| void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(origin_addr); | MS_EXCEPTION_IF_NULL(origin_addr); | ||||
| std::vector<float> addr{reinterpret_cast<float *>(origin_addr), | std::vector<float> addr{reinterpret_cast<float *>(origin_addr), | ||||
| reinterpret_cast<float *>(origin_addr) + size / sizeof(float)}; | reinterpret_cast<float *>(origin_addr) + size / sizeof(float)}; | ||||
| @@ -632,7 +638,8 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va | |||||
| } else { | } else { | ||||
| std::string kv_data = kvs.SerializeAsString(); | std::string kv_data = kvs.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -658,7 +665,7 @@ void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<floa | |||||
| } | } | ||||
| } | } | ||||
| void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens, int cmd, | |||||
| void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd, | |||||
| int64_t priority) { | int64_t priority) { | ||||
| MS_EXCEPTION_IF_NULL(vals); | MS_EXCEPTION_IF_NULL(vals); | ||||
| KVMessage kvs; | KVMessage kvs; | ||||
| @@ -933,7 +940,8 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa | |||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -959,7 +967,8 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa | |||||
| std::string kv_data = messages.at(i).second.SerializeAsString(); | std::string kv_data = messages.at(i).second.SerializeAsString(); | ||||
| std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); | ||||
| int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length()); | |||||
| size_t dest_size = kv_data.length(); | |||||
| int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length()); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| return; | return; | ||||
| @@ -92,7 +92,7 @@ class Worker { | |||||
| void AddKeyByHashMod(const Key &key); | void AddKeyByHashMod(const Key &key); | ||||
| void InitPSOptimId(const size_t param_key); | void InitPSOptimId(const size_t param_key); | ||||
| void InitPSOptimInputShapes(const size_t key); | void InitPSOptimInputShapes(const size_t key); | ||||
| void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | |||||
| void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size); | |||||
| bool IsReadyForPush(const Key &key); | bool IsReadyForPush(const Key &key); | ||||
| bool IsReadyForPull(const Key &key); | bool IsReadyForPull(const Key &key); | ||||
| void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, | void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, | ||||
| @@ -105,8 +105,8 @@ class Worker { | |||||
| int command = 0, int64_t priority = 0); | int command = 0, int64_t priority = 0); | ||||
| void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, | void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, | ||||
| size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); | size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); | ||||
| void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0, | |||||
| int64_t priority = 0); | |||||
| void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr, | |||||
| int cmd = 0, int64_t priority = 0); | |||||
| void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, | void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, | ||||
| const std::map<int64_t, int64_t> &attrs); | const std::map<int64_t, int64_t> &attrs); | ||||