Browse Source

!15118 fixed codex

From: @anancds
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
pull/15118/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
d65b370bdf
11 changed files with 55 additions and 32 deletions
  1. +6
    -2
      mindspore/ccsrc/ps/core/abstract_node.cc
  2. +1
    -2
      mindspore/ccsrc/ps/core/communicator/http_client.cc
  3. +3
    -1
      mindspore/ccsrc/ps/core/communicator/http_message_handler.cc
  4. +4
    -2
      mindspore/ccsrc/ps/core/communicator/tcp_message_handler.cc
  5. +1
    -1
      mindspore/ccsrc/ps/core/communicator/tcp_message_handler.h
  6. +3
    -1
      mindspore/ccsrc/ps/core/server_node.cc
  7. +12
    -8
      mindspore/ccsrc/ps/parameter_server.cc
  8. +2
    -1
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc
  9. +1
    -1
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h
  10. +19
    -10
      mindspore/ccsrc/ps/worker.cc
  11. +3
    -3
      mindspore/ccsrc/ps/worker.h

+ 6
- 2
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -526,7 +526,9 @@ void AbstractNode::ProcessSendDataResp(std::shared_ptr<MessageMeta> meta, const
auto it = receive_messages_.find(request_id); auto it = receive_messages_.find(request_id);
VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0); VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
if (size > 0) { if (size > 0) {
auto ret = memcpy_s(received_data.get()->data(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
if (ret != EOK) { if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
@@ -586,7 +588,9 @@ void AbstractNode::RunReceiveCallback(std::shared_ptr<MessageMeta> meta, const P
// If they are equal, then call the callback function // If they are equal, then call the callback function
uint64_t rank_request_id = NextActualRankRequestId(rank_id); uint64_t rank_request_id = NextActualRankRequestId(rank_id);
std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0); std::shared_ptr<std::vector<unsigned char>> received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
int ret = memcpy_s(received_data->data(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
int ret = memcpy_s(received_data->data(), dest_size, data, src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


+ 1
- 2
mindspore/ccsrc/ps/core/communicator/http_client.cc View File

@@ -106,7 +106,7 @@ int HttpClient::ReadHeaderDoneCallback(struct evhttp_request *request, void *arg
handler->set_request(request); handler->set_request(request);
struct evkeyvalq *headers = evhttp_request_get_input_headers(request); struct evkeyvalq *headers = evhttp_request_get_input_headers(request);
MS_EXCEPTION_IF_NULL(headers); MS_EXCEPTION_IF_NULL(headers);
struct evkeyval *header;
struct evkeyval *header = nullptr;
TAILQ_FOREACH(header, headers, next) { TAILQ_FOREACH(header, headers, next) {
MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value; MS_LOG(DEBUG) << "The key:" << header->key << ",The value:" << header->value;
std::string len = "Content-Length"; std::string len = "Content-Length";
@@ -204,7 +204,6 @@ Status HttpClient::CreateRequest(std::shared_ptr<HttpMessageHandler> handler, st


bool HttpClient::Start() { bool HttpClient::Start() {
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
// int ret = event_base_dispatch(event_base_);
int ret = event_base_loop(event_base_, 0); int ret = event_base_loop(event_base_, 0);
if (ret == 0) { if (ret == 0) {
MS_LOG(DEBUG) << "Event base dispatch success!"; MS_LOG(DEBUG) << "Event base dispatch success!";


+ 3
- 1
mindspore/ccsrc/ps/core/communicator/http_message_handler.cc View File

@@ -237,7 +237,9 @@ void HttpMessageHandler::RespError(int nCode, const std::string &message) {


void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) { void HttpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
MS_EXCEPTION_IF_NULL(buffer); MS_EXCEPTION_IF_NULL(buffer);
int ret = memcpy_s(body_->data() + offset_, num, buffer, num);
size_t dest_size = num;
size_t src_size = num;
int ret = memcpy_s(body_->data() + offset_, dest_size, buffer, src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


+ 4
- 2
mindspore/ccsrc/ps/core/communicator/tcp_message_handler.cc View File

@@ -53,10 +53,12 @@ void TcpMessageHandler::ReceiveMessage(const void *buffer, size_t num) {
remaining_length_ -= copy_len; remaining_length_ -= copy_len;
num -= copy_len; num -= copy_len;


int ret = memcpy_s(message_buffer_.get() + last_copy_len_, copy_len, buffer_data, copy_len);
size_t dest_size = copy_len;
size_t src_size = copy_len;
auto ret = memcpy_s(message_buffer_.get() + last_copy_len_, dest_size, buffer_data, src_size);
last_copy_len_ += copy_len; last_copy_len_ += copy_len;
buffer_data += copy_len; buffer_data += copy_len;
if (ret != 0) {
if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }




+ 1
- 1
mindspore/ccsrc/ps/core/communicator/tcp_message_handler.h View File

@@ -48,7 +48,7 @@ class TcpMessageHandler {
bool is_parsed_; bool is_parsed_;
std::unique_ptr<unsigned char> message_buffer_; std::unique_ptr<unsigned char> message_buffer_;
size_t remaining_length_; size_t remaining_length_;
char header_[16];
char header_[16]{0};
int header_index_; int header_index_;
size_t last_copy_len_; size_t last_copy_len_;
MessageHeader message_header_; MessageHeader message_header_;


+ 3
- 1
mindspore/ccsrc/ps/core/server_node.cc View File

@@ -102,7 +102,9 @@ void ServerNode::ProcessSendData(std::shared_ptr<TcpConnection> conn, std::share
MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
std::shared_ptr<unsigned char[]> res(new unsigned char[size]); std::shared_ptr<unsigned char[]> res(new unsigned char[size]);
auto ret = memcpy_s(res.get(), size, data, size);
size_t dest_size = size;
size_t src_size = size;
auto ret = memcpy_s(res.get(), dest_size, data, src_size);
if (ret != EOK) { if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


+ 12
- 8
mindspore/ccsrc/ps/parameter_server.cc View File

@@ -548,8 +548,9 @@ void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, Ve
auto weight = ps_->weight(key); auto weight = ps_->weight(key);
*res_data.mutable_values() = {weight->begin(), weight->end()}; *res_data.mutable_values() = {weight->begin(), weight->end()};
res->resize(res_data.ByteSizeLong()); res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
@@ -662,8 +663,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_
res_data.add_keys(key); res_data.add_keys(key);
res_data.add_values(ready); res_data.add_values(ready);
res->resize(res_data.ByteSizeLong()); res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
@@ -679,8 +681,9 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_
res_data.add_keys(key); res_data.add_keys(key);
res_data.add_values(ready); res_data.add_values(ready);
res->resize(res_data.ByteSizeLong()); res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
@@ -699,8 +702,9 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t
ps_->DoEmbeddingLookup(key, keys, &res_data); ps_->DoEmbeddingLookup(key, keys, &res_data);


res->resize(res_data.ByteSizeLong()); res->resize(res_data.ByteSizeLong());
int ret =
memcpy_s(res->data(), res_data.ByteSizeLong(), res_data.SerializeAsString().data(), res_data.ByteSizeLong());
size_t dest_size = res_data.ByteSizeLong();
size_t src_size = res_data.ByteSizeLong();
int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }


+ 2
- 1
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -936,7 +936,8 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in
return true; return true;
} }


bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) {
bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids,
size_t key) {
MS_ERROR_IF_NULL(embedding_device_cache_); MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_); MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_out_ids); MS_ERROR_IF_NULL(swap_out_ids);


+ 1
- 1
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h View File

@@ -165,7 +165,7 @@ class PsCacheManager {
const float *insert_data, float *hash_table_addr); const float *insert_data, float *hash_table_addr);
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr); const int *indices_addr, float *output_addr);
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key);
bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *const swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr); const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const; bool CheckFinishInsertInitInfo() const;


+ 19
- 10
mindspore/ccsrc/ps/worker.cc View File

@@ -84,7 +84,9 @@ void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs,
MS_EXCEPTION_IF_NULL(dst_data); MS_EXCEPTION_IF_NULL(dst_data);
MS_EXCEPTION_IF_NULL(src_data); MS_EXCEPTION_IF_NULL(src_data);
int size = sizes[i] * sizeof(float); int size = sizes[i] * sizeof(float);
auto ret = memcpy_s(dst_data, size, src_data, size);
size_t dest_size = IntToSize(size);
size_t src_size = IntToSize(size);
auto ret = memcpy_s(dst_data, dest_size, src_data, src_size);
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -222,7 +224,8 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &
std::string kv_data = embedding_table_meta.SerializeAsString(); std::string kv_data = embedding_table_meta.SerializeAsString();


std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -288,7 +291,8 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_
std::string kv_data = messages.at(i).second.SerializeAsString(); std::string kv_data = messages.at(i).second.SerializeAsString();


std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -371,7 +375,8 @@ void Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vecto
std::string kv_data = messages.at(i).second.SerializeAsString(); std::string kv_data = messages.at(i).second.SerializeAsString();


std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -391,7 +396,8 @@ void Worker::Finalize() {
kvs.add_values(0.0f); kvs.add_values(0.0f);
std::string kv_data = kvs.SerializeAsString(); std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -482,7 +488,7 @@ void Worker::InitPSOptimInputShapes(const size_t key) {
PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
} }


void Worker::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) {
void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) {
MS_EXCEPTION_IF_NULL(origin_addr); MS_EXCEPTION_IF_NULL(origin_addr);
std::vector<float> addr{reinterpret_cast<float *>(origin_addr), std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
reinterpret_cast<float *>(origin_addr) + size / sizeof(float)}; reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
@@ -632,7 +638,8 @@ void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &va
} else { } else {
std::string kv_data = kvs.SerializeAsString(); std::string kv_data = kvs.SerializeAsString();
std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -658,7 +665,7 @@ void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<floa
} }
} }


void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens, int cmd,
void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd,
int64_t priority) { int64_t priority) {
MS_EXCEPTION_IF_NULL(vals); MS_EXCEPTION_IF_NULL(vals);
KVMessage kvs; KVMessage kvs;
@@ -933,7 +940,8 @@ void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &pa
std::string kv_data = messages.at(i).second.SerializeAsString(); std::string kv_data = messages.at(i).second.SerializeAsString();


std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;
@@ -959,7 +967,8 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa
std::string kv_data = messages.at(i).second.SerializeAsString(); std::string kv_data = messages.at(i).second.SerializeAsString();


std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]); std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
int ret = memcpy_s(res.get(), kv_data.length(), kv_data.data(), kv_data.length());
size_t dest_size = kv_data.length();
int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
if (ret != 0) { if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return; return;


+ 3
- 3
mindspore/ccsrc/ps/worker.h View File

@@ -92,7 +92,7 @@ class Worker {
void AddKeyByHashMod(const Key &key); void AddKeyByHashMod(const Key &key);
void InitPSOptimId(const size_t param_key); void InitPSOptimId(const size_t param_key);
void InitPSOptimInputShapes(const size_t key); void InitPSOptimInputShapes(const size_t key);
void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size);
void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size);
bool IsReadyForPush(const Key &key); bool IsReadyForPush(const Key &key);
bool IsReadyForPull(const Key &key); bool IsReadyForPull(const Key &key);
void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
@@ -105,8 +105,8 @@ class Worker {
int command = 0, int64_t priority = 0); int command = 0, int64_t priority = 0);
void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0,
int64_t priority = 0);
void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr,
int cmd = 0, int64_t priority = 0);


void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
const std::map<int64_t, int64_t> &attrs); const std::map<int64_t, int64_t> &attrs);


Loading…
Cancel
Save